In [None]:
# =========================
# Imports
# =========================
import os
from collections import defaultdict

import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import clear_output


In [None]:
# =========================
# Dataset Class for Semantic Segmentation
# =========================
class SegmentationDatasetFromDF(Dataset):
    def __init__(self, df_img, df_seg, transform=None, mask_transform=None):
        self.df_img = df_img.reset_index(drop=True)
        self.df_seg = df_seg.reset_index(drop=True)
        self.transform = transform
        self.mask_transform = mask_transform

        # Vectorized mapping from original mask IDs to training IDs
        self.id_to_trainId_vec = np.vectorize(lambda x: id_to_trainId_map.get(x, 255))

    def __len__(self):
        return len(self.df_img)

    def __getitem__(self, idx):
        # Load image
        image = self.df_img.loc[idx, 'image_data']
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)

        # Load mask
        mask = self.df_seg.loc[idx, 'annotation_data']
        if not isinstance(mask, np.ndarray):
            mask = np.array(mask)

        # Convert mask IDs to train IDs
        mask_converted = self.id_to_trainId_vec(mask).astype(np.uint8)
        mask_converted = Image.fromarray(mask_converted)

        # Apply transforms
        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask_converted = self.mask_transform(mask_converted)

        return image, mask_converted


In [None]:
# =========================
# Training and Validation Functions
# =========================
def train_one_epoch(model, dataloader, optimizer, criterion, device, epoch_num):
    model.train()
    running_loss = 0.0

    for batch_idx, (images, masks) in enumerate(dataloader):
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)['out']
        loss = criterion(outputs, masks.squeeze(1).long())

        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == len(dataloader):
            print(f"Epoch [{epoch_num}], Batch [{batch_idx + 1}/{len(dataloader)}], Loss: {loss.item():.4f}")

    epoch_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch_num}] completed. Average Loss: {epoch_loss:.4f}")
    return epoch_loss


def validate_one_epoch(model, dataloader, criterion, device, epoch_num):
    model.eval()
    running_val_loss = 0.0

    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(dataloader):
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)['out']
            loss = criterion(outputs, masks.squeeze(1).long())
            running_val_loss += loss.item()

            if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(dataloader):
                print(f"Validation Epoch [{epoch_num}], Batch [{batch_idx + 1}/{len(dataloader)}], Loss: {loss.item():.4f}")

    avg_val_loss = running_val_loss / len(dataloader)
    print(f"Validation Epoch [{epoch_num}] completed. Average Loss: {avg_val_loss:.4f}")
    return avg_val_loss


In [None]:
# =========================
# Model Setup
# =========================
N_CLASSES = 19
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained DeepLabV3 with ResNet50 backbone
model_backbone_pretrained = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)

# Replace classifier for custom number of classes
model_backbone_pretrained.classifier = torchvision.models.segmentation.deeplabv3.DeepLabHead(2048, N_CLASSES)
model = model_backbone_pretrained.to(device)


In [None]:
# =========================
# Training Settings
# =========================
num_epochs = 19
save_every_n = 2
checkpoint_dir = '/content/drive/MyDrive/Vision Project/backbone_resnet50_pretrained'
os.makedirs(checkpoint_dir, exist_ok=True)

# =========================
# Load Checkpoint
# =========================
checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_epoch_10.pth')
checkpoint = torch.load(checkpoint_path, map_location=device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

# Load checkpoint states
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', patience=3, factor=0.1, verbose=True
)
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

start_epoch = checkpoint['epoch']
train_losses = checkpoint['train_losses']
val_losses = checkpoint['val_losses']
best_val_loss = checkpoint['best_val_loss']


In [None]:
def analyze_backbone_percent_difference(checkpoint_path, device="cpu", epsilon=1e-8):
    """
    Compare transferred checkpoint backbone vs torchvision DeepLabV3 ResNet-50 backbone.
    Shows differences as percentage relative to DeepLab weights.
    """
    # Load checkpoint
    ckpt = torch.load(checkpoint_path, map_location=device)
    state_dict = ckpt.get('model_state_dict', ckpt)

    # Extract backbone weights
    ckpt_backbone = {k: v for k, v in state_dict.items() if "backbone" in k}
    norm_ckpt = {k.replace("backbone.body.", "").replace("backbone.", ""): v for k, v in ckpt_backbone.items()}

    # Load DeepLab backbone
    deeplab = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True).backbone
    deeplab_state = deeplab.state_dict()

    common = set(norm_ckpt.keys()).intersection(deeplab_state.keys())

    group_diffs = defaultdict(list)
    sample_diffs = {}
    for k in sorted(common):
        t1 = norm_ckpt[k].detach().to(device)
        t2 = deeplab_state[k].detach().to(device)
        diff_percent = torch.mean(torch.abs(t1 - t2) / (torch.abs(t2) + epsilon)) * 100.0

        group = k.split('.')[0]
        group_diffs[group].append(diff_percent.item())
        sample_diffs[k] = diff_percent.item()

    # Display grouped results
    print(f"✅ Common keys: {len(common)}")
    for group, diffs in group_diffs.items():
        avg_diff = sum(diffs) / len(diffs)
        print(f"📂 {group:<8} | keys: {len(diffs):<4} | avg % diff: {avg_diff:.2f}%")

    # Sample differences
    print("\n🔎 Sample percent differences (first 10):")
    for k in list(sample_diffs.keys())[:10]:
        print(f"  {k:<40} diff={sample_diffs[k]:.2f}%")

    # Plot bar chart
    groups, avgs = zip(*[(g, sum(d)/len(d)) for g, d in group_diffs.items
