# Imports and Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import Cityscapes
from torchvision.models.segmentation import DeepLabV3_ResNet50_Weights
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# 1. Setup Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 2. Reproducibility (Crucial for MSc reports)
torch.manual_seed(12)
np.random.seed(12)

# The Dataset (Cityscapes Adaptation)

In [None]:
# Define transformations
# Rresize to 256x512 to speed up training 
data_transforms = transforms.Compose([
    transforms.Resize((256, 512)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

target_transforms = transforms.Compose([
    transforms.Resize((256, 512), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.PILToTensor(), # Converts to [0, 255] int tensor
])

# Load Dataset (assuming data is in ./data/cityscapes)
# 'quality_mode'='fine' is the standard segmentation task
print("Loading dataset...")
try:
    train_dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
                               target_type='semantic', transform=data_transforms, 
                               target_transform=target_transforms)
    
    val_dataset = Cityscapes('./data/cityscapes', split='val', mode='fine',
                             target_type='semantic', transform=data_transforms, 
                             target_transform=target_transforms)
    
    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)
    print("Data loaded successfully.")
except Exception as e:
    print(f"Error loading data: {e}")
    print("Ensure Cityscapes is downloaded and extracted to ./data/cityscapes")

# Visualize Data (Exploratory Data Analysis)

In [None]:
def visualize_sample(dataset, index=0):
    img, mask = dataset[index]
    
    # Undo normalization for visualization
    inv_normalize = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )
    img = inv_normalize(img)
    
    # Plot
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].imshow(img.permute(1, 2, 0).numpy())
    ax[0].set_title("Input Image")
    
    # Mask is [1, H, W], squeeze to [H, W]
    ax[1].imshow(mask.squeeze(), cmap='jet') 
    ax[1].set_title("Ground Truth Mask")
    plt.show()

visualize_sample(train_dataset, index=10)

# Model Selection & Adaptation

In [None]:
# 1. Load Pre-trained DeepLabV3 with ResNet50 backbone
# Weights.DEFAULT loads the best available pre-trained weights
model = models.segmentation.deeplabv3_resnet50(weights=DeepLabV3_ResNet50_Weights.DEFAULT)

# 2. Adapt the Classifier Head
# Cityscapes has 19 valid classes. (around 30 total, but usually mapped to 19 for training)
# Here we assume the raw dataset output, or we map 255 to ignore).
num_classes = 34 # Raw cityscapes classes usually go up to 33 + background

# DeepLab has a 'classifier' module. The last layer is classifier[4]
model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1))

# If using auxiliary loss (helps training convergence), adapt that too
model.aux_classifier[4] = nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1))

model = model.to(device)

# Loss Function & Optimizer

In [None]:
# 255 is the 'void' class in Cityscapes (borders, ego vehicle, etc.)
criterion = nn.CrossEntropyLoss(ignore_index=255) 

# Optimization: Adam or SGD
# Learning rate is a hyperparameter 
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# The Training Loop

In [None]:
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    
    for images, masks in loader:
        images = images.to(device)
        masks = masks.to(device).long().squeeze(1) # [Batch, H, W]
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)['out'] # DeepLab returns a dictionary
        
        # Calculate Loss
        loss = criterion(outputs, masks)
        
        # Backward pass 
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    return running_loss / len(loader)

# Training Loop
num_epochs = 1
loss_history = []

print("Starting training...")
for epoch in range(num_epochs):
    loss = train_one_epoch(model, train_loader, criterion, optimizer)
    loss_history.append(loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}")

# Plotting Convergence 
plt.plot(loss_history)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

# Evaluation (IoU)

In [None]:
def calculate_iou(pred_mask, true_mask, num_classes):
    ious = []
    pred_mask = pred_mask.view(-1)
    true_mask = true_mask.view(-1)
    
    # Ignore void class (255)
    valid = true_mask != 255
    pred_mask = pred_mask[valid]
    true_mask = true_mask[valid]
    
    for cls in range(num_classes):
        pred_inds = pred_mask == cls
        target_inds = true_mask == cls
        
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        
        if union > 0:
            ious.append(intersection / union)
            
    return np.mean(ious)

# Evaluate on Validation Set
model.eval()
total_iou = 0
batches = 0

with torch.no_grad():
    for images, masks in val_loader:
        images = images.to(device)
        masks = masks.to(device).squeeze(1)
        
        outputs = model(images)['out']
        preds = torch.argmax(outputs, dim=1) # Convert logits to class index
        
        batch_iou = 0
        for i in range(len(images)):
            batch_iou += calculate_iou(preds[i], masks[i], num_classes)
        
        total_iou += batch_iou / len(images)
        batches += 1
        
print(f"Mean IoU on Validation Set: {total_iou/batches:.4f}")

# Critical Reflection