# Deeplabv3 Imagenet1k - Validation pre Domain Adaptation

In [18]:
from PIL import Image
from cityscapesscripts.helpers import labels
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import time
import gc
from tqdm.auto import tqdm
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import Cityscapes
from backbones_unet.model.unet import Unet

In [19]:
# Set seeds for reproducibility
def set_seeds(seed=42):
    """Set seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.mps.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed
SEED = 42
set_seeds(SEED)
print(f"Seeds set for reproducibility is {SEED}")

Seeds set for reproducibility is 42


In [20]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() 
                      else "mps" if torch.mps.is_available() 
                      else "cpu"
                    )
print(DEVICE)

mps


In [21]:
class SegmentationDataset(Dataset):
    def __init__(self, images_dir, targets_dir, image_transform=None, target_transform=None):
        """
        Args:
            images_dir (string): Directory with all the images.
            targets_dir (string): Directory with all the target masks.
            image_transform (callable, optional): Optional transform to be applied on images.
            target_transform (callable, optional): Optional transform to be applied on targets.
        """
        self.images_dir = images_dir
        self.targets_dir = targets_dir
        self.image_transform = image_transform
        self.target_transform = target_transform
        
        # Get all image filenames
        self.image_filenames = [f for f in os.listdir(images_dir) 
                               if f.lower().endswith(('.png'))]
        self.image_filenames.sort()
        
    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        # Load image
        img_name = self.image_filenames[idx]
        img_path = os.path.join(self.images_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        # Load target mask
        target_name = img_name.replace('.png', '_trainId.png')
        target_path = os.path.join(self.targets_dir, target_name)
        target = Image.open(target_path)
        
        # Apply transforms
        if self.image_transform:
            image = self.image_transform(image)
        
        if self.target_transform:
            target = self.target_transform(target)
       # else:
       #     # Default: convert to tensor
       #     target = torch.from_numpy(target)
        
        return image, target

### Loading the Synthetic Dataset

In [22]:
path_images = "syn_resized_images"
path_target = "syn_resized_gt"

image_transform = transforms.Compose([
    transforms.Resize((256, 480)), # We augment from 466x256 to 480x256 in order to avoid the padding artifacts
    transforms.ToTensor(),  # Converts PIL Image to Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize with ImageNet parameters
])

target_transform = transforms.Compose([
    transforms.Resize((256, 480), interpolation=Image.NEAREST), # This interpolation ensure that all pixels have a correct value of their class
    transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
])
syn_dataset = SegmentationDataset(images_dir=path_images, targets_dir=path_target, image_transform=image_transform, target_transform=target_transform)

### Splitting Synthetic Dataset

In [23]:
# Get total dataset size
total_size = len(syn_dataset)
print(f"Total dataset size: {total_size}")

# Calculate split sizes (60% train, 10% val, 30% test)
train_size = int(0.6 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

print(f"Train size: {train_size} ({train_size/total_size*100:.1f}%)")
print(f"Validation size: {val_size} ({val_size/total_size*100:.1f}%)")
print(f"Test size: {test_size} ({test_size/total_size*100:.1f}%)")

# Create random splits
syn_train_dataset, syn_val_dataset, syn_test_dataset = random_split(
    syn_dataset, 
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(SEED)  # For reproducibility
)

# Create DataLoaders
batch_size = 8

syn_train_dataloader = DataLoader(
    syn_train_dataset,
    batch_size=batch_size,
    shuffle=True,
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED) 
)


syn_val_dataloader = DataLoader(
    syn_val_dataset,
    batch_size=batch_size,
    shuffle=True,  
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

syn_test_dataloader = DataLoader(
    syn_test_dataset,
    batch_size=batch_size,
    shuffle=True,  
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

print(f"\nDataLoaders created:")
print(f"Train batches: {len(syn_train_dataloader)}")
print(f"Validation batches: {len(syn_val_dataloader)}")
print(f"Test batches: {len(syn_test_dataloader)}")

Total dataset size: 12500
Train size: 7500 (60.0%)
Validation size: 1250 (10.0%)
Test size: 3750 (30.0%)

DataLoaders created:
Train batches: 938
Validation batches: 157
Test batches: 469


### Loading the Real Dataset

In [24]:
def cityscapes_mask_to_train_ids(mask):
    mask_np = np.array(mask)
    labels_mapping = {label.id: label.trainId for label in labels.labels if label.trainId != -1}
    labels_mapping[-1] = 255  # Map invalid labels to 255
    mask_np = np.vectorize(labels_mapping.get)(mask_np)
    mask_np = np.where(mask_np <= 18, mask_np, 255).astype(np.uint8)  # Ensure that all pixels are within the range of train IDs
    return torch.from_numpy(mask_np).long()

real_target_transform = transforms.Compose([
    transforms.Resize((256, 480), interpolation=Image.NEAREST), # This interpolation ensure that all pixels have a correct value of their class
    transforms.Lambda(lambda x: cityscapes_mask_to_train_ids(x))  # Convert Cityscapes mask to train IDs
])

real_train_dataset = Cityscapes(
    root='cityscapes',
    split='train',
    mode='fine',
    target_type='semantic',
    transform=image_transform,
    target_transform=real_target_transform
)

real_val_dataset = Cityscapes(
    root='cityscapes',
    split='val',
    mode='fine',
    target_type='semantic',
    transform=image_transform,
    target_transform=real_target_transform
)

real_test_dataset = Cityscapes(
    root='cityscapes',
    split='test',
    mode='fine',
    target_type='semantic',
    transform=image_transform,
    target_transform=real_target_transform
)

# Print dataloaders sizes
real_total_size = len(real_train_dataset) + len(real_val_dataset) + len(real_test_dataset)
print(f"Total dataset size: {real_total_size}")
print(f"Train size: {len(real_train_dataset)} ({len(real_train_dataset)/real_total_size*100:.1f}%)")
print(f"Validation size: {len(real_val_dataset)} ({len(real_val_dataset)/real_total_size*100:.1f}%)")
print(f"Test size: {len(real_test_dataset)} ({len(real_test_dataset)/real_total_size*100:.1f}%)")

Total dataset size: 5000
Train size: 2975 (59.5%)
Validation size: 500 (10.0%)
Test size: 1525 (30.5%)


In [25]:
real_train_dataloader = DataLoader(
    real_train_dataset,
    batch_size=batch_size,
    shuffle=True,
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

real_val_dataloader = DataLoader(
    real_val_dataset,
    batch_size=batch_size,
    shuffle=True,  
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

real_test_dataloader = DataLoader(
    real_test_dataset,
    batch_size=batch_size,
    shuffle=True,  
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

In [26]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=0.1, ignore_index=None):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.ignore_index = ignore_index

    def forward(self, inputs, targets):
        """
        inputs: (N, C, H, W) - logits (non softmaxati)
        targets: (N, H, W)   - ground truth con classi (0...C-1)
        """
        num_classes = inputs.shape[1]
        device = inputs.device  # Get device from input tensor
        
        # Softmax sulle predizioni
        probs = F.softmax(inputs, dim=1)  # (N, C, H, W)
        
        # Handle ignore_index by creating a mask and filtering out ignored pixels
        if self.ignore_index is not None:
            # Create mask for valid pixels
            valid_mask = (targets != self.ignore_index)  # (N, H, W)
            
            # Only process valid pixels
            valid_targets = targets[valid_mask]  # (N_valid,)
            
            # Reshape probs to match and filter valid pixels
            probs_reshaped = probs.permute(0, 2, 3, 1)  # (N, H, W, C)
            valid_probs = probs_reshaped[valid_mask]  # (N_valid, C)
            
        else:
            # No ignore_index, process all pixels
            valid_targets = targets.view(-1)  # (N*H*W,)
            valid_probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, num_classes)  # (N*H*W, C)
        
        # One-hot encoding of valid targets only
        if len(valid_targets) > 0:
            targets_one_hot = F.one_hot(valid_targets, num_classes=num_classes).float()  # (N_valid, C)
        else:
            # If no valid pixels, create a zero loss that maintains gradients
            # Use a small operation on the input to maintain gradient flow
            zero_loss = (inputs * 0.0).sum()  # This maintains gradients from inputs
            return zero_loss
        
        # Calcolo Dice per ogni classe usando solo pixel validi
        intersection = (valid_probs * targets_one_hot).sum(dim=0)  # (C,)
        union = valid_probs.sum(dim=0) + targets_one_hot.sum(dim=0)  # (C,)
        
        
        smooth_tensor = torch.tensor(self.smooth, device=device, dtype=intersection.dtype)
        dice = (2.0 * intersection + smooth_tensor) / (union + smooth_tensor)
        
        # Media sulle classi
        loss = 1.0 - dice.mean()
        return loss

## Maximum Mean Discrepancy Loss (MMD) 
$$
    MMD(X, Y) = \frac{1}{|X|^2} \sum_{i,j} k(x_i, x_j) - \frac{2}{|X||Y|} \sum_{i,j} k(x_i, y_j) + \frac{1}{|Y|^2} \sum_{i,j} k(y_i, y_j)
$$
where:
- $k$ is a kernel function (e.g., Gaussian).
- $X$ is the source domain (MNIST) and $Y$ is the target domain (SVHN).
- $x_i$ and $y_j$ are samples from the respective domains.
- $|X|$ and $|Y|$ are the number of samples in each domain.

In [27]:
# Define MMD loss between source and target feature batches
def rbf_kernel(a, b, sigma):
    a = a.unsqueeze(1) 
    b = b.unsqueeze(0) 
    dist = (a - b).pow(2).sum(2) 
    return torch.exp(-dist / (2 * sigma ** 2))

def MMD_loss(x, y, sigma=1.0):

    K_xx = rbf_kernel(x, x, sigma)
    K_yy = rbf_kernel(y, y, sigma)
    K_xy = rbf_kernel(x, y, sigma)
    mmd = K_xx.mean() + K_yy.mean() - 2 * K_xy.mean()
    return mmd

In [28]:
import ssl
import urllib.request
from contextlib import contextmanager
# Disable SSL verification for urllib requests on MacOS
# This is a workaround for the "SSL: CERTIFICATE_VERIFY_FAILED" error on MacOS
@contextmanager
def no_ssl_verification():
    """Temporarily disable SSL verification"""
    old_context = ssl._create_default_https_context
    ssl._create_default_https_context = ssl._create_unverified_context
    try:
        yield
    finally:
        ssl._create_default_https_context = old_context

NUM_CLASSES = 19
LR = 1e-4

CE_loss = nn.CrossEntropyLoss(ignore_index=255)  # 255 = unlabeled
Dice_loss = DiceLoss(smooth=0.1, ignore_index=255)  

ce_importance = 0.7
mmd_weight = 0.1

def criterion(outputs, targets):
    ce_loss = CE_loss(outputs, targets)
    dice_loss = Dice_loss(outputs, targets)

    cls_loss = ce_importance * ce_loss + (1 - ce_importance) * dice_loss
    return cls_loss


In [29]:
# Load model

# Instantiate the model architecture
if DEVICE.type == 'mps': 
    print("mps detected, using no_ssl_verification")
    with no_ssl_verification():
        model = Unet(
            backbone='resnet50',
            pretrained=True,
            in_channels=3,
            num_classes=NUM_CLASSES,
        )
else:
    model = Unet(
        backbone='resnet50',
        pretrained=True,
        in_channels=3,
        num_classes=NUM_CLASSES,
    )
    
model = model.to(DEVICE)

# Load the checkpoint and filter out aux_classifier keys
best_model_path = "models/unet_imagenet1k_480x256_best_model.pth"
checkpoint = torch.load(best_model_path, map_location=DEVICE)
print (f"Loading model from: {best_model_path}")

name = "DA_unet_imagenet1k_480x256"

state_dict = checkpoint["model_state_dict"]
model.load_state_dict(state_dict, strict=False)
optimizer = optim.Adam(model.parameters(), lr=LR)

mps detected, using no_ssl_verification
Loading model from: models/unet_imagenet1k_480x256_best_model.pth


In [30]:
def compute_iou(preds, labels, num_classes, ignore_index=255):
    
    preds = torch.argmax(preds, dim=1).detach().cpu()  # [B, H, W]
    
    labels = labels.detach().cpu() 
    
    ious = []
    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (labels == cls)
        
        # Escludi pixel ignorati
        mask = (labels != ignore_index)
        pred_inds = pred_inds & mask
        target_inds = target_inds & mask

        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        
        if union == 0:
            continue  # salta classe non presente
        ious.append(intersection / union)
    
    if len(ious) == 0:
        return float('nan')  # o 0.0 se preferisci
    return sum(ious) / len(ious)

In [31]:
def clear_memory():
    """Clear memory and cache for all device types"""
    # Run garbage collector
    gc.collect()
    
    # Get device from global scope
    if 'DEVICE' in globals():
        if DEVICE.type == 'cuda':
            torch.cuda.empty_cache()

        if DEVICE.type == 'mps':
            torch.mps.empty_cache()

    # Second GC run to make sure everything is cleaned up
    gc.collect()

In [32]:
def validate_epoch(
        model, criterion, syn_val_dataloader, real_val_dataloader
    ):
    
    # Validation 
    model.eval()
    val_total_loss = 0
    val_iou_source = 0
    val_iou_target = 0

    val_steps = min(len(syn_val_dataloader), len(real_val_dataloader))
    
    with torch.no_grad():

        syn_val_iter = iter(syn_val_dataloader)
        real_val_iter = iter(real_val_dataloader)

        for i in tqdm(range(val_steps), desc="Validation"):
            source_data, source_labels = next(syn_val_iter)
            target_data, target_labels = next(real_val_iter)
            source_data, source_labels = source_data.to(DEVICE), source_labels.to(DEVICE)
            target_data, target_labels = target_data.to(DEVICE), target_labels.to(DEVICE)

            outputs_source = model(source_data)
            outputs_target = model(target_data)

            loss_real = criterion(outputs_source, source_labels)
            loss_target = criterion(outputs_target, target_labels)

            val_total_loss += loss_real.detach().cpu().item() 
            val_iou_source += compute_iou(outputs_source.detach().cpu(), source_labels, NUM_CLASSES)
            val_iou_target += compute_iou(outputs_target.detach().cpu(), target_labels, NUM_CLASSES)

            # Explicit cleanup
            del source_data, source_labels, target_data, target_labels
            del outputs_source, outputs_target, loss_real, loss_target
            
            # Periodic memory clear
            if (i + 1) % 5 == 0:
                clear_memory()
    
    clear_memory()
    
    # Average losses and IOUs
    source_val_avg_loss = val_total_loss / val_steps
    source_val_avg_iou = val_iou_source / val_steps
    target_val_avg_iou = val_iou_target / val_steps
    
    return source_val_avg_loss, source_val_avg_iou, target_val_avg_iou

In [33]:
print(f"Best model loaded successfully!")
print(f"Best validation IoU: {checkpoint['val_iou']:.4f}")
print(f"From epoch: {checkpoint['epoch']}")

Best model loaded successfully!
Best validation IoU: 0.5983
From epoch: 24


In [34]:
source_val_avg_loss, source_val_avg_iou, target_val_avg_iou = validate_epoch(
    model, criterion, syn_val_dataloader, real_val_dataloader
)

print(f"Source Validation Average Loss: {source_val_avg_loss:.4f}")
print(f"Source Validation Average IoU: {source_val_avg_iou:.4f}")
print(f"Target Validation Average IoU: {target_val_avg_iou:.4f}")

Validation: 100%|██████████| 63/63 [01:06<00:00,  1.06s/it]

Source Validation Average Loss: 0.2589
Source Validation Average IoU: 0.5927
Target Validation Average IoU: 0.2793



