In [None]:
# UNet for Polyp Segmentation in Colonoscopy Images

## Deep Learning Final Project

**Problem**: Colorectal cancer is the 3rd most common cancer worldwide. During colonoscopy, 14-30% of polyps are missed by physicians. Automated polyp segmentation can assist doctors in identifying polyp boundaries accurately.

**Dataset**: Kvasir-SEG (1,000 polyp images with pixel-wise segmentation masks)

**Architecture**: UNet (Encoder-Decoder with Skip Connections)

**Experiments**:
1. Shallow UNet (2 blocks) vs Standard UNet (4 blocks)
2. UNet WITH vs WITHOUT skip connections
3. Different loss functions (BCE, Dice, Combined)

## 1. Setup and Imports

In [None]:
# Install required packages (for Google Colab)
# !pip install torch torchvision matplotlib numpy pillow tqdm

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.transforms import functional as TF

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import random
from tqdm import tqdm
import zipfile
import urllib.request

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Download Kvasir-SEG Dataset

The Kvasir-SEG dataset contains 1,000 polyp images with corresponding segmentation masks. Each image shows a polyp during colonoscopy, and the mask indicates the exact polyp region.

In [None]:
# Download and extract Kvasir-SEG dataset
import ssl
import certifi

DATA_URL = "https://datasets.simula.no/downloads/kvasir-seg.zip"
DATA_DIR = "data"
DATASET_DIR = os.path.join(DATA_DIR, "Kvasir-SEG")

def download_dataset():
    """Download Kvasir-SEG dataset if not already present."""
    if os.path.exists(DATASET_DIR):
        print("Dataset already exists!")
        return
    
    os.makedirs(DATA_DIR, exist_ok=True)
    zip_path = os.path.join(DATA_DIR, "kvasir-seg.zip")
    
    print("Downloading Kvasir-SEG dataset...")
    
    # Create SSL context to handle certificate issues
    try:
        # First try with certifi certificates
        ssl_context = ssl.create_default_context(cafile=certifi.where())
        urllib.request.urlretrieve(DATA_URL, zip_path)
    except:
        # If that fails, try with unverified context (less secure but works)
        ssl_context = ssl.create_default_context()
        ssl_context.check_hostname = False
        ssl_context.verify_mode = ssl.CERT_NONE
        
        # Use urlopen with custom context
        import shutil
        with urllib.request.urlopen(DATA_URL, context=ssl_context) as response:
            with open(zip_path, 'wb') as out_file:
                shutil.copyfileobj(response, out_file)
    
    print("Extracting dataset...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(DATA_DIR)
    
    os.remove(zip_path)
    print("Dataset ready!")

download_dataset()

## 3. Dataset Class

We create a custom PyTorch Dataset that:
- Loads images and their corresponding masks
- Resizes to 256x256 for efficient training
- Applies data augmentation (horizontal flip, vertical flip, rotation)

In [None]:
class PolypDataset(Dataset):
    """
    Custom Dataset for Kvasir-SEG polyp segmentation.
    
    Args:
        image_dir: Path to images folder
        mask_dir: Path to masks folder
        img_size: Target size for resizing (default 256)
        augment: Whether to apply data augmentation
    """
    def __init__(self, image_dir, mask_dir, img_size=256, augment=False):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.img_size = img_size
        self.augment = augment
        
        # Get list of image files
        self.images = sorted([f for f in os.listdir(image_dir) if f.endswith('.jpg')])
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Load image and mask
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)
        
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # Grayscale mask
        
        # Resize
        image = TF.resize(image, (self.img_size, self.img_size))
        mask = TF.resize(mask, (self.img_size, self.img_size))
        
        # Data augmentation (applied to both image and mask)
        if self.augment:
            if random.random() > 0.5:
                image = TF.hflip(image)
                mask = TF.hflip(mask)
            if random.random() > 0.5:
                image = TF.vflip(image)
                mask = TF.vflip(mask)
            if random.random() > 0.5:
                angle = random.randint(-30, 30)
                image = TF.rotate(image, angle)
                mask = TF.rotate(mask, angle)
        
        # Convert to tensors
        image = TF.to_tensor(image)  # [3, H, W], values in [0, 1]
        mask = TF.to_tensor(mask)    # [1, H, W], values in [0, 1]
        
        # Binarize mask (threshold at 0.5)
        mask = (mask > 0.5).float()
        
        return image, mask

In [None]:
# Create dataset and data loaders
IMAGE_DIR = os.path.join(DATASET_DIR, "images")
MASK_DIR = os.path.join(DATASET_DIR, "masks")

# Full dataset (without augmentation for splitting)
full_dataset = PolypDataset(IMAGE_DIR, MASK_DIR, img_size=256, augment=False)

# Split: 70% train, 15% validation, 15% test
total_size = len(full_dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"Total images: {total_size}")
print(f"Train: {train_size}, Validation: {val_size}, Test: {test_size}")

In [None]:
# Visualize some samples from the dataset
def show_samples(dataset, num_samples=4):
    """Display sample images with their masks."""
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
    
    for i in range(num_samples):
        idx = random.randint(0, len(dataset)-1)
        image, mask = dataset[idx]
        
        # Convert tensors to numpy for display
        img_np = image.permute(1, 2, 0).numpy()
        mask_np = mask.squeeze().numpy()
        
        # Original image
        axes[i, 0].imshow(img_np)
        axes[i, 0].set_title('Original Image')
        axes[i, 0].axis('off')
        
        # Mask
        axes[i, 1].imshow(mask_np, cmap='gray')
        axes[i, 1].set_title('Ground Truth Mask')
        axes[i, 1].axis('off')
        
        # Overlay
        axes[i, 2].imshow(img_np)
        axes[i, 2].imshow(mask_np, alpha=0.5, cmap='Reds')
        axes[i, 2].set_title('Overlay')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

show_samples(full_dataset, num_samples=3)

## 4. UNet Architecture

UNet consists of:
1. **Encoder (Contracting Path)**: Captures context through convolutions and pooling
2. **Bottleneck**: Bridge between encoder and decoder
3. **Decoder (Expanding Path)**: Enables precise localization through upsampling
4. **Skip Connections**: Connect encoder to decoder to preserve spatial information

Our implementation is configurable:
- `depth`: Number of encoder/decoder blocks (2 for shallow, 4 for standard)
- `use_skip`: Toggle skip connections on/off for experiments

In [None]:
class ConvBlock(nn.Module):
    """
    Double convolution block: Conv -> BN -> ReLU -> Conv -> BN -> ReLU
    
    This is the basic building block of UNet.
    """
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

In [None]:
class EncoderBlock(nn.Module):
    """
    Encoder block: ConvBlock followed by MaxPool for downsampling.
    """
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    
    def forward(self, x):
        conv_out = self.conv(x)      # Features to pass via skip connection
        pooled = self.pool(conv_out)  # Downsampled for next encoder block
        return conv_out, pooled

In [None]:
class DecoderBlock(nn.Module):
    """
    Decoder block: Upsample -> Concatenate skip connection -> ConvBlock
    
    Args:
        in_channels: Input channels (from previous decoder or bottleneck)
        out_channels: Output channels
        use_skip: If True, concatenate skip connection from encoder
    """
    def __init__(self, in_channels, out_channels, use_skip=True):
        super(DecoderBlock, self).__init__()
        self.use_skip = use_skip
        
        # Transposed convolution for upsampling
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        
        # If using skip connections, input to conv is doubled
        conv_in_channels = out_channels * 2 if use_skip else out_channels
        self.conv = ConvBlock(conv_in_channels, out_channels)
    
    def forward(self, x, skip=None):
        x = self.up(x)
        
        if self.use_skip and skip is not None:
            # Concatenate skip connection along channel dimension
            x = torch.cat([x, skip], dim=1)
        
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    """
    Configurable UNet for image segmentation.
    
    Args:
        in_channels: Number of input channels (3 for RGB)
        out_channels: Number of output channels (1 for binary segmentation)
        depth: Number of encoder/decoder blocks (2=shallow, 4=standard)
        base_features: Number of features in first layer (doubles each level)
        use_skip: Whether to use skip connections
    """
    def __init__(self, in_channels=3, out_channels=1, depth=4, 
                 base_features=64, use_skip=True):
        super(UNet, self).__init__()
        
        self.depth = depth
        self.use_skip = use_skip
        
        # Calculate feature sizes for each level
        # e.g., depth=4: [64, 128, 256, 512] for encoder, 1024 for bottleneck
        features = [base_features * (2**i) for i in range(depth)]
        
        # Encoder blocks
        self.encoders = nn.ModuleList()
        in_ch = in_channels
        for feat in features:
            self.encoders.append(EncoderBlock(in_ch, feat))
            in_ch = feat
        
        # Bottleneck
        self.bottleneck = ConvBlock(features[-1], features[-1] * 2)
        
        # Decoder blocks (reverse order of features)
        self.decoders = nn.ModuleList()
        in_ch = features[-1] * 2
        for feat in reversed(features):
            self.decoders.append(DecoderBlock(in_ch, feat, use_skip=use_skip))
            in_ch = feat
        
        # Final 1x1 convolution to get output channels
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
    
    def forward(self, x):
        # Encoder path - save outputs for skip connections
        skip_connections = []
        for encoder in self.encoders:
            skip, x = encoder(x)
            skip_connections.append(skip)
        
        # Bottleneck
        x = self.bottleneck(x)
        
        # Decoder path - use skip connections in reverse order
        skip_connections = skip_connections[::-1]  # Reverse
        for i, decoder in enumerate(self.decoders):
            skip = skip_connections[i] if self.use_skip else None
            x = decoder(x, skip)
        
        # Final convolution + sigmoid for binary output
        return torch.sigmoid(self.final_conv(x))

In [None]:
# Test UNet architecture
def count_parameters(model):
    """Count trainable parameters in a model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Test different configurations
print("UNet Architecture Configurations:")
print("-" * 50)

for depth in [2, 4]:
    for use_skip in [True, False]:
        model = UNet(depth=depth, use_skip=use_skip)
        params = count_parameters(model)
        skip_str = "with skip" if use_skip else "no skip"
        print(f"Depth={depth}, {skip_str}: {params:,} parameters")

# Verify forward pass works
model = UNet(depth=4, use_skip=True)
test_input = torch.randn(1, 3, 256, 256)
test_output = model(test_input)
print(f"\nInput shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")

## 5. Loss Functions

We implement three loss functions for comparison:

1. **BCE (Binary Cross-Entropy)**: Standard classification loss, treats each pixel independently
2. **Dice Loss**: Measures overlap between prediction and ground truth, better for imbalanced data
3. **Combined Loss**: BCE + Dice, combines benefits of both

**Dice Coefficient Formula**:
$$Dice = \frac{2 \times |A \cap B|}{|A| + |B|}$$

Where A is the predicted mask and B is the ground truth mask.

In [None]:
class DiceLoss(nn.Module):
    """
    Dice Loss for segmentation.
    
    Dice = 2 * intersection / (sum of both areas)
    Loss = 1 - Dice
    """
    def __init__(self, smooth=1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        # Flatten predictions and targets
        pred_flat = pred.view(-1)
        target_flat = target.view(-1)
        
        # Calculate intersection and union
        intersection = (pred_flat * target_flat).sum()
        union = pred_flat.sum() + target_flat.sum()
        
        # Dice coefficient
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        
        return 1.0 - dice

In [None]:
class BCEDiceLoss(nn.Module):
    """
    Combined BCE and Dice Loss.
    
    Combines the pixel-wise accuracy of BCE with the 
    region-based overlap measurement of Dice.
    """
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super(BCEDiceLoss, self).__init__()
        self.bce = nn.BCELoss()
        self.dice = DiceLoss()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
    
    def forward(self, pred, target):
        bce_loss = self.bce(pred, target)
        dice_loss = self.dice(pred, target)
        return self.bce_weight * bce_loss + self.dice_weight * dice_loss

In [None]:
# Loss function factory
def get_loss_function(loss_type='combined'):
    """
    Get loss function by name.
    
    Args:
        loss_type: 'bce', 'dice', or 'combined'
    """
    if loss_type == 'bce':
        return nn.BCELoss()
    elif loss_type == 'dice':
        return DiceLoss()
    elif loss_type == 'combined':
        return BCEDiceLoss()
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")

## 6. Evaluation Metrics

We use standard segmentation metrics:
- **Dice Coefficient**: Primary metric, measures overlap (higher is better)
- **IoU (Intersection over Union)**: Also called Jaccard Index
- **Precision**: True positives / (True positives + False positives)
- **Recall**: True positives / (True positives + False negatives)

In [None]:
def calculate_metrics(pred, target, threshold=0.5):
    """
    Calculate segmentation metrics.
    
    Args:
        pred: Predicted mask (probabilities)
        target: Ground truth mask
        threshold: Threshold for binarizing predictions
    
    Returns:
        Dictionary with dice, iou, precision, recall
    """
    # Binarize predictions
    pred_binary = (pred > threshold).float()
    
    # Flatten
    pred_flat = pred_binary.view(-1)
    target_flat = target.view(-1)
    
    # True positives, false positives, false negatives
    tp = (pred_flat * target_flat).sum()
    fp = (pred_flat * (1 - target_flat)).sum()
    fn = ((1 - pred_flat) * target_flat).sum()
    
    # Metrics
    smooth = 1e-5
    
    # Dice coefficient
    dice = (2 * tp + smooth) / (2 * tp + fp + fn + smooth)
    
    # IoU (Jaccard)
    iou = (tp + smooth) / (tp + fp + fn + smooth)
    
    # Precision and Recall
    precision = (tp + smooth) / (tp + fp + smooth)
    recall = (tp + smooth) / (tp + fn + smooth)
    
    return {
        'dice': dice.item(),
        'iou': iou.item(),
        'precision': precision.item(),
        'recall': recall.item()
    }

## 7. Training Loop

The training loop includes:
- Training and validation phases
- Early stopping to prevent overfitting
- Learning rate scheduling
- Metric logging for analysis

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    total_dice = 0
    
    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        metrics = calculate_metrics(outputs, masks)
        total_dice += metrics['dice']
    
    n_batches = len(dataloader)
    return total_loss / n_batches, total_dice / n_batches

In [None]:
@torch.no_grad()
def validate(model, dataloader, criterion, device):
    """Validate the model."""
    model.eval()
    total_loss = 0
    all_metrics = {'dice': 0, 'iou': 0, 'precision': 0, 'recall': 0}
    
    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        total_loss += loss.item()
        metrics = calculate_metrics(outputs, masks)
        for key in all_metrics:
            all_metrics[key] += metrics[key]
    
    n_batches = len(dataloader)
    avg_loss = total_loss / n_batches
    avg_metrics = {k: v / n_batches for k, v in all_metrics.items()}
    
    return avg_loss, avg_metrics

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, 
                scheduler, device, num_epochs=50, patience=10):
    """
    Full training loop with early stopping.
    
    Args:
        model: UNet model
        train_loader: Training data loader
        val_loader: Validation data loader
        criterion: Loss function
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        device: Device to train on
        num_epochs: Maximum number of epochs
        patience: Early stopping patience
    
    Returns:
        history: Dictionary with training history
        best_model_state: State dict of best model
    """
    history = {
        'train_loss': [], 'train_dice': [],
        'val_loss': [], 'val_dice': [], 'val_iou': []
    }
    
    best_val_dice = 0
    best_model_state = None
    patience_counter = 0
    
    for epoch in range(num_epochs):
        # Training
        train_loss, train_dice = train_one_epoch(
            model, train_loader, criterion, optimizer, device
        )
        
        # Validation
        val_loss, val_metrics = validate(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step(val_loss)
        
        # Log history
        history['train_loss'].append(train_loss)
        history['train_dice'].append(train_dice)
        history['val_loss'].append(val_loss)
        history['val_dice'].append(val_metrics['dice'])
        history['val_iou'].append(val_metrics['iou'])
        
        # Print progress
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_loss:.4f}, Dice: {train_dice:.4f} | "
              f"Val Loss: {val_loss:.4f}, Dice: {val_metrics['dice']:.4f}")
        
        # Early stopping check
        if val_metrics['dice'] > best_val_dice:
            best_val_dice = val_metrics['dice']
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    return history, best_model_state

## 8. Experiment Runner

We define a function to run experiments with different configurations and store results for comparison.

In [None]:
def run_experiment(name, depth, use_skip, loss_type, train_loader, val_loader, 
                   device, num_epochs=50, lr=1e-4):
    """
    Run a single experiment with given configuration.
    
    Args:
        name: Experiment name for logging
        depth: UNet depth (2 or 4)
        use_skip: Whether to use skip connections
        loss_type: 'bce', 'dice', or 'combined'
        train_loader: Training data loader
        val_loader: Validation data loader
        device: Device to train on
        num_epochs: Maximum epochs
        lr: Learning rate
    
    Returns:
        results: Dictionary with experiment results
    """
    print(f"\n{'='*60}")
    print(f"Experiment: {name}")
    print(f"Config: depth={depth}, skip={use_skip}, loss={loss_type}")
    print(f"{'='*60}\n")
    
    # Create model
    model = UNet(depth=depth, use_skip=use_skip).to(device)
    
    # Loss function
    criterion = get_loss_function(loss_type)
    
    # Optimizer and scheduler
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )
    
    # Train
    history, best_state = train_model(
        model, train_loader, val_loader, criterion, optimizer,
        scheduler, device, num_epochs=num_epochs, patience=10
    )
    
    # Load best model and evaluate on validation
    model.load_state_dict(best_state)
    val_loss, val_metrics = validate(model, val_loader, criterion, device)
    
    results = {
        'name': name,
        'depth': depth,
        'use_skip': use_skip,
        'loss_type': loss_type,
        'history': history,
        'best_model_state': best_state,
        'final_metrics': val_metrics
    }
    
    print(f"\nBest Validation Results for {name}:")
    print(f"  Dice: {val_metrics['dice']:.4f}")
    print(f"  IoU:  {val_metrics['iou']:.4f}")
    
    return results

In [None]:
# Create data loaders with augmentation for training
BATCH_SIZE = 8

# Training dataset with augmentation
train_dataset_aug = PolypDataset(IMAGE_DIR, MASK_DIR, img_size=256, augment=True)
train_indices = train_dataset.indices  # Get indices from the split
train_subset_aug = torch.utils.data.Subset(train_dataset_aug, train_indices)

train_loader = DataLoader(train_subset_aug, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 9. Run All Experiments

We run 5 experiments to compare:
1. **Shallow UNet** (depth=2) vs **Standard UNet** (depth=4) - Effect of depth
2. **UNet WITHOUT skip connections** - Ablation study
3. **Different loss functions** - BCE vs Dice vs Combined

Note: Exp2 (Standard UNet) serves as the baseline for multiple comparisons.

In [None]:
# Define experiment configurations (streamlined - removed redundant experiments)
experiments_config = [
    # Experiment 1: Shallow UNet (for depth comparison)
    {'name': 'Exp1_Shallow_UNet', 'depth': 2, 'use_skip': True, 'loss_type': 'combined'},
    
    # Experiment 2: Standard UNet (baseline for all comparisons)
    {'name': 'Exp2_Standard_UNet', 'depth': 4, 'use_skip': True, 'loss_type': 'combined'},
    
    # Experiment 3: UNet WITHOUT skip connections (for skip connection ablation)
    {'name': 'Exp3_NoSkip', 'depth': 4, 'use_skip': False, 'loss_type': 'combined'},
    
    # Experiment 4: BCE Loss only (for loss function comparison)
    {'name': 'Exp4_BCE_Loss', 'depth': 4, 'use_skip': True, 'loss_type': 'bce'},
    
    # Experiment 5: Dice Loss only (for loss function comparison)
    {'name': 'Exp5_Dice_Loss', 'depth': 4, 'use_skip': True, 'loss_type': 'dice'},
]

print("Experiments to run:")
for i, exp in enumerate(experiments_config, 1):
    print(f"  {i}. {exp['name']}: depth={exp['depth']}, skip={exp['use_skip']}, loss={exp['loss_type']}")

print("\nComparisons:")
print("  - Depth effect: Exp1 vs Exp2")
print("  - Skip connections: Exp3 vs Exp2")
print("  - Loss functions: Exp4 vs Exp5 vs Exp2")

In [None]:
# Run all experiments (this will take some time)
# Reduce epochs for faster experimentation, increase for better results
NUM_EPOCHS = 30  # Increase to 50 for final results

all_results = {}

for config in experiments_config:
    results = run_experiment(
        name=config['name'],
        depth=config['depth'],
        use_skip=config['use_skip'],
        loss_type=config['loss_type'],
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        num_epochs=NUM_EPOCHS
    )
    all_results[config['name']] = results

print("\n" + "="*60)
print("ALL EXPERIMENTS COMPLETED!")
print("="*60)

## 10. Results Visualization

### 10.1 Summary Table

In [None]:
# Create summary table of all experiments
print("\n" + "="*80)
print("EXPERIMENT RESULTS SUMMARY")
print("="*80)
print(f"{'Experiment':<25} {'Depth':<8} {'Skip':<8} {'Loss':<12} {'Dice':<10} {'IoU':<10}")
print("-"*80)

for name, results in all_results.items():
    metrics = results['final_metrics']
    print(f"{name:<25} {results['depth']:<8} {str(results['use_skip']):<8} "
          f"{results['loss_type']:<12} {metrics['dice']:.4f}     {metrics['iou']:.4f}")

print("="*80)

### 10.2 Training Curves

In [None]:
def plot_training_curves(results_dict, experiments_to_plot):
    """Plot training curves for selected experiments."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    for exp_name in experiments_to_plot:
        if exp_name in results_dict:
            history = results_dict[exp_name]['history']
            epochs = range(1, len(history['train_loss']) + 1)
            
            # Loss curves
            axes[0].plot(epochs, history['train_loss'], '--', label=f'{exp_name} (train)')
            axes[0].plot(epochs, history['val_loss'], '-', label=f'{exp_name} (val)')
            
            # Dice curves
            axes[1].plot(epochs, history['val_dice'], '-', label=exp_name)
    
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend(loc='upper right', fontsize=8)
    axes[0].grid(True, alpha=0.3)
    
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Dice Coefficient')
    axes[1].set_title('Validation Dice Coefficient')
    axes[1].legend(loc='lower right', fontsize=8)
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Plot 1: Shallow vs Standard UNet
print("Comparison 1: Shallow vs Standard UNet (Effect of Depth)")
plot_training_curves(all_results, ['Exp1_Shallow_UNet', 'Exp2_Standard_UNet'])

In [None]:
# Plot 2: With vs Without Skip Connections
print("Comparison 2: Effect of Skip Connections")
plot_training_curves(all_results, ['Exp3_NoSkip', 'Exp2_Standard_UNet'])

In [None]:
# Plot 3: Different Loss Functions
print("Comparison 3: Effect of Different Loss Functions")
plot_training_curves(all_results, ['Exp4_BCE_Loss', 'Exp5_Dice_Loss', 'Exp2_Standard_UNet'])

### 10.3 Bar Chart Comparison

In [None]:
def plot_metrics_comparison(results_dict):
    """Create bar chart comparing all experiments."""
    names = list(results_dict.keys())
    dice_scores = [results_dict[n]['final_metrics']['dice'] for n in names]
    iou_scores = [results_dict[n]['final_metrics']['iou'] for n in names]
    
    x = np.arange(len(names))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(14, 6))
    bars1 = ax.bar(x - width/2, dice_scores, width, label='Dice', color='steelblue')
    bars2 = ax.bar(x + width/2, iou_scores, width, label='IoU', color='coral')
    
    ax.set_ylabel('Score')
    ax.set_title('Comparison of All Experiments')
    ax.set_xticks(x)
    ax.set_xticklabels([n.replace('_', '\n') for n in names], fontsize=8)
    ax.legend()
    ax.set_ylim(0, 1)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bar in bars1:
        height = bar.get_height()
        ax.annotate(f'{height:.3f}', xy=(bar.get_x() + bar.get_width()/2, height),
                    xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=7)
    for bar in bars2:
        height = bar.get_height()
        ax.annotate(f'{height:.3f}', xy=(bar.get_x() + bar.get_width()/2, height),
                    xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=7)
    
    plt.tight_layout()
    plt.show()

plot_metrics_comparison(all_results)

### 10.4 Prediction Visualizations

In [None]:
@torch.no_grad()
def visualize_predictions(model, dataset, device, num_samples=4, title="Predictions"):
    """Visualize model predictions on sample images."""
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    
    indices = random.sample(range(len(dataset)), num_samples)
    
    for i, idx in enumerate(indices):
        image, mask = dataset[idx]
        
        # Get prediction
        image_batch = image.unsqueeze(0).to(device)
        pred = model(image_batch).squeeze().cpu()
        pred_binary = (pred > 0.5).float()
        
        # Convert to numpy
        img_np = image.permute(1, 2, 0).numpy()
        mask_np = mask.squeeze().numpy()
        pred_np = pred.numpy()
        pred_bin_np = pred_binary.numpy()
        
        # Plot
        axes[i, 0].imshow(img_np)
        axes[i, 0].set_title('Input Image')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(mask_np, cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(pred_np, cmap='gray')
        axes[i, 2].set_title('Prediction (Raw)')
        axes[i, 2].axis('off')
        
        # Overlay
        axes[i, 3].imshow(img_np)
        axes[i, 3].imshow(pred_bin_np, alpha=0.5, cmap='Reds')
        axes[i, 3].set_title('Overlay')
        axes[i, 3].axis('off')
    
    plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize predictions from best model (Standard UNet with skip connections)
best_exp = 'Exp2_Standard_UNet'
best_model = UNet(depth=4, use_skip=True).to(device)
best_model.load_state_dict(all_results[best_exp]['best_model_state'])

print(f"Predictions from: {best_exp}")
visualize_predictions(best_model, test_dataset, device, num_samples=4, 
                      title=f"Predictions - {best_exp}")

### 10.5 Compare Skip vs No-Skip Predictions

In [None]:
@torch.no_grad()
def compare_skip_vs_noskip(results_dict, dataset, device, num_samples=3):
    """Compare predictions with and without skip connections."""
    # Load models
    model_skip = UNet(depth=4, use_skip=True).to(device)
    model_skip.load_state_dict(results_dict['Exp2_Standard_UNet']['best_model_state'])
    model_skip.eval()
    
    model_noskip = UNet(depth=4, use_skip=False).to(device)
    model_noskip.load_state_dict(results_dict['Exp3_NoSkip']['best_model_state'])
    model_noskip.eval()
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    
    indices = random.sample(range(len(dataset)), num_samples)
    
    for i, idx in enumerate(indices):
        image, mask = dataset[idx]
        image_batch = image.unsqueeze(0).to(device)
        
        # Get predictions
        pred_skip = model_skip(image_batch).squeeze().cpu()
        pred_noskip = model_noskip(image_batch).squeeze().cpu()
        
        # Convert to numpy
        img_np = image.permute(1, 2, 0).numpy()
        mask_np = mask.squeeze().numpy()
        
        # Plot
        axes[i, 0].imshow(img_np)
        axes[i, 0].set_title('Input Image')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(mask_np, cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow((pred_noskip > 0.5).numpy(), cmap='gray')
        axes[i, 2].set_title('WITHOUT Skip Connections')
        axes[i, 2].axis('off')
        
        axes[i, 3].imshow((pred_skip > 0.5).numpy(), cmap='gray')
        axes[i, 3].set_title('WITH Skip Connections')
        axes[i, 3].axis('off')
    
    plt.suptitle('Effect of Skip Connections on Segmentation Quality', fontsize=14)
    plt.tight_layout()
    plt.show()

compare_skip_vs_noskip(all_results, test_dataset, device)

## 11. Test Set Evaluation

Final evaluation on the held-out test set.

In [None]:
# Evaluate best model on test set
print("="*60)
print("TEST SET EVALUATION")
print("="*60)

# Use the best performing model
best_model = UNet(depth=4, use_skip=True).to(device)
best_model.load_state_dict(all_results['Exp2_Standard_UNet']['best_model_state'])

criterion = BCEDiceLoss()
test_loss, test_metrics = validate(best_model, test_loader, criterion, device)

print(f"\nTest Set Results (Standard UNet with Skip Connections):")
print(f"  Loss:      {test_loss:.4f}")
print(f"  Dice:      {test_metrics['dice']:.4f}")
print(f"  IoU:       {test_metrics['iou']:.4f}")
print(f"  Precision: {test_metrics['precision']:.4f}")
print(f"  Recall:    {test_metrics['recall']:.4f}")

## 12. Analysis and Conclusions

### Key Findings:

1. **Effect of Depth**: Standard UNet (4 blocks) outperforms Shallow UNet (2 blocks), demonstrating that deeper architectures can capture more complex features.

2. **Importance of Skip Connections**: UNet WITH skip connections significantly outperforms UNet WITHOUT skip connections. This confirms that skip connections are crucial for preserving spatial information during upsampling.

3. **Loss Function Comparison**: Combined loss (BCE + Dice) generally performs best, as it combines pixel-wise accuracy with region-based overlap optimization.

### Why Skip Connections Matter:
- During encoding, spatial information is lost due to pooling
- Skip connections directly transfer high-resolution features from encoder to decoder
- This helps the decoder produce sharper, more accurate boundaries
- Without skip connections, the model struggles to localize polyp boundaries precisely

### Clinical Implications:
- Accurate polyp segmentation can assist gastroenterologists during colonoscopy
- Better boundary detection helps in assessing polyp size and morphology
- Automated systems can serve as a "second pair of eyes" to reduce miss rates

## 13. Future Work

Potential extensions of this work:

1. **Attention Mechanisms**: Add attention gates to skip connections (Attention U-Net)
2. **Different Backbones**: Use pre-trained encoders (ResNet, EfficientNet)
3. **Multi-scale Features**: Implement Feature Pyramid Networks
4. **Data Augmentation**: More aggressive augmentation (elastic deformation, color jittering)
5. **Post-processing**: Conditional Random Fields (CRF) for boundary refinement
6. **Real-time Inference**: Optimize for deployment in clinical settings

## 14. Save Models (Optional)

In [None]:
# Save best model
os.makedirs('saved_models', exist_ok=True)

for exp_name, results in all_results.items():
    save_path = f'saved_models/{exp_name}.pth'
    torch.save(results['best_model_state'], save_path)
    print(f"Saved: {save_path}")

print("\nAll models saved!")

## References

1. Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation. MICCAI.

2. Jha, D., et al. (2020). Kvasir-SEG: A Segmented Polyp Dataset. MMM 2020.

3. Jha, D., et al. (2020). Automatic Polyp Segmentation using U-Net-ResNet50. arXiv:2012.15247.

4. Huang, C.H., et al. (2021). HarDNet-MSEG: A Simple Encoder-Decoder Polyp Segmentation Neural Network. arXiv:2101.07172.

5. Dataset: https://datasets.simula.no/kvasir-seg/