In [None]:
# from google.colab import drive
# drive.mount('/content/drive')
# !pwd

In [None]:
!pip install torchsummary

In [None]:
# # prompt: take me to the RA Data folder

# %cd /content/drive/MyDrive/RA Data

In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, random_split
from torch.cuda.amp import GradScaler, autocast
from torchvision.utils import save_image
from PIL import Image
import os
from PIL import Image
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(in_channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out

In [2]:
# class DualTaskResNet(nn.Module):
#     def __init__(self, num_classes=1, pretrained=True):
#         super(DualTaskResNet, self).__init__()
#         # Load a pretrained ResNet
#         self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

#         # Remove the last fully connected layer
#         self.features = nn.Sequential(*list(self.resnet.children())[:-2])

#         # Upsampling layers for mask generation
#         self.upsample1 = self._make_upsample_block(2048, 1024)
#         self.upsample2 = self._make_upsample_block(1024, 512)
#         self.upsample3 = self._make_upsample_block(512, 256)
#         self.upsample4 = self._make_upsample_block(256, 64)
#         self.upsample5 = self._make_upsample_block(64, num_classes, final=True)

#         # Enhanced skeletonization layers
#         self.skeleton_decoder = nn.Sequential(
#             nn.Conv2d(num_classes, 32, kernel_size=3, padding=1),
#             nn.BatchNorm2d(32),
#             nn.ReLU(inplace=True),
#             ResidualBlock(32),
#             nn.Conv2d(32, 16, kernel_size=3, padding=1),
#             nn.BatchNorm2d(16),
#             nn.ReLU(inplace=True),
#             ResidualBlock(16),
#             nn.Conv2d(16, num_classes, kernel_size=1)
#         )

#     def _make_upsample_block(self, in_channels, out_channels, final=False):
#         layers = []
#         for i in range(3):  # Stack 3 ConvTranspose2d layers
#             if i == 0:
#                 layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2))
#             else:
#                 layers.append(nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1))
#             if not final:
#                 layers.append(nn.BatchNorm2d(out_channels))
#                 layers.append(nn.ReLU(inplace=True))
#         return nn.Sequential(*layers)

#     def forward(self, x, task='mask'):
#         x = self.features(x)
#         x = self.upsample1(x)
#         x = self.upsample2(x)
#         x = self.upsample3(x)
#         x = self.upsample4(x)
#         mask = self.upsample5(x)

#         if task == 'mask':
#             return torch.sigmoid(mask)
#         elif task == 'skeleton':
#             skeleton = self.skeleton_decoder(mask)
#             return torch.sigmoid(skeleton)

#     def freeze_encoder(self):
#         for param in self.features.parameters():
#             param.requires_grad = False

#     def unfreeze_encoder(self):
#         for param in self.features.parameters():
#             param.requires_grad = True

#     def freeze_decoder(self):
#         for layer in [self.upsample1, self.upsample2, self.upsample3, self.upsample4, self.upsample5]:
#             for param in layer.parameters():
#                 param.requires_grad = False

#     def unfreeze_decoder(self):
#         for layer in [self.upsample1, self.upsample2, self.upsample3, self.upsample4, self.upsample5]:
#             for param in layer.parameters():
#                 param.requires_grad = True

#     def finetune_for_skeleton(self):
#         # Freeze the encoder and decoder
#         self.freeze_encoder()
#         self.freeze_decoder()

#         # Unfreeze only the skeleton decoder
#         for param in self.skeleton_decoder.parameters():
#             param.requires_grad = True

#     def prepare_for_full_training(self):
#         # Unfreeze all layers for full training
#         self.unfreeze_encoder()
#         self.unfreeze_decoder()
#         for param in self.skeleton_decoder.parameters():
#             param.requires_grad = True

In [3]:
class DualTaskResNet(nn.Module):
    def __init__(self, num_classes=1, pretrained=True):
        super(DualTaskResNet, self).__init__()
        
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
        self.resnet18 = nn.Sequential(*list(resnet.children())[:-2])
        self.layer0 = nn.Sequential(*list(self.resnet18.children())[:3])
        self.layer1 = nn.Sequential(*list(self.resnet18.children())[3:5])
        self.layer2 = self.resnet18[5]
        self.layer3 = self.resnet18[6]
        self.layer4 = self.resnet18[7]

        # Dilated convolutions for layer2
        self.dilation_conv1 = self._make_dilated_conv(128, 256, 2)
        self.dilation_conv2 = self._make_dilated_conv(128, 256, 4)
        self.dilation_conv3 = self._make_dilated_conv(128, 256, 8)
        self.dilation_conv4 = self._make_dilated_conv(128, 256, 16)
        self.dilation_conv5 = self._make_dilated_conv(128, 256, 32)

        # Dilated convolutions for layer3
        self.dilation_conv6 = self._make_dilated_conv(256, 512, 2)
        self.dilation_conv7 = self._make_dilated_conv(256, 512, 4)
        self.dilation_conv8 = self._make_dilated_conv(256, 512, 8)
        self.dilation_conv9 = self._make_dilated_conv(256, 512, 16)
        self.dilation_conv10 = self._make_dilated_conv(256, 512, 32)

        # Upsampling path
        self.upsample1 = self._make_transpose_conv(512, 512, 2)  # 7x7 -> 14x14
        self.upsample2 = self._make_transpose_conv(3072, 512, 2)  # 14x14 -> 28x28
        self.upsample3 = self._make_transpose_conv(1792, 256, 2)  # 28x28 -> 56x56
        self.upsample4 = self._make_transpose_conv(256, 128, 2)  # 56x56 -> 112x112
        self.upsample5 = self._make_transpose_conv(128, 64, 2)  # 112x112 -> 224x224
        self.convf = nn.Conv2d(64, num_classes, kernel_size=1)

        # Task-specific output layers
        self.mask_output = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.skeleton_output = nn.Conv2d(num_classes, num_classes, kernel_size=1)

    def _make_dilated_conv(self, in_channels, out_channels, dilation):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def _make_transpose_conv(self, in_channels, out_channels, scale_factor):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=scale_factor, padding=0, output_padding=0),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, img, task='mask'):
        # Expected input size: 224x224x3
        layer0 = self.layer0(img)  # 112x112x64
        layer1 = self.layer1(layer0)  # 56x56x64
        layer2 = self.layer2(layer1)  # 28x28x128
        layer3 = self.layer3(layer2)  # 14x14x256
        layer4 = self.layer4(layer3)  # 7x7x512

        # Apply dilation to layer2 (28x28x128)
        y1 = self.dilation_conv1(layer2)
        y2 = self.dilation_conv2(layer2)
        y3 = self.dilation_conv3(layer2)
        y4 = self.dilation_conv4(layer2)
        y5 = self.dilation_conv5(layer2)
        y = torch.cat([y1, y2, y3, y4, y5], dim=1)  # 28x28x1280

        # Apply dilation to layer3 (14x14x256)
        z1 = self.dilation_conv6(layer3)
        z2 = self.dilation_conv7(layer3)
        z3 = self.dilation_conv8(layer3)
        z4 = self.dilation_conv9(layer3)
        z5 = self.dilation_conv10(layer3)
        z = torch.cat([z1, z2, z3, z4, z5], dim=1)  # 14x14x2560

        # Upsampling path
        x = self.upsample1(layer4)  # 14x14x512
        x = torch.cat([x, z], dim=1)  # 14x14x3072
        x = self.upsample2(x)  # 28x28x512
        x = torch.cat([x, y], dim=1)  # 28x28x1792
        x = self.upsample3(x)  # 56x56x256
        x = self.upsample4(x)  # 112x112x128
        x = self.upsample5(x)  # 224x224x64
        x = self.convf(x)  # 224x224xnum_classes

        if task == 'mask':
            output = self.mask_output(x)
        elif task == 'skeleton':
            output = self.skeleton_output(x)
        else:
            raise ValueError("Task must be either 'mask' or 'skeleton'")

        return torch.sigmoid(output)



In [4]:
# Test function
def test_dualtask_resnet():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create model instance
    model = DualTaskResNet(num_classes=1, pretrained=False).to(device)

    # Print model summary
    summary(model, (3, 224, 224))

    # Generate random input
    batch_size = 4
    input_tensor = torch.randn(batch_size, 3, 224, 224).to(device)

    # Test mask generation
    print("\nTesting mask generation:")
    mask_output = model(input_tensor, task='mask')
    print(f"Mask output shape: {mask_output.shape}")
    print(f"Mask output min: {mask_output.min().item():.4f}, max: {mask_output.max().item():.4f}")

    # Test skeletonization
    print("\nTesting skeletonization:")
    skeleton_output = model(input_tensor, task='skeleton')
    print(f"Skeleton output shape: {skeleton_output.shape}")
    print(f"Skeleton output min: {skeleton_output.min().item():.4f}, max: {skeleton_output.max().item():.4f}")

    # Test if outputs are different
    print("\nChecking if mask and skeleton outputs are different:")
    is_different = not torch.allclose(mask_output, skeleton_output)
    print(f"Outputs are different: {is_different}")



In [5]:
# test_dualtask_resnet()

In [6]:
class DualTaskDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, mask_dir, skeleton_dir, image_transform=None, mask_transform=None):
        self.image_dir = os.path.normpath(image_dir)
        self.mask_dir = os.path.normpath(mask_dir)
        self.skeleton_dir = os.path.normpath(skeleton_dir)
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.image_files = [f for f in os.listdir(self.image_dir) if f.endswith('.jpeg') or f.endswith('.jpg') or f.endswith('.png')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.normpath(os.path.join(self.image_dir, img_name))

        # Construct the mask and skeleton paths
        base_name = os.path.splitext(img_name)[0]
        mask_name = base_name + '.png'
        skeleton_name = base_name + '.png'
        mask_path = os.path.normpath(os.path.join(self.mask_dir, mask_name))
        skeleton_path = os.path.normpath(os.path.join(self.skeleton_dir, skeleton_name))

        # Load the images
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        skeleton = Image.open(skeleton_path).convert('L')

        # Apply transforms
        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
            skeleton = self.mask_transform(skeleton)

        return image, mask, skeleton, img_name  # Return the image name as well


def get_data_loaders(image_dir, mask_dir, skeleton_dir, batch_size=32, train_split=0.8, val_split=0.1, test_split=0.1):
    # Define transforms for input images
    image_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5514, 0.4094, 0.3140], std=[0.1299, 0.1085, 0.0914])
    ])

    # Define transforms for masks and skeletons
    mask_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    # Create dataset
    full_dataset = DualTaskDataset(image_dir, mask_dir, skeleton_dir,
                                   image_transform=image_transform,
                                   mask_transform=mask_transform)

    # Calculate split sizes
    total_size = len(full_dataset)
    train_size = int(train_split * total_size)
    val_size = int(val_split * total_size)
    test_size = total_size - train_size - val_size

    # Split the dataset
    train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

    # Optimize DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
                              # num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
                            # num_workers=1, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
                             # num_workers=1, pin_memory=True)

    return train_loader, val_loader, test_loader


In [7]:
def linear_loss_decay(epoch, total_epochs, start_weight=0.5, end_weight=1.0):
    return start_weight + (end_weight - start_weight) * (epoch / total_epochs)

def calculate_iou(pred, target, threshold=0.5):
    pred = (pred > threshold).float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    return (intersection + 1e-6) / (union + 1e-6)  # Adding small epsilon to avoid division by zero

def calculate_pos_weight(train_loader):
    total_pixels = 0
    skeleton_pixels = 0
    for _, _, skeletons, _ in tqdm(train_loader, desc="Calculating pos_weight"):
        total_pixels += skeletons.numel()
        skeleton_pixels += skeletons.sum().item()
    neg_pos_ratio = (total_pixels - skeleton_pixels) / skeleton_pixels
    return torch.tensor([neg_pos_ratio])


In [8]:
def validate(model, val_loader, criterion, device, task):
    model.eval()
    val_loss = 0.0
    iou_sum = 0.0

    with torch.no_grad():
        for images, masks, skeletons, _ in val_loader:
            images = images.to(device)
            targets = masks.to(device) if task == 'mask' else skeletons.to(device)

            outputs = model(images, task=task)
            loss = criterion(outputs, targets)

            val_loss += loss.item() * images.size(0)
            iou_sum += calculate_iou(outputs, targets)

    val_loss /= len(val_loader.dataset)
    val_iou = iou_sum / len(val_loader)

    return val_loss, val_iou

In [9]:
class WeightedDiceLoss(nn.Module):
    def __init__(self, smooth=1., beta=0.5):
        super(WeightedDiceLoss, self).__init__()
        self.smooth = smooth
        self.beta = beta  # Weight for the foreground class

    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        pred = pred.view(-1)
        target = target.view(-1)

        # Calculate weights
        w_fg = self.beta
        w_bg = 1 - self.beta

        # Weighted intersection and union
        intersection = (pred * target * w_fg).sum()
        union = (pred * w_fg).sum() + (target * w_fg).sum()

        # Add background contribution
        intersection += ((1 - pred) * (1 - target) * w_bg).sum()
        union += ((1 - pred) * w_bg).sum() + ((1 - target) * w_bg).sum()

        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice

In [11]:
# Check
criterion = WeightedDiceLoss(beta=0.7)  # Adjust beta based on your dataset's imbalance
pred = torch.randn(1, 1, 224, 224)  # Example prediction
target = torch.randint(0, 2, (1, 1, 224, 224))  # Example target mask
loss = criterion(pred, target)
print(f"Loss: {loss.item()}")

Loss: 0.4989425539970398


In [12]:
def train_mask_model(model, train_loader, val_loader, num_epochs, device, beta=0.7, validate_every=10):
    best_mask_iou = 0.0
    best_mask_model_weights = None
    mask_train_losses, mask_val_losses, mask_ious, mask_lrs = [], [], [], []

    criterion_mask = WeightedDiceLoss(beta=beta)
    optimizer_mask = torch.optim.Adam(model.parameters(), lr=0.001)

    constant_lr_epochs = num_epochs // 2
    cosine_annealing_epochs = num_epochs - constant_lr_epochs

    scheduler_mask = torch.optim.lr_scheduler.LambdaLR(optimizer_mask, lambda epoch: 1)  # Constant LR

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        pbar = tqdm(train_loader, desc=f"Mask Training Epoch {epoch+1}/{num_epochs}")
        for images, masks, _, _ in pbar:
            images, masks = images.to(device), masks.to(device)

            optimizer_mask.zero_grad()
            outputs_mask = model(images, task='mask')
            loss = criterion_mask(outputs_mask, masks)
            loss.backward()
            optimizer_mask.step()

            train_loss += loss.item() * images.size(0)
            
            pbar.set_postfix({
                'Train Loss': f'{loss.item():.4f}',
                'LR': f'{scheduler_mask.get_last_lr()[0]:.6f}'
            })

        train_loss /= len(train_loader.dataset)
        mask_train_losses.append(train_loss)
        mask_lrs.append(scheduler_mask.get_last_lr()[0])

        # Validate every 10 epochs
        if (epoch + 1) % validate_every == 0:
            val_loss, mask_iou = validate(model, val_loader, criterion_mask, device, task='mask')
            mask_val_losses.append(val_loss)
            mask_ious.append(mask_iou)
            
            print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, '
                  f'Val Loss: {val_loss:.4f}, Mask IoU: {mask_iou:.4f}')
            
            if mask_iou > best_mask_iou:
                best_mask_iou = mask_iou
                best_mask_model_weights = model.state_dict().copy()
                print(f'New best mask model saved with IoU: {best_mask_iou:.4f}')

        if epoch == constant_lr_epochs - 1:
            scheduler_mask = CosineAnnealingLR(optimizer_mask, T_max=cosine_annealing_epochs)
        
        scheduler_mask.step()

    visualization_data = {
        'train_losses': mask_train_losses,
        'val_losses': mask_val_losses,
        'ious': mask_ious,
        'lrs': mask_lrs
    }

    return model, best_mask_model_weights, visualization_data

def finetune_skeleton_model(model, train_loader, val_loader, num_epochs, device, best_mask_model_path):
    # Load the best mask model
    state_dict = torch.load(best_mask_model_path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    
    best_skeleton_iou = 0.0
    best_skeleton_model_weights = None
    skeleton_train_losses, skeleton_val_losses, skeleton_ious, skeleton_lrs = [], [], [], []

    pos_weight = calculate_pos_weight(train_loader)
    criterion_skeleton = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
    optimizer_skeleton = torch.optim.Adam(model.parameters(), lr=0.001)

    constant_lr_epochs = num_epochs // 2
    cosine_annealing_epochs = num_epochs - constant_lr_epochs
    scheduler_skeleton = torch.optim.lr_scheduler.LambdaLR(optimizer_skeleton, lambda epoch: 1)  # Constant LR

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        pbar = tqdm(train_loader, desc=f"Skeleton Finetuning Epoch {epoch+1}/{num_epochs}")
        for batch in pbar:
            images, _, skeletons, _ = batch  # Unpack correctly regardless of the number of items
            images, skeletons = images.to(device), skeletons.to(device)

            optimizer_skeleton.zero_grad()
            outputs_skeleton = model(images, task='skeleton')
            loss = criterion_skeleton(outputs_skeleton, skeletons)
            loss.backward()
            optimizer_skeleton.step()
            
            train_loss += loss.item()
            pbar.set_postfix({
                'Train Loss': f'{loss.item():.4f}',
                'LR': f'{scheduler_skeleton.get_last_lr()[0]:.6f}'
            })
        
        scheduler_skeleton.step()  # Step the scheduler once per epoch
        
        train_loss /= len(train_loader)
        skeleton_train_losses.append(train_loss)
        skeleton_lrs.append(scheduler_skeleton.get_last_lr()[0])

        # Validation at the end of each epoch
        val_loss, skeleton_iou = validate(model, val_loader, criterion_skeleton, device, task='skeleton')
        skeleton_val_losses.append(val_loss)
        skeleton_ious.append(skeleton_iou)
        
        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, '
              f'Val Loss: {val_loss:.4f}, Skeleton IoU: {skeleton_iou:.4f}')
        
        if skeleton_iou > best_skeleton_iou:
            best_skeleton_iou = skeleton_iou
            best_skeleton_model_weights = model.state_dict().copy()
            print(f'New best skeleton model saved with IoU: {best_skeleton_iou:.4f}')

        if epoch == constant_lr_epochs - 1:
            scheduler_skeleton = CosineAnnealingLR(optimizer_skeleton, T_max=cosine_annealing_epochs)

    visualization_data = {
        'train_losses': skeleton_train_losses,
        'val_losses': skeleton_val_losses,
        'ious': skeleton_ious,
        'lrs': skeleton_lrs
    }

    return model, best_skeleton_model_weights, visualization_data

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
image_dir = "DATA/IMG"
mask_dir = "DATA/MASK"
skeleton_dir = "DATA/SKELETON"
batch_size = 32

train_loader, val_loader, test_loader = get_data_loaders(image_dir, mask_dir, skeleton_dir, batch_size)

model = DualTaskResNet(num_classes=1, pretrained=True)
model.to(device)

# Phase 1: Train mask model
print("Phase 1: Training mask model")
model, best_mask_weights, mask_vis_data = train_mask_model(
    model, train_loader, val_loader, num_epochs=100, device=device, beta=0.7
)

# Save the best mask model
torch.save(best_mask_weights, 'best_mask_model.pth')
print("Best mask model saved.")




Phase 1: Training mask model


Mask Training Epoch 1/100: 100%|██████████| 19/19 [11:36<00:00, 36.63s/it, Train Loss=0.5862, LR=0.001000]
Mask Training Epoch 2/100: 100%|██████████| 19/19 [11:49<00:00, 37.33s/it, Train Loss=0.6001, LR=0.001000]
Mask Training Epoch 3/100:   5%|▌         | 1/19 [00:35<10:32, 35.13s/it, Train Loss=0.6117, LR=0.001000]

In [None]:
# Phase 2: Finetune skeleton model
print("Phase 2: Finetuning skeleton model")
model, best_skeleton_weights, skeleton_vis_data = finetune_skeleton_model(
    model, train_loader, val_loader, num_epochs=100, device=device, best_mask_model_path='best_mask_model.pth'
)


# Save the best skeleton model
torch.save(best_skeleton_weights, 'best_skeleton_model.pth')
print("Best skeleton model saved.")



In [None]:
# Combine visualization data
vis_data = {
    'mask': mask_vis_data,
    'skeleton': skeleton_vis_data
}

print("Training completed and best models saved.")

In [None]:
import matplotlib.pyplot as plt

def visualize_training(visualization_data):
    fig, axs = plt.subplots(2, 3, figsize=(20, 12))
    fig.suptitle('Training and Validation Metrics', fontsize=16)

    phases = ['mask', 'skeleton']
    colors = ['blue', 'red']

    for i, phase in enumerate(phases):
        # Plot training and validation loss
        axs[i, 0].plot(visualization_data[phase]['train_losses'], label='Train Loss', color=colors[0])
        axs[i, 0].plot(visualization_data[phase]['val_losses'], label='Val Loss', color=colors[1])
        axs[i, 0].set_title(f'{phase.capitalize()} Loss')
        axs[i, 0].set_xlabel('Epoch')
        axs[i, 0].set_ylabel('Loss')
        axs[i, 0].legend()

        # Plot IoU
        axs[i, 1].plot(visualization_data[phase]['ious'], label='IoU', color=colors[0])
        axs[i, 1].set_title(f'{phase.capitalize()} IoU')
        axs[i, 1].set_xlabel('Epoch')
        axs[i, 1].set_ylabel('IoU')
        axs[i, 1].legend()

        # Plot Learning Rate
        axs[i, 2].plot(visualization_data[phase]['lrs'], label='Learning Rate', color=colors[0])
        axs[i, 2].set_title(f'{phase.capitalize()} Learning Rate')
        axs[i, 2].set_xlabel('Epoch')
        axs[i, 2].set_ylabel('Learning Rate')
        axs[i, 2].legend()

    plt.tight_layout()
    plt.show()



In [None]:
# Visulize Training
visualize_training(vis_data)

In [None]:
test_loader_for_test = test_loader

# Evaluate on test set
test_loss, test_mask_iou, test_skeleton_iou, test_avg_iou = validate(trained_model, test_loader, criterion, device)
print(f'Test Loss: {test_loss:.4f}, Test Mask IoU: {test_mask_iou:.4f}, '
        f'Test Skeleton IoU: {test_skeleton_iou:.4f}, Test Avg IoU: {test_avg_iou:.4f}')

In [None]:


def generate_skeletons(model, test_dataloader, save_dir, device='cuda'):
    model.eval()  # Set the model to evaluation mode
    os.makedirs(save_dir, exist_ok=True)  # Ensure the save directory exists

    with torch.no_grad():  # No need to compute gradients for inference
        for batch in tqdm(test_dataloader, desc="Generating Skeletons"):
            # Unpack the batch
            images, _, _, img_names = batch
            images = images.to(device)

            # Generate skeleton predictions
            skeleton_preds = model(images, task='skeleton')

            # Apply a threshold to get binary skeletons
            skeleton_binary = (skeleton_preds > 0.3).float()

            # Save each predicted skeleton
            for j, img_name in enumerate(img_names):
                # Use the original image name for the skeleton file
                base_name = os.path.splitext(img_name)[0]
                save_path = os.path.join(save_dir, f"{base_name}_skeleton.png")
                save_image(skeleton_binary[j], save_path)
                print(f"Skeleton saved at: {save_path}")





In [None]:
# Load your trained model
model = DualTaskResNet(num_classes=1, pretrained=False)  # Adjust num_classes if necessary
model.load_state_dict(torch.load("best_skeleton_model.pth"))  # Load the trained weights
model = model.to('cuda')  # Send the model to the GPU (or use 'cpu')

# Load your test dataloader (assumes test_dataloader is defined)
# Example: test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Directory where skeletons will be saved
save_dir = "./skeleton_predictions"

# Call the skeleton generation function
generate_skeletons(model, test_loader, save_dir, device='cuda')