# EfficientNet-B0 Experimentation on Cityscapes Dataset for Semantic Segmentation

This notebook implements a series of experiments to evaluate and improve the performance of EfficientNet-B0 on the Cityscapes dataset for semantic segmentation.

## Overview

1. **Baseline Experiment**: Train EfficientNet-B0 with segmentation head
2. **Modified Models**:
   - Add CBAM (Convolutional Block Attention Module)
   - Switch to Mish activation function
   - Add DeeplabV3+ segmentation head
3. **Comparative Analysis**: Compare and analyze the results across all models

In [None]:
import os
from google.colab import drive

# Mount Google Drive to a path without spaces
drive.mount('/content/drive/')  # Changed the mount point

# Construct the path to the datasets directory with spaces
datasets_dir = './drive/MyDrive/NTU-AI6103-DEEP-LEARNING-AND-APPLICATIONS/Group-Assignment/datasets/Cityscapes'

# Check if the directory exists
if os.path.exists(datasets_dir):
    print(f"Datasets directory found: {datasets_dir}")
else:
    print(f"Datasets directory not found: {datasets_dir}")
    print("Please make sure the path is correct and the directory exists in your Google Drive.")

## 1. Environment Setup

First, let's import all necessary libraries for our experiments:

- PyTorch and related libraries for deep learning
- EfficientNet implementation
- Data processing libraries (NumPy, Pandas, etc.)
- Visualization and progress tracking tools

It also checks CUDA availability to ensure GPU acceleration if available.

In [None]:
# Install required dependencies
%pip install torch torchvision torchaudio
%pip install efficientnet_pytorch
%pip install numpy pandas matplotlib
%pip install tqdm scikit-learn
%pip install jupyter

# For CUDA compatibility check
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")

In [None]:
# Import standard libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
from tqdm.notebook import tqdm

# PyTorch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
from torchvision import transforms, models
from efficientnet_pytorch import EfficientNet

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Data Preparation

### 2.1 Loading Cityscapes Dataset

Here we'll load the Cityscapes dataset from its original directory structure, subsample 1500 images, and create our train/validation/test splits. The Cityscapes dataset is particularly well-suited for segmentation tasks as it provides pixel-level annotations for urban street scenes.

In [None]:
# Clone the Cityscapes repository if not already present
!git clone https://github.com/mcordts/cityscapesScripts.git
%pip install -e cityscapesScripts

In [None]:
# Import Cityscapes helper functions
import sys
import os
import glob
import random
from PIL import Image

# Add the cityscapesScripts directory to the Python path
cwd = os.getcwd()
cityscapes_path = os.path.join(cwd, 'cityscapesScripts')
if cityscapes_path not in sys.path:
    sys.path.append(cityscapes_path)

# Now import the modules
from cityscapesscripts.helpers.labels import trainId2label, id2label

# Define paths to dataset directories
cityscapes_root = './drive/MyDrive/NTU-AI6103-DEEP-LEARNING-AND-APPLICATIONS/Group-Assignment/datasets/Cityscapes'
images_dir = os.path.join(cityscapes_root, 'leftImg8bit_trainvaltest', 'leftImg8bit')
annotations_dir = os.path.join(cityscapes_root, 'gtFine_trainvaltest', 'gtFine')

# Print images_dir and annotations_dir to check if they are formed correctly
print(f"Images directory: {images_dir}")
print(f"Annotations directory: {annotations_dir}")

# Function to collect image and label pairs from train, val, and test folders
def collect_dataset_files():
    splits = ['train', 'val', 'test']
    datasets = {'train': [], 'val': [], 'test': []}
    for split in splits:
        split_img_dir = os.path.join(images_dir, split)
        split_label_dir = os.path.join(annotations_dir, split)
        city_dirs = [d for d in os.listdir(split_img_dir) if os.path.isdir(os.path.join(split_img_dir, d))]
        image_paths = []
        label_paths = []
        for city in city_dirs:
            city_img_dir = os.path.join(split_img_dir, city)
            city_img_files = glob.glob(os.path.join(city_img_dir, '*_leftImg8bit.png'))
            for img_path in city_img_files:
                img_name = os.path.basename(img_path)
                img_id = img_name.replace('_leftImg8bit.png', '')
                label_name = f"{img_id}_gtFine_labelIds.png"
                label_path = os.path.join(split_label_dir, city, label_name)
                if os.path.exists(label_path):
                    image_paths.append(img_path)
                    label_paths.append(label_path)
        datasets[split] = (image_paths, label_paths)
    return datasets['train'][0], datasets['train'][1], datasets['val'][0], datasets['val'][1], datasets['test'][0], datasets['test'][1]

# Map Cityscapes IDs to train IDs (0–18, 255 for void)
def map_cityscapes_labels(label):
    label_np = np.array(label, dtype=np.uint8)
    mapped_label = np.full_like(label_np, 255, dtype=np.uint8)
    id_to_trainid = {
        7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5, 19: 6, 20: 7, 21: 8,
        22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 28: 15, 31: 16,
        32: 17, 33: 18
    }
    for id_, train_id in id_to_trainid.items():
        mapped_label[label_np == id_] = train_id

    # # Debug: Print the number of pixels for each class in the first few images
    # if random.random() < 0.05:  # Only print for ~5% of images to avoid flooding output
    #     unique_values, counts = np.unique(mapped_label, return_counts=True)
    #     print("Label distribution:")
    #     for val, count in zip(unique_values, counts):
    #         class_name = 'void' if val == 255 else trainId2label[val].name
    #         print(f"  Class {val} ({class_name}): {count} pixels")

    return Image.fromarray(mapped_label)

# Define dataset class
class CityscapesDataset(Dataset):
    def __init__(self, image_paths, label_paths, transform=None, target_transform=None):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label_path = self.label_paths[idx]
        label = Image.open(label_path)
        label = map_cityscapes_labels(label)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
            label = label.squeeze(0).long()
        return image, label

# Collect dataset
print("Collecting dataset files...")
train_image_paths, train_label_paths, val_image_paths, val_label_paths, test_image_paths, test_label_paths = collect_dataset_files()
print(f"Found {len(train_image_paths)} train pairs, {len(val_image_paths)} val pairs, {len(test_image_paths)} test pairs")

# Define transforms
train_transform = transforms.Compose([
    transforms.RandomResizedCrop((224, 224), scale=(0.5, 1.5)),
    transforms.RandomAffine(degrees=10, shear=10),  # Added shear
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.3, contrast=0.2, saturation=0.2, hue=0.1),  # Increased brightness
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = val_transform

target_transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
])

# Create datasets
train_dataset = CityscapesDataset(train_image_paths, train_label_paths, transform=train_transform, target_transform=target_transform)
val_dataset = CityscapesDataset(val_image_paths, val_label_paths, transform=val_transform, target_transform=target_transform)
test_dataset = CityscapesDataset(test_image_paths, test_label_paths, transform=test_transform, target_transform=target_transform)

# Create dataloaders
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Check label distribution in the training dataset
def check_label_distribution(dataset, num_samples=5):
    print("\nChecking label distribution in dataset...")
    class_counts = np.zeros(20)  # 19 classes + void (255)

    for i in range(min(num_samples, len(dataset))):
        idx = np.random.randint(len(dataset))
        _, label = dataset[idx]
        # If label has channel dimension, remove it
        if label.dim() == 3 and label.shape[0] == 1:
            label = label.squeeze(0)

        unique_values, counts = np.unique(label.numpy(), return_counts=True)
        print(f"Image {i+1}/{num_samples} (idx {idx}) - Unique values: {unique_values}")

        for val, count in zip(unique_values, counts):
            class_idx = 19 if val == 255 else val  # Store void class (255) at index 19
            class_counts[class_idx] += count

    # Print summary
    print("\nClass distribution summary:")
    class_names = [
        'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
        'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
        'truck', 'bus', 'train', 'motorcycle', 'bicycle', 'void'
    ]

    for i in range(20):
        if class_counts[i] > 0:
            print(f"Class {i if i < 19 else 255} ({class_names[i]}): {class_counts[i]:.0f} pixels")

# Run the check on training dataset
check_label_distribution(train_dataset, num_samples=10)

# Show a sample image from the dataset
def show_sample(dataset, idx=0):
    img, label = dataset[idx]

    # Denormalize the image
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    img_denorm = img * std[:, None, None] + mean[:, None, None]

    # Create a color-mapped version of the label for better visualization
    # Convert label tensor to numpy and ensure it's 2D by squeezing out the channel dimension
    label_np = label.squeeze().numpy()  # Remove the channel dimension (1,224,224) -> (224,224)

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(img_denorm.permute(1, 2, 0).numpy())
    plt.title('Image')

    plt.subplot(1, 2, 2)
    plt.imshow(label_np, cmap='viridis')  # Add a colormap for better visualization
    plt.title('Segmentation Mask')
    plt.colorbar()  # Add a colorbar to show the mapping of class IDs to colors
    plt.show()

# Visualize a random sample from the training dataset
show_sample(train_dataset, idx=np.random.randint(len(train_dataset)))

# Also visualize a sample from validation and test to ensure everything looks correct
show_sample(val_dataset, idx=np.random.randint(len(val_dataset)))
show_sample(test_dataset, idx=np.random.randint(len(test_dataset)))

### Setting Up Data Transformations and Loading Dataset

This cell configures data preprocessing pipelines for both training and validation/testing:
1. Training transforms include data augmentation (flips, rotations, color jitter)
2. All images are resized to 224x224 pixels to match EfficientNet-B0's input size
3. Images are normalized using ImageNet mean and standard deviation
4. Segmentation masks are also resized to 224x224 but using nearest-neighbor interpolation to preserve label values

We then load the Cityscapes dataset using our custom dataset class that handles both images and their corresponding segmentation masks.

## 3. Baseline Model: EfficientNet-B0 for Segmentation

### Defining the Baseline EfficientNet-B0 Segmentation Model

This cell implements our baseline segmentation model by:
1. Creating a custom EfficientNetB0Segmentation class that uses the pre-trained model as an encoder
2. Adding a decoder network that upsamples features to produce full-resolution segmentation masks
3. Setting up the model to output predictions for 19 classes (Cityscapes semantic classes) at each pixel
4. Initializing the model and moving it to the appropriate device (GPU if available)
5. Setting up the loss function (Cross-Entropy for segmentation), optimizer (SGD with momentum), and learning rate scheduler

In [None]:
class EfficientNetB0Segmentation(nn.Module):
    def __init__(self, num_classes=19):  # Cityscapes has 19 classes with trainId
        super(EfficientNetB0Segmentation, self).__init__()
        # Load the pre-trained EfficientNet-B0 model as the encoder
        self.encoder = EfficientNet.from_pretrained('efficientnet-b0')
        # Get the number of features from the last layer
        self.encoder_features = self.encoder._fc.in_features

        # Remove the classification head
        self.encoder._fc = nn.Identity()

        # Create a simple decoder for segmentation
        self.decoder = nn.Sequential(
            # Upsample to get back to input resolution
            nn.ConvTranspose2d(self.encoder_features, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, num_classes, kernel_size=3, padding=1)
        )

    def forward(self, x):
        # Extract features from the encoder
        features = self.encoder.extract_features(x)  # Shape: [B, C, H/32, W/32]

        # Pass through decoder to get segmentation map
        segmentation_map = self.decoder(features)  # Shape: [B, num_classes, H, W]

        # Ensure output size matches input size
        if segmentation_map.shape[-2:] != x.shape[-2:]:
            segmentation_map = F.interpolate(segmentation_map, size=x.shape[-2:], mode='bilinear', align_corners=True)

        return segmentation_map

# Initialize the baseline segmentation model
baseline_model = EfficientNetB0Segmentation().to(device)

# Define loss function and optimizer for segmentation
# Ignore index 255 which is the 'ignored' label in Cityscapes
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.SGD(baseline_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1)

### Implementing Training and Evaluation Functions

This cell defines two essential functions for model training and evaluation in a segmentation task:
1. `train_one_epoch`: Handles a complete training cycle for semantic segmentation, including:
   - Forward and backward passes through the network
   - Gradient computation and parameter updates
   - Loss and segmentation metrics tracking (mean IoU, pixel accuracy)
2. `evaluate`: Performs model evaluation on validation or test data:
   - Forward passes without gradient computation (using `torch.no_grad()`)
   - Computes segmentation metrics (mean IoU, pixel accuracy)
   - Visual inspection of segmentation quality

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_pixel_acc = 0.0
    running_iou = 0.0
    processed_data = 0
    num_classes = 19  # Cityscapes has 19 classes with trainId (0-18)

    # Initialize tensors to track intersection and union for each class
    class_intersection = torch.zeros(num_classes).to(device)
    class_union = torch.zeros(num_classes).to(device)

    # Debug counters
    valid_label_count = 0
    batch_count = 0

    for inputs, labels in tqdm(dataloader, desc="Training"):
        batch_count += 1
        inputs = inputs.to(device)

        # Label shape debug - before any processing
        if batch_count == 1:
            print(f"Original label shape: {labels.shape}, dtype: {labels.dtype}")
            print(f"Label min: {labels.min()}, max: {labels.max()}, unique: {torch.unique(labels)}")

        # Remove channel dimension for CrossEntropyLoss [B, H, W]
        # If labels have shape [B, 1, H, W], squeeze them to [B, H, W]
        if labels.dim() == 4 and labels.shape[1] == 1:
            labels = labels.squeeze(1)

        labels = labels.long().to(device)  # Ensure labels are of type Long

        # Label shape debug - after processing
        if batch_count == 1:
            print(f"Processed label shape: {labels.shape}, dtype: {labels.dtype}")
            print(f"Label min: {labels.min()}, max: {labels.max()}")
            print(f"Label unique values: {torch.unique(labels)}")

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)  # Shape: [B, num_classes, H, W]

        # Debug first batch outputs
        if batch_count == 1:
            print(f"Model output shape: {outputs.shape}, dtype: {outputs.dtype}")

        # Compute loss - CrossEntropyLoss expects (N,C,d1,d2...) for input and (N,d1,d2...) for target
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Get predictions
        _, preds = torch.max(outputs, 1)  # Shape: [B, H, W]

        # Calculate pixel accuracy (ignoring void pixels with value 255)
        valid_pixels = (labels != 255)  # Create mask for valid pixels
        valid_label_count += torch.sum(valid_pixels).item()
        correct_pixels = torch.sum((preds == labels) & valid_pixels).item()
        total_valid_pixels = torch.sum(valid_pixels).item()
        pixel_acc = correct_pixels / (total_valid_pixels + 1e-8)

        # Calculate IoU (Intersection over Union) for each class
        for cls in range(num_classes):
            # For each class, find pixels where prediction is this class
            pred_inclass = (preds == cls)
            # Find pixels where ground truth is this class
            target_inclass = (labels == cls)
            # Calculate intersection (pixels that are this class in both pred and target)
            intersection = torch.sum(pred_inclass & target_inclass).item()
            # Calculate union (pixels that are this class in either pred or target)
            union = torch.sum(pred_inclass | target_inclass).item()
            # Accumulate for epoch-level metrics
            class_intersection[cls] += intersection
            class_union[cls] += union

        # Calculate batch mean IoU (for tracking only)
        batch_intersection = torch.zeros(num_classes).to(device)
        batch_union = torch.zeros(num_classes).to(device)
        classes_present = []

        for cls in range(num_classes):
            pred_cls = (preds == cls)
            target_cls = (labels == cls)
            batch_intersection[cls] = torch.sum(pred_cls & target_cls).item()
            batch_union[cls] = torch.sum(pred_cls | target_cls).item()
            if torch.sum(target_cls) > 0:
                classes_present.append(cls)

        # Avoid division by zero with small epsilon
        batch_iou = batch_intersection / (batch_union + 1e-8)

        # Only consider classes that are present in this batch
        valid_classes = (batch_union > 0)
        if torch.sum(valid_classes) > 0:
            iou = torch.mean(batch_iou[valid_classes]).item()
        else:
            iou = 0.0

        # Update statistics
        running_loss += loss.item() * inputs.size(0)
        running_pixel_acc += pixel_acc * inputs.size(0)
        running_iou += iou * inputs.size(0)
        processed_data += inputs.size(0)

        # Debug for first few batches
        if batch_count <= 3:
            print(f"Batch {batch_count} - Classes present: {classes_present}")
            print(f"Batch {batch_count} - pixel_acc: {pixel_acc:.4f}, IoU: {iou:.4f}")

    # Print final debug info
    print(f"Total valid label pixels: {valid_label_count}")

    # Calculate epoch-level metrics
    train_loss = running_loss / processed_data
    train_pixel_acc = running_pixel_acc / processed_data

    # Calculate per-class IoU for the entire epoch (more accurate than batch-wise)
    class_iou = class_intersection / (class_union + 1e-8)
    # Only consider classes that actually appeared in the dataset
    valid_classes = (class_union > 0)

    # Map trainId to class names for better interpretability
    class_names = [
        'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
        'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
        'truck', 'bus', 'train', 'motorcycle', 'bicycle'
    ]

    # Print per-class IoU with class names
    print("\nPer-class IoU:")
    for cls in range(num_classes):
        if class_union[cls] > 0:
            print(f"{class_names[cls]}: {class_iou[cls]:.4f}")

    if torch.sum(valid_classes) > 0:
        mean_iou = torch.mean(class_iou[valid_classes]).item()
    else:
        mean_iou = 0.0

    # Print summary
    print(f"Valid classes: {torch.sum(valid_classes).item()} out of {num_classes}")
    print(f"Mean IoU: {mean_iou:.4f}, Train loss: {train_loss:.4f}, Train pixel acc: {train_pixel_acc:.4f}")

    # Return the relevant metrics
    return train_loss, train_pixel_acc, mean_iou

In [None]:
def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_pixel_acc = 0.0
    running_iou = 0.0
    processed_data = 0
    num_classes = 19  # Cityscapes has 19 classes with trainId

    # Initialize tensors to track intersection and union for each class
    class_intersection = torch.zeros(num_classes).to(device)
    class_union = torch.zeros(num_classes).to(device)

    # Map trainId to class names for better interpretability
    class_names = [
        'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
        'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
        'truck', 'bus', 'train', 'motorcycle', 'bicycle'
    ]

    batch_count = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            batch_count += 1
            inputs = inputs.to(device)

            # Remove channel dimension for CrossEntropyLoss [B, H, W]
            # If labels have shape [B, 1, H, W], squeeze them to [B, H, W]
            if labels.dim() == 4 and labels.shape[1] == 1:
                labels = labels.squeeze(1)

            labels = labels.long().to(device)  # Ensure labels are of type Long

            # Debug first batch labels and shapes
            if batch_count == 1:
                print(f"Label shape: {labels.shape}, dtype: {labels.dtype}")
                print(f"Label min: {labels.min()}, max: {labels.max()}")
                print(f"Label unique values: {torch.unique(labels)}")

            # Forward pass
            outputs = model(inputs)

            # Debug model output shape
            if batch_count == 1:
                print(f"Model output shape: {outputs.shape}, dtype: {outputs.dtype}")

            # Compute loss
            loss = criterion(outputs, labels)

            # Get predictions
            _, preds = torch.max(outputs, 1)  # Shape: [B, H, W]

            # Calculate pixel accuracy (ignoring void pixels with value 255)
            valid_pixels = (labels != 255)  # Create mask for valid pixels
            correct_pixels = torch.sum((preds == labels) & valid_pixels).item()
            total_valid_pixels = torch.sum(valid_pixels).item()
            pixel_acc = correct_pixels / (total_valid_pixels + 1e-8)

            # Calculate IoU (Intersection over Union) for each class
            for cls in range(num_classes):
                # For each class, find pixels where prediction is this class
                pred_inclass = (preds == cls)
                # Find pixels where ground truth is this class
                target_inclass = (labels == cls)
                # Calculate intersection (pixels that are this class in both pred and target)
                intersection = torch.sum(pred_inclass & target_inclass).item()
                # Calculate union (pixels that are this class in either pred or target)
                union = torch.sum(pred_inclass | target_inclass).item()
                # Accumulate for epoch-level metrics
                class_intersection[cls] += intersection
                class_union[cls] += union

            # Calculate batch mean IoU (for tracking only)
            batch_iou = torch.zeros(num_classes).to(device)
            valid_classes = torch.zeros(num_classes, dtype=torch.bool).to(device)

            for cls in range(num_classes):
                pred_cls = (preds == cls)
                target_cls = (labels == cls)
                intersection = torch.sum(pred_cls & target_cls).item()
                union = torch.sum(pred_cls | target_cls).item()
                if union > 0:
                    batch_iou[cls] = intersection / union
                    valid_classes[cls] = True

            # Only consider classes that are present in this batch
            if torch.sum(valid_classes) > 0:
                iou = torch.mean(batch_iou[valid_classes]).item()
            else:
                iou = 0.0

            # Update statistics
            running_loss += loss.item() * inputs.size(0)
            running_pixel_acc += pixel_acc * inputs.size(0)
            running_iou += iou * inputs.size(0)
            processed_data += inputs.size(0)

    # Calculate epoch-level metrics
    eval_loss = running_loss / processed_data
    eval_pixel_acc = running_pixel_acc / processed_data

    # Calculate per-class IoU for the entire evaluation set
    class_iou = class_intersection / (class_union + 1e-8)
    # Only consider classes that actually appeared in the dataset
    valid_classes = (class_union > 0)

    # Print per-class IoU with class names
    print("\nPer-class IoU:")
    for cls in range(num_classes):
        if class_union[cls] > 0:
            print(f"{class_names[cls]}: {class_iou[cls]:.4f}")

    if torch.sum(valid_classes) > 0:
        mean_iou = torch.mean(class_iou[valid_classes]).item()
    else:
        mean_iou = 0.0

    print(f"Valid classes: {torch.sum(valid_classes).item()} out of {num_classes}")
    print(f"Mean IoU: {mean_iou:.4f}, Eval loss: {eval_loss:.4f}, Pixel acc: {eval_pixel_acc:.4f}")

    return eval_loss, eval_pixel_acc, mean_iou

### Complete Model Training Pipeline

This cell defines and executes the full training pipeline for semantic segmentation:
1. Implements the `train_model` function that orchestrates training over multiple epochs
   - Tracks training and validation metrics in a history dictionary
   - Implements early stopping to save the best model based on validation IoU
   - Adjusts learning rate using the scheduler based on validation loss
2. Imports the `copy` module to maintain a copy of the best model weights
3. Trains the baseline segmentation model for 10 epochs

In [None]:
# Training loop for baseline model
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=10):
    history = {
        'train_loss': [],
        'train_pixel_acc': [],
        'train_iou': [],
        'val_loss': [],
        'val_pixel_acc': [],
        'val_iou': []
    }

    best_model_wts = copy.deepcopy(model.state_dict())
    best_iou = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Train phase
        train_loss, train_pixel_acc, train_iou = train_one_epoch(model, train_loader, criterion, optimizer, device)
        print(f'Train Loss: {train_loss:.4f} Pixel Acc: {train_pixel_acc:.4f} IoU: {train_iou:.4f}')

        # Validation phase
        val_loss, val_pixel_acc, val_iou = evaluate(model, val_loader, criterion, device)
        print(f'Val Loss: {val_loss:.4f} Pixel Acc: {val_pixel_acc:.4f} IoU: {val_iou:.4f}')

        # Update learning rate
        scheduler.step(val_loss)

        # Deep copy the model if it's the best
        if val_iou > best_iou:
            best_iou = val_iou
            best_model_wts = copy.deepcopy(model.state_dict())
            print(f'New best model with IoU: {best_iou:.4f}')

        # Update history
        history['train_loss'].append(train_loss)
        history['train_pixel_acc'].append(train_pixel_acc)
        history['train_iou'].append(train_iou)
        history['val_loss'].append(val_loss)
        history['val_pixel_acc'].append(val_pixel_acc)
        history['val_iou'].append(val_iou)

        print()

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, history

import copy

# Train the baseline segmentation model
baseline_model_trained, baseline_history = train_model(
    baseline_model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    num_epochs=1
)

<!-- ### Evaluating the Baseline Model on Test Set

This cell evaluates the trained baseline segmentation model on the unseen test data:
1. Computes test loss, pixel accuracy, and mean IoU using the previously defined `evaluate` function
2. Prints the results to compare with later model variants
3. These segmentation-specific metrics provide a comprehensive assessment of how well the model performs at the pixel level -->

** Currently we will skip test set **

In [None]:
# # Evaluate the baseline model on test set
# baseline_test_loss, baseline_test_pixel_acc, baseline_test_iou = evaluate(baseline_model_trained, test_loader, criterion, device)
# print(f'Baseline Model - Test Loss: {baseline_test_loss:.4f} Pixel Acc: {baseline_test_pixel_acc:.4f} IoU: {baseline_test_iou:.4f}')

### Visualizing the Training Results

This cell defines and uses a function to visualize training progress for semantic segmentation:
1. Creates the `plot_training_history` function that generates three plots:
   - Training and validation loss curves
   - Training and validation pixel accuracy curves
   - Training and validation mean IoU curves
2. Visualizes the baseline model's training history to analyze convergence and potential overfitting

In [None]:
# Visualize the training history
def plot_training_history(history, title):
    epochs = range(1, len(history['train_loss'])+1)

    plt.figure(figsize=(18, 5))

    plt.subplot(1, 3, 1)
    plt.plot(epochs, history['train_loss'], 'bo-', label='Training Loss')
    plt.plot(epochs, history['val_loss'], 'ro-', label='Validation Loss')
    plt.title(f'{title} - Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(epochs, history['train_pixel_acc'], 'bo-', label='Training Pixel Accuracy')
    plt.plot(epochs, history['val_pixel_acc'], 'ro-', label='Validation Pixel Accuracy')
    plt.title(f'{title} - Pixel Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Pixel Accuracy')
    plt.legend()

    plt.subplot(1, 3, 3)
    plt.plot(epochs, history['train_iou'], 'bo-', label='Training IoU')
    plt.plot(epochs, history['val_iou'], 'ro-', label='Validation IoU')
    plt.title(f'{title} - Mean IoU')
    plt.xlabel('Epochs')
    plt.ylabel('Mean IoU')
    plt.legend()

    plt.tight_layout()
    plt.show()

# Visualize baseline model training history
plot_training_history(baseline_history, 'Baseline EfficientNet-B0 Segmentation')

## 4. Modified Models

### 4.1 EfficientNet-B0 with CBAM (Convolutional Block Attention Module)

CBAM enhances the representational power by focusing on important features and suppressing unnecessary ones. For segmentation tasks, this attention mechanism is particularly helpful as it allows the model to focus on relevant spatial regions and feature channels, leading to more accurate pixel-wise predictions.

### Implementing the CBAM Attention Module

This cell implements the Convolutional Block Attention Module (CBAM) and integrates it with EfficientNet-B0 for segmentation:
1. Defines the `ChannelAttention` class that focuses on important channels
2. Defines the `SpatialAttention` class that emphasizes informative regions
3. Combines both in the `CBAM` class
4. Creates an `EfficientNetB0WithCBAM` class that incorporates CBAM into the segmentation model architecture
5. Implements a decoder structure to convert encoded features to segmentation masks
6. Initializes the model and sets up its optimizer and scheduler

In [None]:
# Implementing CBAM
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        concat = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(concat)
        return self.sigmoid(out)

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_att = ChannelAttention(in_planes, ratio)
        self.spatial_att = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.channel_att(x)
        x = x * self.spatial_att(x)
        return x

# EfficientNet-B0 with CBAM attention for segmentation
class EfficientNetB0WithCBAM(nn.Module):
    def __init__(self, num_classes=19):  # 19 classes for Cityscapes segmentation
        super(EfficientNetB0WithCBAM, self).__init__()
        # Load the pre-trained EfficientNet-B0 model as encoder
        self.encoder = EfficientNet.from_pretrained('efficientnet-b0')
        in_features = self.encoder._fc.in_features

        # Remove the classification head from the encoder
        self.encoder._fc = nn.Identity()

        # Add CBAM at the end of feature extraction
        self.cbam = CBAM(in_features)

        # Create decoder for segmentation (similar to baseline but with CBAM in between)
        self.decoder = nn.Sequential(
            # Upsample to get back to input resolution
            nn.ConvTranspose2d(in_features, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, num_classes, kernel_size=3, padding=1)
        )

    def forward(self, x):
        # Extract features from the encoder
        features = self.encoder.extract_features(x)  # Shape: [B, C, H/32, W/32]

        # Apply CBAM attention
        features_with_attention = self.cbam(features)

        # Pass through decoder to get segmentation map
        segmentation_map = self.decoder(features_with_attention)  # Shape: [B, num_classes, H, W]

        # Ensure output size matches input size
        if segmentation_map.shape[-2:] != x.shape[-2:]:
            segmentation_map = F.interpolate(segmentation_map, size=x.shape[-2:], mode='bilinear', align_corners=True)

        return segmentation_map

# Initialize the CBAM model
cbam_model = EfficientNetB0WithCBAM().to(device)

# Define loss function, optimizer and scheduler for the CBAM segmentation model
cbam_criterion = nn.CrossEntropyLoss(ignore_index=255)  # Ignore index 255 which is the 'ignored' label in Cityscapes
cbam_optimizer = optim.SGD(cbam_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
cbam_scheduler = optim.lr_scheduler.ReduceLROnPlateau(cbam_optimizer, 'min', patience=3, factor=0.1)

### Training and Evaluating the CBAM Segmentation Model

Here we train and evaluate the EfficientNet-B0 model enhanced with CBAM for semantic segmentation:
1. Train the model for 10 epochs using the same training function as the baseline
2. Track segmentation metrics (mean IoU, pixel accuracy) during training
3. Visualize the training history to compare with the baseline segmentation model

In [None]:
# Train the CBAM segmentation model
cbam_model_trained, cbam_history = train_model(
    cbam_model,
    train_loader,
    val_loader,
    cbam_criterion,
    cbam_optimizer,
    cbam_scheduler,
    num_epochs=1
)

In [None]:
# # Evaluate CBAM model on test set
# cbam_test_loss, cbam_test_pixel_acc, cbam_test_iou = evaluate(cbam_model_trained, test_loader, cbam_criterion, device)
# print(f'CBAM Model - Test Loss: {cbam_test_loss:.4f} Pixel Acc: {cbam_test_pixel_acc:.4f} IoU: {cbam_test_iou:.4f}')

# Visualize CBAM model training history
plot_training_history(cbam_history, 'EfficientNet-B0 with CBAM for Segmentation')

### Detailed Training History Analysis for CBAM vs Baseline Segmentation

This cell creates a comprehensive DataFrame containing the epoch-by-epoch segmentation metrics for both models:
1. Collects per-epoch training and validation losses
2. Collects per-epoch training and validation IoU and pixel accuracy values
3. Organizes data into a DataFrame for detailed analysis

This information enables us to pinpoint exactly when and how the CBAM model's segmentation performance diverges from the baseline.

In [None]:
# Log epoch-wise training metrics for comparative analysis
epochs = list(range(1, len(baseline_history['train_loss'])+1))

train_data = {
    'Epoch': epochs,
    'Baseline Train Loss': baseline_history['train_loss'],
    'Baseline Val Loss': baseline_history['val_loss'],
    'CBAM Train Loss': cbam_history['train_loss'],
    'CBAM Val Loss': cbam_history['val_loss'],
    'Baseline Train IoU': baseline_history['train_iou'],
    'Baseline Val IoU': baseline_history['val_iou'],
    'CBAM Train IoU': cbam_history['train_iou'],
    'CBAM Val IoU': cbam_history['val_iou'],
    'Baseline Train Pixel Acc': baseline_history['train_pixel_acc'],
    'Baseline Val Pixel Acc': baseline_history['val_pixel_acc'],
    'CBAM Train Pixel Acc': cbam_history['train_pixel_acc'],
    'CBAM Val Pixel Acc': cbam_history['val_pixel_acc']
}

training_df = pd.DataFrame(train_data)
print("Training History Comparison for Segmentation:")
display(training_df)

### 4.2 EfficientNet-B0 with Mish Activation Function for Segmentation

Mish is a self-regularized non-monotonic activation function that often outperforms ReLU and its variants in various tasks. For semantic segmentation, Mish may provide better gradient flow characteristics that improve feature representation at pixel level.

### Implementing the Mish Activation Function for Segmentation

This cell implements the Mish activation function and integrates it with EfficientNet-B0 for segmentation:
1. Defines the `Mish` activation class (formula: x * tanh(softplus(x)))
2. Creates an `EfficientNetB0WithMish` class that replaces all ReLU activations with Mish
3. Implements a recursive function to replace activations throughout the model
4. Adds a segmentation decoder to produce pixel-wise predictions
5. Initializes the model and sets up its optimizer and scheduler

In [None]:
# Implementing Mish activation
class Mish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * torch.tanh(F.softplus(x))

# Updating the EfficientNetB0WithMish class to fix the channel mismatch
class EfficientNetB0WithMish(nn.Module):
    def __init__(self, num_classes=19):  # 19 classes for Cityscapes segmentation
        super(EfficientNetB0WithMish, self).__init__()
        # Load the pre-trained EfficientNet-B0 model as encoder
        self.efficient_net = EfficientNet.from_pretrained('efficientnet-b0')

        # Get the correct number of features from the encoder
        self.in_features = self.efficient_net._fc.in_features  # This should be 1280 for EfficientNet-B0

        # Replace all activation functions with Mish
        self._replace_relu_with_mish(self.efficient_net)

        # Decoder structure with correct channel sizes
        self.decoder = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(self.in_features, 256, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(256),
                Mish(),
            ),
            nn.Sequential(
                nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(128),
                Mish(),
            ),
            nn.Sequential(
                nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(64),
                Mish(),
            ),
            nn.Sequential(
                nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(32),
                Mish(),
            )
        ])

        # Final segmentation head
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            Mish(),
            nn.Conv2d(32, num_classes, kernel_size=1)
        )

    def _replace_relu_with_mish(self, model):
        for name, module in model.named_children():
            if isinstance(module, nn.ReLU):
                setattr(model, name, Mish())
            else:
                self._replace_relu_with_mish(module)

    def forward(self, x):
        # Original input size for later upsampling
        input_size = x.size()[2:]

        # Extract features from the EfficientNet backbone
        features = self.efficient_net.extract_features(x)  # Output shape: [B, 1280, H/32, W/32]

        # Apply decoder blocks
        x = features
        for decoder_block in self.decoder:
            x = decoder_block(x)

        # Apply final segmentation head
        x = self.segmentation_head(x)

        # Upsample to match original input size if needed
        if x.shape[-2:] != input_size:
            x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)

        return x

# Initialize the Mish model for segmentation
mish_model = EfficientNetB0WithMish(num_classes=19).to(device)

# Define optimizer for Mish segmentation model
mish_optimizer = optim.SGD(mish_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
mish_scheduler = optim.lr_scheduler.ReduceLROnPlateau(mish_optimizer, 'min', patience=3, factor=0.1)

### Training and Evaluating the Mish Model for Segmentation

Here we train and evaluate the EfficientNet-B0 model with Mish activation functions for semantic segmentation:
1. Train the model for 10 epochs using the same training function as before
2. Evaluate its performance on the test set using segmentation metrics (mIoU, pixel accuracy)
3. Visualize the training history and sample segmentation outputs to analyze the impact of the Mish activation

In [None]:
# Train the Mish segmentation model
mish_model_trained, mish_history = train_model(
    mish_model,
    train_loader,
    val_loader,
    criterion,
    mish_optimizer,
    mish_scheduler,
    num_epochs=10
)

In [None]:
# Visualize Mish model training history for segmentation
plot_training_history(mish_history, 'EfficientNet-B0 with Mish for Segmentation')

### Detailed Training History Analysis for Mish vs Baseline

This cell creates a comprehensive DataFrame of epoch-by-epoch training metrics:
1. Compares training and validation losses between Mish and baseline models
2. Compares training and validation accuracies between the models
3. Allows for fine-grained analysis of how Mish affects the training dynamics

In [None]:
# Log epoch-wise training metrics for comparative analysis
epochs = list(range(1, len(baseline_history['train_loss'])+1))

train_data = {
    'Epoch': epochs,
    'Baseline Train Loss': baseline_history['train_loss'],
    'Baseline Val Loss': baseline_history['val_loss'],
    'Mish Train Loss': mish_history['train_loss'],
    'Mish Val Loss': mish_history['val_loss'],
    'Baseline Train Acc': baseline_history['train_acc'],
    'Baseline Val Acc': baseline_history['val_acc'],
    'Mish Train Acc': mish_history['train_acc'],
    'Mish Val Acc': mish_history['val_acc']
}

training_df = pd.DataFrame(train_data)
print("Training History Comparison (Baseline vs Mish):")
display(training_df)

### 4.3 EfficientNet-B0 with DeeplabV3+ Segmentation Head

DeepLabV3+ is a semantic segmentation architecture that combines atrous convolution with encoder-decoder structure.

### Implementing DeepLabV3+ Segmentation Head

This cell implements the DeepLabV3+ architecture with EfficientNet-B0 as the backbone:
1. Creates the `ASPP` (Atrous Spatial Pyramid Pooling) module that captures multi-scale information
   - Uses multiple dilated convolutions with different rates
   - Includes global pooling to capture context
2. Implements the `DeepLabV3Plus` class that combines:
   - EfficientNet backbone for feature extraction
   - ASPP module for multi-scale processing
   - Decoder for generating the final segmentation output
3. Initializes the model and sets up optimizer and scheduler

In [None]:
# Implementing DeeplabV3+ segmentation head
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, rates=[6, 12, 18]):
        super(ASPP, self).__init__()

        self.aspp = nn.ModuleList()

        # 1x1 convolution
        self.aspp.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ))

        # Atrous convolutions
        for rate in rates:
            self.aspp.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            ))

        # Global average pooling
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        # Output layer
        self.output = nn.Sequential(
            nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        size = x.size()[2:]

        outputs = []
        for module in self.aspp:
            outputs.append(module(x))

        # Process global average pooling branch
        gap_output = self.global_avg_pool(x)
        gap_output = F.interpolate(gap_output, size=size, mode='bilinear', align_corners=True)
        outputs.append(gap_output)

        # Concatenate and process through output layer
        x = torch.cat(outputs, dim=1)
        return self.output(x)

class DeepLabV3Plus(nn.Module):
    def __init__(self, base_model, num_classes=19, output_stride=16):
        super(DeepLabV3Plus, self).__init__()
        self.backbone = base_model
        in_features = self.backbone._fc.in_features

        # Remove the classification head
        self.backbone._fc = nn.Identity()

        # Low-level features from earlier layers for skip connection
        self.low_level_features = 64  # Adjust based on EfficientNet architecture

        # ASPP module
        self.aspp = ASPP(in_features, 256)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, 1)
        )

    def forward(self, x):
        input_size = x.size()[2:]

        # Extract features
        features = self.backbone.extract_features(x)

        # Apply ASPP
        x = self.aspp(features)

        # Decoder
        x = self.decoder(x)

        # Upsampling to original size
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)

        return x

# Initialize the DeepLabV3+ model for segmentation
base_model = EfficientNet.from_pretrained('efficientnet-b0')
deeplabv3_model = DeepLabV3Plus(base_model, num_classes=19).to(device)  # 19 classes for Cityscapes

# Define loss function and optimizer for semantic segmentation
deeplabv3_criterion = nn.CrossEntropyLoss(ignore_index=255)  # Ignore index 255 which is the 'ignored' label in Cityscapes
deeplabv3_optimizer = optim.SGD(deeplabv3_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
deeplabv3_scheduler = optim.lr_scheduler.ReduceLROnPlateau(deeplabv3_optimizer, 'min', patience=3, factor=0.1)

### Training and Evaluating the DeepLabV3+ Model

Here we train and evaluate the EfficientNet-B0 model with DeepLabV3+ segmentation head:
1. Train the model for 10 epochs using the same training function
2. Evaluate its performance on the test set
3. Visualize the training history to analyze how the segmentation head affects performance

In [None]:
# Train the DeepLabV3+ segmentation model
deeplabv3_model_trained, deeplabv3_history = train_model(
    deeplabv3_model,
    train_loader,
    val_loader,
    deeplabv3_criterion,
    deeplabv3_optimizer,
    deeplabv3_scheduler,
    num_epochs=10
)

In [None]:
# Visualize DeepLabV3+ model training history for segmentation
plot_training_history(deeplabv3_history, 'EfficientNet-B0 with DeepLabV3+ for Segmentation')

### Detailed Training History Analysis for DeepLabV3+ vs Baseline

This cell creates a comprehensive comparison of training metrics between models:
1. Collects epoch-by-epoch training and validation losses
2. Collects epoch-by-epoch training and validation accuracies
3. Organizes the data into a DataFrame for detailed analysis

This information helps identify how the DeepLabV3+ architecture changes learning dynamics.

In [None]:
# Log epoch-wise training metrics for comparative analysis
epochs = list(range(1, len(baseline_history['train_loss'])+1))

train_data = {
    'Epoch': epochs,
    'Baseline Train Loss': baseline_history['train_loss'],
    'Baseline Val Loss': baseline_history['val_loss'],
    'DeeplabV3+ Train Loss': deeplabv3_history['train_loss'],
    'DeeplabV3+ Val Loss': deeplabv3_history['val_loss'],
    'Baseline Train Acc': baseline_history['train_acc'],
    'Baseline Val Acc': baseline_history['val_acc'],
    'DeeplabV3+ Train Acc': deeplabv3_history['train_acc'],
    'DeeplabV3+ Val Acc': deeplabv3_history['val_acc']
}

training_df = pd.DataFrame(train_data)
print("Training History Comparison (Baseline vs DeeplabV3+):")
display(training_df)

### 4.4 Combined Approach: EfficientNet-B0 with CBAM, Mish, and DeepLabV3+

After testing each modification individually, we now explore combining all three enhancements:
1. CBAM for attention-based feature refinement
2. Mish activation function for better gradient flow
3. DeepLabV3+ segmentation head for multi-scale feature extraction

This combined approach should theoretically leverage the strengths of each individual modification to achieve even better performance.

In [None]:
# Combined model: EfficientNet-B0 with CBAM, Mish, and DeepLabV3+
class CombinedModel(nn.Module):
    def __init__(self, num_classes=19):  # 19 classes for Cityscapes segmentation
        super(CombinedModel, self).__init__()
        # Initialize the EfficientNet-B0 backbone
        self.efficient_net = EfficientNet.from_pretrained('efficientnet-b0')
        self.in_features = self.efficient_net._fc.in_features  # Should be 1280 for EfficientNet-B0

        # Remove the classification head
        self.efficient_net._fc = nn.Identity()

        # Replace ReLU with Mish in the backbone
        self._replace_relu_with_mish(self.efficient_net)

        # Add CBAM module
        self.cbam = CBAM(self.in_features)

        # Add ASPP module (from DeepLabV3+)
        self.aspp = ASPP(self.in_features, 256)

        # Add decoder (from DeepLabV3+ but with Mish activation)
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            Mish(),  # Using Mish instead of ReLU
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            Mish(),  # Using Mish instead of ReLU
            nn.Conv2d(256, num_classes, 1)
        )

    def _replace_relu_with_mish(self, model):
        for name, module in model.named_children():
            if isinstance(module, nn.ReLU):
                setattr(model, name, Mish())
            else:
                self._replace_relu_with_mish(module)

    def forward(self, x):
        # Store input size for later upsampling
        input_size = x.size()[2:]

        # Extract features from EfficientNet backbone
        features = self.efficient_net.extract_features(x)  # [B, 1280, H/32, W/32]

        # Apply CBAM attention
        features_with_attention = self.cbam(features)

        # Apply ASPP module from DeepLabV3+
        x = self.aspp(features_with_attention)

        # Apply decoder
        x = self.decoder(x)

        # Upsampling to original size
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)

        return x

# Initialize the combined model
combined_model = CombinedModel(num_classes=19).to(device)  # 19 classes for Cityscapes

# Define loss function, optimizer and scheduler for the combined segmentation model
combined_criterion = nn.CrossEntropyLoss(ignore_index=255)  # Ignore index 255 which is the 'ignored' label in Cityscapes
combined_optimizer = optim.SGD(combined_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
combined_scheduler = optim.lr_scheduler.ReduceLROnPlateau(combined_optimizer, 'min', patience=3, factor=0.1)

In [None]:
# Train the combined segmentation model
combined_model_trained, combined_history = train_model(
    combined_model,
    train_loader,
    val_loader,
    combined_criterion,
    combined_optimizer,
    combined_scheduler,
    num_epochs=10
)

In [None]:
# Visualize combined model training history for segmentation
plot_training_history(combined_history, 'EfficientNet-B0 with CBAM, Mish, and DeepLabV3+ for Segmentation')

In [None]:
# Log epoch-wise training metrics for segmentation comparative analysis
epochs = list(range(1, len(baseline_history['train_loss'])+1))

train_data = {
    'Epoch': epochs,
    'Baseline Train Loss': baseline_history['train_loss'],
    'Baseline Val Loss': baseline_history['val_loss'],
    'Combined Train Loss': combined_history['train_loss'],
    'Combined Val Loss': combined_history['val_loss'],
    'Baseline Train IoU': baseline_history['train_iou'],
    'Baseline Val IoU': baseline_history['val_iou'],
    'Combined Train IoU': combined_history['train_iou'],
    'Combined Val IoU': combined_history['val_iou'],
    'Baseline Train Pixel Acc': baseline_history['train_pixel_acc'],
    'Baseline Val Pixel Acc': baseline_history['val_pixel_acc'],
    'Combined Train Pixel Acc': combined_history['train_pixel_acc'],
    'Combined Val Pixel Acc': combined_history['val_pixel_acc']
}

training_df = pd.DataFrame(train_data)
print("Training History Comparison for Segmentation (Baseline vs Combined):")
display(training_df)

## 5. Results Comparison and Analysis

Let's compare the performance of all model variants across various metrics.

### Save the experiments results

### Setting Up Model Storage

This cell prepares a directory to save our trained segmentation models:
1. Creates a 'models' directory in the current working directory if it doesn't exist
2. Displays the path where models will be saved

Saving models allows us to use them later for inference without retraining.

In [None]:
# Create a models directory if it doesn't exist
import os
models_dir = os.path.join(os.getcwd(), 'models')
os.makedirs(models_dir, exist_ok=True)
print(f"Models will be saved to: {models_dir}")

### Saving the Baseline Segmentation Model

This cell saves the trained baseline segmentation model to disk:
1. Defines the file path for the baseline model
2. Saves a comprehensive checkpoint including:
   - Model state dictionary (weights and parameters)
   - Optimizer state
   - Training history

In [None]:
# Save the baseline segmentation model after test evaluation
baseline_model_path = os.path.join(models_dir, 'baseline_efficientnet_b0_segmentation.pth')
torch.save({
    'model_state_dict': baseline_model_trained.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'history': baseline_history,
}, baseline_model_path)
print(f"Baseline segmentation model saved to {baseline_model_path}")

### Loading the Baseline Segmentation Model

This cell defines a function to load the saved baseline segmentation model and demonstrates its usage:
1. Implements the `load_baseline_model` function that:
   - Initializes a fresh model with the same architecture
   - Loads the weights and state from the saved checkpoint
   - Returns the model along with its history

In [None]:
# Load the baseline segmentation model
def load_baseline_model(model_path):
    model = EfficientNetB0Segmentation().to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    history = checkpoint['history']
    print(f"Loaded baseline segmentation model.")
    return model, history

baseline_model_path = os.path.join(models_dir, 'baseline_efficientnet_b0_segmentation.pth')
baseline_model_trained, baseline_history = load_baseline_model(baseline_model_path)
# The loaded model can now be used for inference

### Saving the CBAM Model

This cell saves the trained CBAM model to disk:
1. Defines the file path for the CBAM model
2. Saves a comprehensive checkpoint including:
   - Model state dictionary
   - Optimizer state
   - Training history

In [None]:
# Save the CBAM model after test evaluation
cbam_model_path = os.path.join(models_dir, 'cbam_efficientnet_b0.pth')
torch.save({
    'model_state_dict': cbam_model_trained.state_dict(),
    'optimizer_state_dict': cbam_optimizer.state_dict(),
    'history': cbam_history,
}, cbam_model_path)
print(f"CBAM model saved to {cbam_model_path}")

### Loading the CBAM Model

This cell defines a function to load the saved CBAM model:
1. Implements the `load_cbam_model` function with the same pattern as the baseline loader
2. Properly initializes the CBAM-specific architecture before loading weights

In [None]:
# Load the CBAM model
def load_cbam_model(model_path):
    model = EfficientNetB0WithCBAM().to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    history = checkpoint['history']
    print(f"Loaded CBAM segmentation model.")
    return model, history

cbam_model_path = os.path.join(models_dir, 'cbam_efficientnet_b0.pth')
# Example usage:
cbam_model_trained, cbam_history = load_cbam_model(cbam_model_path)
# The loaded model can now be used for inference

### Saving the Mish Model

This cell saves the trained Mish model to disk:
1. Defines the file path for the Mish model
2. Saves the complete checkpoint with model weights, optimizer state and history

In [None]:
# Save the Mish model after test evaluation
mish_model_path = os.path.join(models_dir, 'mish_efficientnet_b0.pth')
torch.save({
    'model_state_dict': mish_model_trained.state_dict(),
    'optimizer_state_dict': mish_optimizer.state_dict(),
    'history': mish_history,
}, mish_model_path)
print(f"Mish model saved to {mish_model_path}")

### Loading the Mish Model

This cell defines a function to load the saved Mish model:
1. Implements the `load_mish_model` function that correctly initializes the model with Mish activations
2. Loads the saved weights and states

In [None]:
# Load the Mish model
def load_mish_model(model_path):
    model = EfficientNetB0WithMish().to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    history = checkpoint['history']
    # test_acc = checkpoint['test_acc']
    # test_loss = checkpoint['test_loss']
    print(f"Loaded Mish segmentation model.")
    return model, history

# Example usage:
loaded_mish_model, loaded_mish_history = load_mish_model(mish_model_path)
# The loaded model can now be used for inference

### Saving the DeepLabV3+ Model

This cell saves the trained DeepLabV3+ model to disk:
1. Defines the file path for the DeepLabV3+ model
2. Saves the complete checkpoint with all necessary information
3. Confirms successful saving with a print statement

In [None]:
# Save the DeeplabV3+ model after test evaluation
deeplabv3_model_path = os.path.join(models_dir, 'deeplabv3_efficientnet_b0.pth')
torch.save({
    'model_state_dict': deeplabv3_model_trained.state_dict(),
    'optimizer_state_dict': deeplabv3_optimizer.state_dict(),
    'history': deeplabv3_history
}, deeplabv3_model_path)
print(f"DeeplabV3+ model saved to {deeplabv3_model_path}")

### Loading the DeepLabV3+ Model

This cell defines a function to load the saved DeepLabV3+ model:
1. Implements the `load_deeplabv3_model` function with special handling for the two-component architecture:
   - First initializes a fresh EfficientNet-B0 base model
   - Then creates the DeepLabV3+ model with that base
   - Loads the saved weights and states

In [None]:
# Load the DeeplabV3+ model
def load_deeplabv3_model(model_path):
    base_model = EfficientNet.from_pretrained('efficientnet-b0')  # We need a base model for DeeplabV3+
    model = DeepLabV3Plus(base_model).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    history = checkpoint['history']
    print(f"Loaded DeeplabV3+ model.")
    return model, history

# Example usage:
loaded_deeplabv3_model, loaded_deeplabv3_history = load_deeplabv3_model(deeplabv3_model_path)
# The loaded model can now be used for inference

### Saving the Combined Model

This cell saves the trained combined model to disk:
1. Defines the file path for the combined model
2. Saves a comprehensive checkpoint including model weights, optimizer state, history
3. Confirms successful saving with a print statement

In [None]:
# Save the combined model after test evaluation
combined_model_path = os.path.join(models_dir, 'combined_model_efficientnet_b0.pth')
torch.save({
    'model_state_dict': combined_model_trained.state_dict(),
    'optimizer_state_dict': combined_optimizer.state_dict(),
    'history': combined_history,
}, combined_model_path)
print(f"Combined model saved to {combined_model_path}")

### Loading the Combined Model

This cell defines a function to load the saved combined model:
1. Implements the `load_combined_model` function that initializes the architecture with all modifications
2. Loads the saved weights and states
3. Provides an example of loading the model for future use

In [None]:
# Load the combined model
def load_combined_model(model_path):
    model = CombinedModel().to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    history = checkpoint['history']
    print(f"Loaded combined model.")
    return model, history, test_acc, test_loss

# Example usage:
loaded_combined_model, loaded_combined_history = load_combined_model(combined_model_path)
# The loaded model can now be used for inference