# Task 7.4 Solution: Segmentation Lab

**Module:** 7 - Computer Vision  
**Type:** Solution Notebook

---

This notebook contains solutions for semantic segmentation exercises using U-Net architecture.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, List, Optional

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## Exercise Solution: Larger U-Net Architecture

This solution implements a larger U-Net with doubled channel dimensions for improved segmentation performance.

**Changes from standard U-Net:**
- Original: 64 → 128 → 256 → 512 → 1024
- This version: 128 → 256 → 512 → 1024 → 2048

In [None]:
class DoubleConv(nn.Module):
    """
    Double convolution block: (Conv2d -> BN -> ReLU) x 2
    
    This is the basic building block of U-Net.
    """
    
    def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
        super(DoubleConv, self).__init__()
        if not mid_channels:
            mid_channels = out_channels
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.double_conv(x)


print("DoubleConv block defined!")

In [None]:
class Down(nn.Module):
    """Downscaling with MaxPool then double conv."""
    
    def __init__(self, in_channels: int, out_channels: int):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv."""
    
    def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
        super(Up, self).__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)
        
        # Handle size mismatch due to odd dimensions
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                       diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


print("Down and Up blocks defined!")

In [None]:
class LargerUNet(nn.Module):
    """
    Larger U-Net with more channels.
    
    Original U-Net: 64 -> 128 -> 256 -> 512 -> 1024
    This version:   128 -> 256 -> 512 -> 1024 -> 2048
    
    Benefits:
    - More capacity for complex segmentation tasks
    - Better feature representation
    - Suitable for high-resolution images
    
    Trade-offs:
    - More parameters (~4x standard U-Net)
    - Higher memory usage
    - Longer training time
    """
    
    def __init__(self, n_channels: int = 3, n_classes: int = 21, bilinear: bool = False):
        super(LargerUNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        # Encoder path (doubled channels)
        self.inc = DoubleConv(n_channels, 128)
        self.down1 = Down(128, 256)
        self.down2 = Down(256, 512)
        self.down3 = Down(512, 1024)
        factor = 2 if bilinear else 1
        self.down4 = Down(1024, 2048 // factor)
        
        # Decoder path
        self.up1 = Up(2048, 1024 // factor, bilinear)
        self.up2 = Up(1024, 512 // factor, bilinear)
        self.up3 = Up(512, 256 // factor, bilinear)
        self.up4 = Up(256, 128, bilinear)
        
        # Output
        self.outc = nn.Conv2d(128, n_classes, kernel_size=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Decoder with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        logits = self.outc(x)
        return logits


# Test the model
large_unet = LargerUNet(n_channels=3, n_classes=21)
x = torch.randn(1, 3, 256, 256)
output = large_unet(x)

print(f"Larger U-Net Architecture")
print("="*50)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Parameters: {sum(p.numel() for p in large_unet.parameters()):,}")

## Exercise Solution: Attention U-Net

U-Net with attention gates that focus on relevant features during upsampling.

In [None]:
class AttentionGate(nn.Module):
    """
    Attention Gate for skip connections.
    
    Learns to focus on relevant spatial features from the encoder
    based on the decoder's context.
    
    Paper: "Attention U-Net: Learning Where to Look" (Oktay et al., 2018)
    """
    
    def __init__(self, gate_channels: int, skip_channels: int, inter_channels: int):
        super(AttentionGate, self).__init__()
        
        # Transform gate signal (from decoder)
        self.W_g = nn.Sequential(
            nn.Conv2d(gate_channels, inter_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(inter_channels)
        )
        
        # Transform skip connection (from encoder)
        self.W_x = nn.Sequential(
            nn.Conv2d(skip_channels, inter_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(inter_channels)
        )
        
        # Compute attention coefficients
        self.psi = nn.Sequential(
            nn.Conv2d(inter_channels, 1, kernel_size=1, bias=False),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            g: Gate signal from decoder (coarse, low-res)
            x: Skip connection from encoder (fine, high-res)
        
        Returns:
            Attended skip connection features
        """
        # Resize gate signal to match skip connection
        g1 = F.interpolate(g, size=x.shape[2:], mode='bilinear', align_corners=True)
        g1 = self.W_g(g1)
        
        x1 = self.W_x(x)
        
        # Combine and compute attention
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        
        return x * psi  # Apply attention to skip features


print("AttentionGate defined!")

In [None]:
class AttentionUNet(nn.Module):
    """
    U-Net with Attention Gates.
    
    Attention gates help the model focus on relevant features
    during the decoding phase.
    """
    
    def __init__(self, n_channels: int = 3, n_classes: int = 21):
        super(AttentionUNet, self).__init__()
        
        # Encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        
        # Attention gates
        self.att4 = AttentionGate(gate_channels=1024, skip_channels=512, inter_channels=256)
        self.att3 = AttentionGate(gate_channels=512, skip_channels=256, inter_channels=128)
        self.att2 = AttentionGate(gate_channels=256, skip_channels=128, inter_channels=64)
        self.att1 = AttentionGate(gate_channels=128, skip_channels=64, inter_channels=32)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv_up4 = DoubleConv(1024, 512)
        
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv_up3 = DoubleConv(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv_up2 = DoubleConv(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv_up1 = DoubleConv(128, 64)
        
        # Output
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Decoder with attention
        d4 = self.up4(x5)
        x4 = self.att4(d4, x4)  # Attend to encoder features
        d4 = torch.cat([x4, d4], dim=1)
        d4 = self.conv_up4(d4)
        
        d3 = self.up3(d4)
        x3 = self.att3(d3, x3)
        d3 = torch.cat([x3, d3], dim=1)
        d3 = self.conv_up3(d3)
        
        d2 = self.up2(d3)
        x2 = self.att2(d2, x2)
        d2 = torch.cat([x2, d2], dim=1)
        d2 = self.conv_up2(d2)
        
        d1 = self.up1(d2)
        x1 = self.att1(d1, x1)
        d1 = torch.cat([x1, d1], dim=1)
        d1 = self.conv_up1(d1)
        
        return self.outc(d1)


# Test Attention U-Net
att_unet = AttentionUNet(n_channels=3, n_classes=21)
x = torch.randn(1, 3, 256, 256)
output = att_unet(x)

print(f"Attention U-Net Architecture")
print("="*50)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Parameters: {sum(p.numel() for p in att_unet.parameters()):,}")

## Exercise Solution: Segmentation Metrics

Implementation of common segmentation metrics.

In [None]:
def dice_coefficient(pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-6) -> torch.Tensor:
    """
    Calculate Dice coefficient.
    
    Dice = 2 * |A ∩ B| / (|A| + |B|)
    
    Args:
        pred: Predicted mask (binary or probabilities)
        target: Ground truth mask (binary)
        smooth: Smoothing factor to avoid division by zero
    
    Returns:
        Dice coefficient (0 to 1, higher is better)
    """
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    
    intersection = (pred_flat * target_flat).sum()
    
    dice = (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
    return dice


def iou_score(pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-6) -> torch.Tensor:
    """
    Calculate Intersection over Union (IoU / Jaccard Index).
    
    IoU = |A ∩ B| / |A ∪ B|
    
    Args:
        pred: Predicted mask
        target: Ground truth mask
        smooth: Smoothing factor
    
    Returns:
        IoU score (0 to 1, higher is better)
    """
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    
    intersection = (pred_flat * target_flat).sum()
    union = pred_flat.sum() + target_flat.sum() - intersection
    
    iou = (intersection + smooth) / (union + smooth)
    return iou


def pixel_accuracy(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Calculate pixel-wise accuracy.
    
    Args:
        pred: Predicted class indices [B, H, W]
        target: Ground truth class indices [B, H, W]
    
    Returns:
        Pixel accuracy (0 to 1)
    """
    correct = (pred == target).float().sum()
    total = target.numel()
    return correct / total


# Demonstration
print("Segmentation Metrics Demo")
print("="*50)

# Create sample predictions and targets
pred = torch.zeros(1, 256, 256)
pred[:, 50:200, 50:200] = 1  # Predicted region

target = torch.zeros(1, 256, 256)
target[:, 60:190, 60:190] = 1  # Ground truth region

dice = dice_coefficient(pred, target)
iou = iou_score(pred, target)
acc = pixel_accuracy(pred.long(), target.long())

print(f"Dice Coefficient: {dice:.4f}")
print(f"IoU Score: {iou:.4f}")
print(f"Pixel Accuracy: {acc:.4f}")

## Exercise Solution: Dice Loss

Dice loss is often used for segmentation tasks, especially with imbalanced classes.

In [None]:
class DiceLoss(nn.Module):
    """
    Dice Loss for segmentation.
    
    Dice Loss = 1 - Dice Coefficient
    
    Benefits:
    - Works well with imbalanced classes
    - Directly optimizes the evaluation metric
    """
    
    def __init__(self, smooth: float = 1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pred: Predictions [B, C, H, W] (logits or probabilities)
            target: Ground truth [B, H, W] (class indices)
        """
        num_classes = pred.shape[1]
        
        # Convert to probabilities
        pred_soft = F.softmax(pred, dim=1)
        
        # One-hot encode target
        target_one_hot = F.one_hot(target.long(), num_classes)  # [B, H, W, C]
        target_one_hot = target_one_hot.permute(0, 3, 1, 2).float()  # [B, C, H, W]
        
        # Calculate dice for each class
        dice_total = 0
        for c in range(num_classes):
            pred_c = pred_soft[:, c]
            target_c = target_one_hot[:, c]
            
            intersection = (pred_c * target_c).sum()
            dice_c = (2. * intersection + self.smooth) / (pred_c.sum() + target_c.sum() + self.smooth)
            dice_total += dice_c
        
        dice_avg = dice_total / num_classes
        return 1 - dice_avg


class CombinedLoss(nn.Module):
    """
    Combined Cross-Entropy + Dice Loss.
    
    Balances pixel-wise accuracy (CE) with region overlap (Dice).
    """
    
    def __init__(self, ce_weight: float = 0.5, dice_weight: float = 0.5):
        super(CombinedLoss, self).__init__()
        self.ce = nn.CrossEntropyLoss()
        self.dice = DiceLoss()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        ce_loss = self.ce(pred, target.long())
        dice_loss = self.dice(pred, target)
        return self.ce_weight * ce_loss + self.dice_weight * dice_loss


# Test losses
pred = torch.randn(2, 21, 64, 64)  # [B, C, H, W]
target = torch.randint(0, 21, (2, 64, 64))  # [B, H, W]

dice_loss = DiceLoss()
combined_loss = CombinedLoss()

print(f"Dice Loss: {dice_loss(pred, target):.4f}")
print(f"Combined Loss: {combined_loss(pred, target):.4f}")

## Summary

Key concepts covered:

1. **Larger U-Net**: Doubled channel dimensions for more capacity
2. **Attention U-Net**: Attention gates for better feature selection
3. **Segmentation Metrics**: Dice coefficient, IoU, pixel accuracy
4. **Dice Loss**: Region-based loss for imbalanced segmentation

Best practices:
- Use combined CE + Dice loss for stable training
- Monitor both pixel accuracy and IoU metrics
- Consider attention mechanisms for complex scenes

In [None]:
# Cleanup
import gc
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
print("Cleanup complete!")