# Lab 2.2.4: Segmentation Lab

**Module:** 2.2 - Computer Vision  
**Time:** 3 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚≠ê

---

## üéØ Learning Objectives

By the end of this notebook, you will:
- [ ] Understand semantic vs instance vs panoptic segmentation
- [ ] Implement the U-Net architecture from scratch
- [ ] Train a segmentation model on real data
- [ ] Evaluate segmentation using IoU (Intersection over Union)

---

## üìö Prerequisites

- Completed: Labs 2.2.1-2.2.3
- Knowledge of: CNNs, skip connections, encoder-decoder architectures

---

## üåç Real-World Context

**Image segmentation is crucial for:**

- üè• **Medical imaging**: Outlining tumors, organs, cell structures
- üöó **Autonomous driving**: Understanding road, lane, sidewalk boundaries
- üõ∞Ô∏è **Satellite imagery**: Land use classification, disaster assessment
- üì∏ **Photo editing**: Background removal, portrait mode
- ü§ñ **Robotics**: Understanding what can be grasped or navigated

---

## üßí ELI5: What is Image Segmentation?

> **Imagine coloring in a coloring book...** üé®
>
> - **Classification**: "This coloring book page has a cat" ‚úì/‚úó
> - **Detection**: "There's a cat shape starting here and ending there" ‚¨ú
> - **Segmentation**: "Color every pixel that belongs to the cat" üé®
>
> Segmentation gives you the **exact shape** of every object, pixel by pixel!

### Three Types of Segmentation

```
Input Image:          Semantic:             Instance:             Panoptic:
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê      ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê      ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê      ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ    üê±  üê±    ‚îÇ      ‚îÇ    ‚ñì‚ñì  ‚ñì‚ñì    ‚îÇ      ‚îÇ    ‚ñà‚ñà  ‚ñí‚ñí    ‚îÇ      ‚îÇ    ‚ñà‚ñà  ‚ñí‚ñí    ‚îÇ
‚îÇ              ‚îÇ  ‚Üí   ‚îÇ              ‚îÇ  ‚Üí   ‚îÇ              ‚îÇ  ‚Üí   ‚îÇ              ‚îÇ
‚îÇ  Background  ‚îÇ      ‚îÇ  ‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë  ‚îÇ      ‚îÇ  ‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë  ‚îÇ      ‚îÇ  ‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë‚ñë  ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò      ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò      ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò      ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
                      All cats = same       Cat 1 ‚â† Cat 2         Cats + Background
                      color (class)         (different IDs)       all labeled
```

- **Semantic**: All pixels of same CLASS get same label
- **Instance**: Each OBJECT gets unique ID (separates individual cats)
- **Panoptic**: Semantic + Instance combined (the full picture)

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import VOCSegmentation
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple, List, Optional, Dict
from tqdm.auto import tqdm
from pathlib import Path
import time

# DGX Spark optimizations
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Device: {device}")
if torch.cuda.is_available():
    print(f"üéÆ GPU: {torch.cuda.get_device_name(0)}")
    print(f"üíæ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

---

## Part 1: Understanding U-Net Architecture

### üßí ELI5: U-Net

> **Imagine describing a picture to a friend over the phone...**
>
> 1. **First**, you give a high-level summary: "It's a beach scene"
> 2. **Then**, you add details: "There's an umbrella on the left"
> 3. **Finally**, you get specific: "The umbrella is red with white stripes, positioned at..."
>
> **U-Net works similarly:**
> - **Encoder (going down)**: Captures "what" is in the image (abstract features)
> - **Decoder (going up)**: Recovers "where" things are (spatial details)
> - **Skip connections**: Pass spatial details directly to help localization!

### U-Net Architecture

```
Input (256√ó256√ó3)
    ‚îÇ
    ‚ñº
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ Conv√ó2  ‚îÇ 64 filters                                           ‚îÇ Skip 1
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îò                                                      ‚îÇ
     ‚ñº MaxPool                                                   ‚îÇ
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê        ‚îÇ
‚îÇ Conv√ó2  ‚îÇ 128 filters                                  ‚îÇ Skip 2 ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îò                                              ‚îÇ        ‚îÇ
     ‚ñº MaxPool                                           ‚îÇ        ‚îÇ
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê             ‚îÇ        ‚îÇ
‚îÇ Conv√ó2  ‚îÇ 256 filters                     ‚îÇ Skip 3     ‚îÇ        ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îò                                 ‚îÇ             ‚îÇ        ‚îÇ
     ‚ñº MaxPool                              ‚îÇ             ‚îÇ        ‚îÇ
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê             ‚îÇ             ‚îÇ        ‚îÇ
‚îÇ Conv√ó2  ‚îÇ 512 filters        ‚îÇ Skip 4     ‚îÇ             ‚îÇ        ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îò                    ‚îÇ             ‚îÇ             ‚îÇ        ‚îÇ
     ‚ñº MaxPool                 ‚îÇ             ‚îÇ             ‚îÇ        ‚îÇ
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê                    ‚îÇ             ‚îÇ             ‚îÇ        ‚îÇ
‚îÇ Conv√ó2  ‚îÇ 1024 (bottleneck)  ‚îÇ             ‚îÇ             ‚îÇ        ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îò                    ‚îÇ             ‚îÇ             ‚îÇ        ‚îÇ
     ‚ñº UpConv                  ‚îÇ             ‚îÇ             ‚îÇ        ‚îÇ
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê ‚óÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò             ‚îÇ             ‚îÇ        ‚îÇ
‚îÇ Conv√ó2  ‚îÇ + Concatenate Skip 4             ‚îÇ             ‚îÇ        ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îò                                 ‚îÇ             ‚îÇ        ‚îÇ
     ‚ñº UpConv                               ‚îÇ             ‚îÇ        ‚îÇ
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê ‚óÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò             ‚îÇ        ‚îÇ
‚îÇ Conv√ó2  ‚îÇ + Concatenate Skip 3                          ‚îÇ        ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îò                                              ‚îÇ        ‚îÇ
     ‚ñº UpConv                                            ‚îÇ        ‚îÇ
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê ‚óÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò        ‚îÇ
‚îÇ Conv√ó2  ‚îÇ + Concatenate Skip 2                                  ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îò                                                      ‚îÇ
     ‚ñº UpConv                                                    ‚îÇ
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê ‚óÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
‚îÇ Conv√ó2  ‚îÇ + Concatenate Skip 1
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îò
     ‚ñº
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ Conv 1√ó1‚îÇ ‚Üí num_classes channels
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
    ‚îÇ
    ‚ñº
Output (256√ó256√ónum_classes)
```

In [None]:
class DoubleConv(nn.Module):
    """
    Double convolution block: (Conv ‚Üí BN ‚Üí ReLU) √ó 2
    
    This is the basic building block of U-Net.
    """
    
    def __init__(self, in_channels: int, out_channels: int):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.double_conv(x)


class Down(nn.Module):
    """Downsampling: MaxPool ‚Üí DoubleConv"""
    
    def __init__(self, in_channels: int, out_channels: int):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upsampling: UpConv ‚Üí Concat skip ‚Üí DoubleConv"""
    
    def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
        super(Up, self).__init__()
        
        if bilinear:
            # Use bilinear upsampling (less parameters)
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels)
        else:
            # Use transposed convolution (learnable upsampling)
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        """x1: from decoder, x2: skip connection from encoder"""
        x1 = self.up(x1)
        
        # Handle size mismatch (input might not be perfectly divisible by 2)
        diff_y = x2.size()[2] - x1.size()[2]
        diff_x = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                       diff_y // 2, diff_y - diff_y // 2])
        
        # Concatenate skip connection
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    """
    U-Net for semantic segmentation.
    
    Original paper: "U-Net: Convolutional Networks for Biomedical Image Segmentation"
    by Olaf Ronneberger et al., 2015
    
    Key innovations:
    - Encoder-decoder with skip connections
    - Works with limited training data
    - Precise localization through skip connections
    
    Args:
        n_channels: Number of input channels (3 for RGB)
        n_classes: Number of output classes
        bilinear: Use bilinear upsampling (vs transposed conv)
    """
    
    def __init__(self, n_channels: int = 3, n_classes: int = 21, bilinear: bool = True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        # Encoder (downsampling path)
        self.inc = DoubleConv(n_channels, 64)    # Initial conv
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)   # Bottleneck
        
        # Decoder (upsampling path)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        
        # Output layer (1√ó1 conv to reduce to num_classes)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encoder
        x1 = self.inc(x)      # [B, 64, H, W]
        x2 = self.down1(x1)   # [B, 128, H/2, W/2]
        x3 = self.down2(x2)   # [B, 256, H/4, W/4]
        x4 = self.down3(x3)   # [B, 512, H/8, W/8]
        x5 = self.down4(x4)   # [B, 512, H/16, W/16] (bottleneck)
        
        # Decoder with skip connections
        x = self.up1(x5, x4)  # [B, 256, H/8, W/8]
        x = self.up2(x, x3)   # [B, 128, H/4, W/4]
        x = self.up3(x, x2)   # [B, 64, H/2, W/2]
        x = self.up4(x, x1)   # [B, 64, H, W]
        
        # Output
        logits = self.outc(x) # [B, n_classes, H, W]
        return logits


# Test the model
model = UNet(n_channels=3, n_classes=21)
dummy_input = torch.randn(1, 3, 256, 256)
output = model(dummy_input)

print(f"üìä U-Net Architecture:")
print(f"   Input shape:  {dummy_input.shape}")
print(f"   Output shape: {output.shape}")
print(f"   Parameters:   {sum(p.numel() for p in model.parameters()):,}")

---

## Part 2: Data Preparation (Pascal VOC)

We'll use the Pascal VOC 2012 segmentation dataset - a classic benchmark with 21 classes.

In [None]:
# VOC class names and colors
VOC_CLASSES = [
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
    'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
    'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]

# VOC color palette
VOC_COLORMAP = np.array([
    [0, 0, 0],        # background
    [128, 0, 0],      # aeroplane
    [0, 128, 0],      # bicycle
    [128, 128, 0],    # bird
    [0, 0, 128],      # boat
    [128, 0, 128],    # bottle
    [0, 128, 128],    # bus
    [128, 128, 128],  # car
    [64, 0, 0],       # cat
    [192, 0, 0],      # chair
    [64, 128, 0],     # cow
    [192, 128, 0],    # diningtable
    [64, 0, 128],     # dog
    [192, 0, 128],    # horse
    [64, 128, 128],   # motorbike
    [192, 128, 128],  # person
    [0, 64, 0],       # pottedplant
    [128, 64, 0],     # sheep
    [0, 192, 0],      # sofa
    [128, 192, 0],    # train
    [0, 64, 128],     # tvmonitor
], dtype=np.uint8)

print("üìã Pascal VOC Classes:")
for i, cls in enumerate(VOC_CLASSES):
    color = VOC_COLORMAP[i]
    print(f"   {i:2d}: {cls:<15} RGB{tuple(color)}")

In [None]:
class VOCSegmentationDataset(Dataset):
    """
    Pascal VOC Segmentation dataset with proper transforms.
    """
    
    def __init__(
        self,
        root: str = '../data',
        image_set: str = 'train',
        image_size: int = 256,
        download: bool = True
    ):
        self.image_size = image_size
        
        # Load VOC dataset
        self.dataset = VOCSegmentation(
            root=root,
            year='2012',
            image_set=image_set,
            download=download
        )
        
        # Transforms for image
        self.image_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
        
        # Transforms for mask (no normalization!)
        self.mask_transform = transforms.Compose([
            transforms.Resize((image_size, image_size), 
                            interpolation=transforms.InterpolationMode.NEAREST),
        ])
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image, mask = self.dataset[idx]
        
        # Transform image
        image = self.image_transform(image)
        
        # Transform mask
        mask = self.mask_transform(mask)
        mask = torch.from_numpy(np.array(mask)).long()
        
        # VOC uses 255 for boundary/ignore - set to 0 (background) for simplicity
        mask[mask == 255] = 0
        
        return image, mask

In [None]:
# Load datasets
# Note: Pascal VOC 2012 is ~2GB download. First run may take several minutes.
print("üìÇ Loading Pascal VOC 2012...")
print("   ‚ö†Ô∏è First run will download ~2GB. This may take several minutes.")

train_dataset = VOCSegmentationDataset(image_set='train')
val_dataset = VOCSegmentationDataset(image_set='val')

# ‚ö†Ô∏è DGX SPARK NOTE: When using Docker with num_workers > 0, use --ipc=host flag
# Example: docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:25.11-py3
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True)

print(f"‚úÖ Dataset loaded:")
print(f"   Training:   {len(train_dataset):,} images")
print(f"   Validation: {len(val_dataset):,} images")

In [None]:
def visualize_segmentation(image: torch.Tensor, mask: torch.Tensor, prediction: Optional[torch.Tensor] = None):
    """
    Visualize image, ground truth mask, and optionally prediction.
    """
    # Denormalize image
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    image_vis = image.cpu() * std + mean
    image_vis = image_vis.permute(1, 2, 0).numpy().clip(0, 1)
    
    # Convert mask to colored image
    mask_np = mask.cpu().numpy()
    mask_colored = VOC_COLORMAP[mask_np]
    
    if prediction is not None:
        fig, axes = plt.subplots(1, 4, figsize=(16, 4))
        
        pred_np = prediction.cpu().numpy()
        pred_colored = VOC_COLORMAP[pred_np]
        
        axes[0].imshow(image_vis)
        axes[0].set_title('Input Image')
        axes[0].axis('off')
        
        axes[1].imshow(mask_colored)
        axes[1].set_title('Ground Truth')
        axes[1].axis('off')
        
        axes[2].imshow(pred_colored)
        axes[2].set_title('Prediction')
        axes[2].axis('off')
        
        # Overlay
        axes[3].imshow(image_vis)
        axes[3].imshow(pred_colored, alpha=0.5)
        axes[3].set_title('Overlay')
        axes[3].axis('off')
    else:
        fig, axes = plt.subplots(1, 3, figsize=(12, 4))
        
        axes[0].imshow(image_vis)
        axes[0].set_title('Input Image')
        axes[0].axis('off')
        
        axes[1].imshow(mask_colored)
        axes[1].set_title('Segmentation Mask')
        axes[1].axis('off')
        
        # Overlay
        axes[2].imshow(image_vis)
        axes[2].imshow(mask_colored, alpha=0.5)
        axes[2].set_title('Overlay')
        axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize samples
images, masks = next(iter(train_loader))
for i in range(3):
    visualize_segmentation(images[i], masks[i])

---

## Part 3: Segmentation Loss Functions

Segmentation uses special loss functions to handle class imbalance.

### üìù Key Function: F.one_hot()

Before we define our loss functions, let's understand `F.one_hot()` which converts class indices to one-hot encoded tensors:

```python
import torch.nn.functional as F

# Class indices for 3 pixels (classes 0, 2, 1)
indices = torch.tensor([0, 2, 1])

# Convert to one-hot with 4 classes
one_hot = F.one_hot(indices, num_classes=4)
# Result: tensor([[1, 0, 0, 0],   <- class 0
#                 [0, 0, 1, 0],   <- class 2
#                 [0, 1, 0, 0]])  <- class 1
```

This is essential for computing per-class losses in segmentation!

In [None]:
class DiceLoss(nn.Module):
    """
    Dice Loss for segmentation.
    
    Dice = 2 * |A ‚à© B| / (|A| + |B|)
    
    Handles class imbalance better than cross-entropy.
    """
    
    def __init__(self, smooth: float = 1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Args:
            logits: [B, C, H, W] - raw model output
            targets: [B, H, W] - class indices
        """
        num_classes = logits.shape[1]
        
        # Convert logits to probabilities
        probs = F.softmax(logits, dim=1)
        
        # One-hot encode targets
        targets_one_hot = F.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()
        
        # Compute Dice per class
        dims = (0, 2, 3)  # Batch, H, W
        intersection = (probs * targets_one_hot).sum(dims)
        union = probs.sum(dims) + targets_one_hot.sum(dims)
        
        dice = (2 * intersection + self.smooth) / (union + self.smooth)
        
        # Average over classes (excluding background if desired)
        return 1 - dice.mean()


class CombinedLoss(nn.Module):
    """
    Combination of Cross-Entropy and Dice Loss.
    
    CE helps with hard examples, Dice handles class imbalance.
    """
    
    def __init__(self, ce_weight: float = 0.5, dice_weight: float = 0.5):
        super(CombinedLoss, self).__init__()
        self.ce = nn.CrossEntropyLoss(ignore_index=255)  # Ignore boundary
        self.dice = DiceLoss()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
    
    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        ce_loss = self.ce(logits, targets)
        dice_loss = self.dice(logits, targets)
        return self.ce_weight * ce_loss + self.dice_weight * dice_loss

### üßí ELI5: Dice Loss vs Cross-Entropy

> **Imagine you're grading an artist's drawing of a cat...**
>
> **Cross-Entropy**: "How confident were you at each pixel?"
> - Penalizes wrong predictions strongly
> - Can be dominated by background (if 90% of pixels are background)
>
> **Dice Loss**: "How much overlap between your drawing and the real cat?"
> - Directly measures segmentation quality
> - Naturally handles class imbalance
>
> **Combined**: Get the best of both worlds!

---

## Part 4: Evaluation Metrics (IoU/mIoU)

In [None]:
def compute_iou(pred: torch.Tensor, target: torch.Tensor, num_classes: int = 21) -> Dict[str, float]:
    """
    Compute Intersection over Union (IoU) per class and mean IoU.
    
    IoU = |A ‚à© B| / |A ‚à™ B|
    
    Args:
        pred: [B, H, W] - predicted class indices
        target: [B, H, W] - ground truth class indices
        num_classes: Number of classes
    
    Returns:
        Dictionary with per-class IoU and mIoU
    """
    ious = []
    
    for cls in range(num_classes):
        pred_cls = (pred == cls)
        target_cls = (target == cls)
        
        intersection = (pred_cls & target_cls).sum().float()
        union = (pred_cls | target_cls).sum().float()
        
        if union > 0:
            iou = intersection / union
            ious.append(iou.item())
        else:
            ious.append(float('nan'))  # Class not present
    
    # Compute mIoU (ignoring classes not present)
    valid_ious = [iou for iou in ious if not np.isnan(iou)]
    miou = np.mean(valid_ious) if valid_ious else 0.0
    
    return {
        'per_class_iou': ious,
        'miou': miou
    }

# Demo
pred = torch.argmax(torch.randn(1, 21, 256, 256), dim=1)
target = torch.randint(0, 21, (1, 256, 256))
metrics = compute_iou(pred, target)

print(f"üìä IoU Metrics (random predictions):")
print(f"   mIoU: {metrics['miou']:.1%}")

---

## Part 5: Training U-Net

In [None]:
def train_segmentation(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    epochs: int = 10,
    lr: float = 0.001,
    device: torch.device = device
) -> Dict[str, List[float]]:
    """
    Train a segmentation model.
    """
    model = model.to(device)
    criterion = CombinedLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    history = {'train_loss': [], 'val_loss': [], 'val_miou': []}
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        for images, masks in pbar:
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0
        all_preds, all_targets = [], []
        
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                
                preds = outputs.argmax(dim=1)
                all_preds.append(preds.cpu())
                all_targets.append(masks.cpu())
        
        val_loss /= len(val_loader)
        history['val_loss'].append(val_loss)
        
        # Compute mIoU
        all_preds = torch.cat(all_preds)
        all_targets = torch.cat(all_targets)
        metrics = compute_iou(all_preds, all_targets)
        history['val_miou'].append(metrics['miou'])
        
        scheduler.step()
        
        print(f"   Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, mIoU: {metrics['miou']:.1%}")
    
    return history

In [None]:
# Train the model
print("üèãÔ∏è Training U-Net on Pascal VOC...")
print("="*50)

model = UNet(n_channels=3, n_classes=21)
start_time = time.time()

history = train_segmentation(
    model, 
    train_loader, 
    val_loader,
    epochs=10,
    lr=0.001
)

train_time = time.time() - start_time
print(f"\n‚úÖ Training complete in {train_time/60:.1f} minutes")
print(f"   Final mIoU: {history['val_miou'][-1]:.1%}")

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history['train_loss'], label='Train', linewidth=2)
axes[0].plot(history['val_loss'], label='Validation', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('üìâ Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history['val_miou'], label='mIoU', linewidth=2, color='green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('mIoU')
axes[1].set_title('üìà Validation mIoU')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## Part 6: Visualize Predictions

In [None]:
# Visualize predictions on validation set
model.eval()

for batch_idx, (images, masks) in enumerate(val_loader):
    if batch_idx >= 1:  # Just one batch
        break
    
    images = images.to(device)
    with torch.no_grad():
        outputs = model(images)
        predictions = outputs.argmax(dim=1)
    
    # Show first 4 samples
    for i in range(min(4, len(images))):
        visualize_segmentation(
            images[i].cpu(),
            masks[i],
            predictions[i].cpu()
        )

---

## ‚úã Try It Yourself

1. **Modify U-Net**: Add more channels (128, 256, 512, 1024 instead of 64, 128, 256, 512)
2. **Add data augmentation**: Random rotation, flipping for both image and mask
3. **Try different losses**: Pure Dice loss vs pure CE vs combined

<details>
<summary>üí° Hint for augmentation using torchvision (built-in)</summary>

For joint image-mask augmentation, you can use torchvision transforms with manual random state:

```python
import torchvision.transforms.functional as TF
import random

def joint_transform(image, mask):
    # Random horizontal flip
    if random.random() > 0.5:
        image = TF.hflip(image)
        mask = TF.hflip(mask)
    
    # Random vertical flip
    if random.random() > 0.5:
        image = TF.vflip(image)
        mask = TF.vflip(mask)
    
    # Random rotation (90 degrees)
    if random.random() > 0.5:
        angle = random.choice([90, 180, 270])
        image = TF.rotate(image, angle)
        mask = TF.rotate(mask, angle)
    
    return image, mask
```

</details>

<details>
<summary>üí° Hint for augmentation using albumentations (advanced)</summary>

The `albumentations` library provides powerful joint image-mask augmentation. **You'll need to install it first:**

```bash
pip install albumentations
```

Then use it like this:

```python
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Define joint transform for image AND mask
transform = A.Compose([
    A.RandomCrop(256, 256),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

# Apply to both image and mask together
transformed = transform(image=image, mask=mask)
aug_image = transformed['image']
aug_mask = transformed['mask']
```

**Key benefit:** Albumentations ensures the same random transformation is applied to both image and mask, maintaining spatial correspondence.

</details>

In [None]:
# YOUR CODE HERE



---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Using wrong interpolation for masks

```python
# ‚ùå Wrong: Bilinear interpolation creates invalid class values!
mask = F.interpolate(mask.float(), size=256, mode='bilinear')

# ‚úÖ Right: Use nearest neighbor for class labels
mask = F.interpolate(mask.unsqueeze(0).float(), size=256, mode='nearest').squeeze(0)
```
**Why:** Masks are class indices (0, 1, 2...). Bilinear creates invalid values (0.5, 1.3...).

### Mistake 2: Forgetting ignore index

```python
# ‚ùå Wrong: Training on boundary pixels
criterion = nn.CrossEntropyLoss()  # Treats 255 as a class!

# ‚úÖ Right: Ignore boundary/void pixels
criterion = nn.CrossEntropyLoss(ignore_index=255)
```
**Why:** VOC uses 255 for boundaries. Training on them confuses the model.

### Mistake 3: Size mismatch with skip connections

```python
# ‚ùå Wrong: Direct concatenation without size check
x = torch.cat([encoder_feature, decoder_feature], dim=1)  # May crash!

# ‚úÖ Right: Pad or crop to match sizes
if encoder_feature.size() != decoder_feature.size():
    decoder_feature = F.pad(decoder_feature, compute_padding(...))
x = torch.cat([encoder_feature, decoder_feature], dim=1)
```
**Why:** Downsampling/upsampling may not preserve exact dimensions.

---

## üéâ Checkpoint

You've learned:
- ‚úÖ Semantic vs instance vs panoptic segmentation
- ‚úÖ U-Net encoder-decoder architecture with skip connections
- ‚úÖ Dice loss for handling class imbalance
- ‚úÖ IoU/mIoU evaluation metrics
- ‚úÖ Training and visualizing segmentation models

---

## üöÄ Challenge (Optional)

**Implement DeepLabV3+ backbone replacement:**

Replace U-Net's encoder with a pre-trained ResNet backbone:

1. Use `torchvision.models.segmentation.deeplabv3_resnet50`
2. Compare performance with your U-Net
3. Fine-tune only the classifier head first, then the full model

<details>
<summary>üí° Starting Code</summary>

```python
from torchvision.models.segmentation import deeplabv3_resnet50

model = deeplabv3_resnet50(pretrained=True)
model.classifier[-1] = nn.Conv2d(256, 21, kernel_size=1)  # 21 VOC classes
```

</details>

In [None]:
# YOUR CHALLENGE CODE HERE



---

## üìñ Further Reading

- [U-Net Paper](https://arxiv.org/abs/1505.04597)
- [DeepLab Series](https://arxiv.org/abs/1706.05587)
- [Panoptic Segmentation](https://arxiv.org/abs/1801.00868)
- [SegFormer](https://arxiv.org/abs/2105.15203) - Modern transformer-based

---

## üßπ Cleanup

In [None]:
# Clear GPU memory
import gc

del model
torch.cuda.empty_cache()
gc.collect()

print("‚úÖ Cleanup complete!")
if torch.cuda.is_available():
    print(f"üíæ GPU Memory Free: {torch.cuda.mem_get_info()[0] / 1e9:.1f} GB")