# ASPP-Enhanced Residual SE U-Net for Fetal Head Segmentation
## Training Notebook (Google Colab Compatible)

**Platform:** Google Colab with GPU acceleration

### Architecture Overview

This notebook implements an **ASPP-Enhanced Residual SE U-Net** for medical image segmentation with:

**Core Architecture:**
- U-Net encoder-decoder with residual connections and squeeze-and-excitation (SE) blocks
- ResidualBlockSE in encoder/decoder (two 3×3 convs + BatchNorm + ReLU + SE + skip connection)
- MaxPool2d downsampling (encoder) and ConvTranspose2d upsampling (decoder)

**Key Innovations:**

1. **ASPP Module at Bottleneck**: Multi-scale contextual feature extraction
   - 1×1 convolution (point-wise features)
   - 3×3 atrous convolutions with dilation rates [6, 12, 18]
   - Global average pooling branch (image-level features)
   - Captures features at different scales simultaneously

2. **Squeeze-and-Excitation (SE) Blocks**: Channel-wise attention mechanism
   - Applied after every ResidualBlockSE
   - Applied to skip connections before concatenation with decoder
   - Learns to emphasize informative channels and suppress less useful ones
   - Reduction ratio: 16

3. **Residual Connections**: Skip connections within blocks for better gradient flow

**Output:**
- Sigmoid activation for binary segmentation probabilities [0, 1]
- Compatible with DiceBCELoss (combined Dice + BCE loss)

## 1. Environment Setup

### Google Colab Configuration

**Repository:** `https://github.com/TrinhThaiSonDHQT/Fetal-Head-Segmentation`

**Directory Structure:**
- **Project Root:** `/content/Fetal-Head-Segmentation/`
- **Outputs:** `/content/outputs/results/`
  - Checkpoints, logs, predictions, and visualizations
  - Download results after training completes

**Steps:**
1. Clone repository from GitHub
2. Install dependencies (Albumentations 1.3.1)
3. Import required modules and verify CUDA availability

In [None]:
# Clone the GitHub repository
import os

# Check if already cloned
if not os.path.exists('/content/Fetal-Head-Segmentation'):
    print("Cloning repository from GitHub...")
    !git clone https://github.com/TrinhThaiSonDHQT/Fetal-Head-Segmentation.git
    print("✓ Repository cloned successfully")
else:
    print("✓ Repository already exists")

In [None]:
import os
import sys
from pathlib import Path

print("[Google Colab Setup]")

# Setup paths for Google Colab
project_root = Path('/content/Fetal-Head-Segmentation')
output_root = Path('/content/outputs')
cache_root = output_root / 'cache'

# Verify project exists
if not project_root.exists():
    raise RuntimeError(
        f"Project not found at {project_root}\n"
        f"Please run the previous cell to clone the repository from GitHub."
    )

if not (project_root / 'accuracy_focus').exists():
    raise RuntimeError(
        f"'accuracy_focus' folder not found in {project_root}\n"
        f"Please ensure the repository was cloned correctly."
    )

print(f"Project root: {project_root}")
print(f"Output root: {output_root}")

# Add project to path
sys.path.insert(0, str(project_root))

print(f"\n✓ Environment setup complete")

In [None]:
# Install required packages for Google Colab
print("Installing required packages...")

# Colab has most packages pre-installed (PyTorch, NumPy, Matplotlib, OpenCV)
# Pin Albumentations to 1.3.1 for compatibility with both Colab and Kaggle
!pip install -q albumentations==1.3.1

print("\n✓ Packages installed successfully")

In [None]:
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from accuracy_focus.standard_unet.src.losses import DiceBCELoss
from accuracy_focus.improved_unet.src.models.aspp_residual_se_unet.aspp_residual_se_unet_model import ASPPResidualSEUNet, count_parameters

from shared.src.data import HC18Dataset
from shared.src.metrics.segmentation_metrics import dice_coefficient, iou_score, pixel_accuracy
from shared.src.utils.visualization import save_prediction_grid, visualize_sample
from shared.src.utils.transforms import get_transforms

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Load Configuration

In [None]:
# Load configuration (use improved_unet config as template)
config_path = project_root / 'accuracy_focus' / 'improved_unet' / 'configs' / 'improved_unet_config.yaml'

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Update model name and parameters for ASPP-Enhanced Residual SE U-Net
config['model']['name'] = 'ASPPResidualSEUNet'
config['model']['reduction_ratio'] = 16  # SE block reduction ratio
config['model']['atrous_rates'] = [6, 12, 18]  # ASPP dilation rates
config['model']['aspp_dropout'] = 0.5  # ASPP dropout rate

# Adjust output paths for Google Colab environment
print(f"Adjusting paths for Google Colab environment...")
config['logging']['checkpoint_dir'] = str(output_root / 'results' / 'checkpoints')
config['logging']['log_dir'] = str(output_root / 'results' / 'logs')
config['logging']['prediction_dir'] = str(output_root / 'results' / 'predictions')
config['logging']['visualization_dir'] = str(output_root / 'results' / 'visualizations')

print(f"  Outputs will be saved to: {output_root / 'results'}")

print("\nConfiguration loaded:")
print(f"  Model: {config['model']['name']}")
print(f"  Base Filters: {config['model']['base_filters']}")
print(f"  SE Reduction Ratio: {config['model']['reduction_ratio']}")
print(f"  ASPP Atrous Rates: {config['model']['atrous_rates']}")
print(f"  ASPP Dropout: {config['model']['aspp_dropout']}")
print(f"  Learning Rate: {config['training']['optimizer']['lr']}")
print(f"  Loss Function: {config['loss']['name']}")
print(f"  Batch Size: {config['training']['batch_size']}")
print(f"  Epochs: {config['training']['num_epochs']}")

## 3. Model Initialization

### ASPP-Enhanced Residual SE U-Net Architecture Details

**Encoder Path:**
- 4 downsampling blocks: 64 → 128 → 256 → 512 filters
- Each block: ResidualBlockSE (Conv2d + BatchNorm + ReLU) × 2 + SE block + residual skip
- Downsampling: MaxPool2d (2×2, stride=2)
- SE blocks provide channel-wise attention

**Bottleneck (ASPP Module):**
- Multi-scale feature extraction with 1024 total filters
- **1×1 convolution**: Point-wise features
- **3×3 atrous convolutions**: Dilation rates [6, 12, 18] for multi-scale context
- **Global Average Pooling**: Image-level features
- **Dropout (0.5)**: Regularization to prevent overfitting
- **Fusion**: Concatenate all branches and project to 1024 channels

**Decoder Path:**
- 4 upsampling blocks: 512 → 256 → 128 → 64 filters
- Upsampling: ConvTranspose2d (2×2, stride=2)
- Skip connections: SE-enhanced encoder features concatenated with decoder
- Each block: ResidualBlockSE × 2 after concatenation

**Channel Attention (SE Blocks):**
- Applied after every ResidualBlockSE in encoder/decoder
- Applied to skip connections before concatenation
- Squeeze: Global average pooling
- Excitation: FC → ReLU → FC → Sigmoid
- Reduction ratio: 16 (balances performance vs. parameters)

**Output Layer:**
- Conv2d (1×1) to single channel
- **Sigmoid activation** → outputs probabilities [0, 1]
- Compatible with DiceBCELoss

In [None]:
# Set device
device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize ASPP-Enhanced Residual SE U-Net
model = ASPPResidualSEUNet(
    in_channels=config['model']['in_channels'],
    out_channels=config['model']['out_channels'],
    base_channels=config['model']['base_filters'],
    reduction_ratio=config['model']['reduction_ratio'],
    atrous_rates=config['model']['atrous_rates'],
    aspp_dropout=config['model']['aspp_dropout']
)

model = model.to(device)

# Count parameters
total_params, trainable_params = count_parameters(model)

# Model summary
print(f"\nASPP-Enhanced Residual SE U-Net Architecture:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / (1024**2):.2f} MB (float32)")
print(f"  Input channels: {config['model']['in_channels']}")
print(f"  Output channels: {config['model']['out_channels']}")
print(f"  Base filters: {config['model']['base_filters']}")
print(f"  SE reduction ratio: {config['model']['reduction_ratio']}")
print(f"  ASPP atrous rates: {config['model']['atrous_rates']}")
print(f"  ASPP dropout: {config['model']['aspp_dropout']}")
print(f"  Output activation: Sigmoid")

# Test forward pass
test_input = torch.randn(1, 1, 256, 256).to(device)
test_output = model(test_input)
print(f"\nTest forward pass:")
print(f"  Input shape: {test_input.shape}")
print(f"  Output shape: {test_output.shape}")
print(f"  Output range: [{test_output.min().item():.4f}, {test_output.max().item():.4f}]")
print(f"  Output type: Probabilities (sigmoid activated)")

## 4. Loss Function and Optimization

### Loss Function: DiceBCELoss

**Hybrid Loss Design:**
- **Dice Loss (80%):** Optimizes region overlap (DSC metric)
  - Handles class imbalance naturally
  - Directly optimizes the evaluation metric
- **BCE Loss (20%):** Optimizes pixel-wise classification
  - Provides stable gradients
  - Handles boundary refinement

**Key Features:**
- Expects **probabilities** [0, 1] as input (model outputs sigmoid)
- Smooth parameter (1.0) for numerical stability in Dice calculation

### Optimizer: Adam
- Adaptive learning rate per parameter
- Learning rate: 1e-3 (configurable)
- Weight decay: 1e-4 (L2 regularization)

### Learning Rate Scheduler: ReduceLROnPlateau
- Monitors validation Dice coefficient
- Reduces LR by factor of 0.5 when validation plateaus
- Patience: 5 epochs
- Minimum LR: 1e-6
- Helps fine-tune convergence and escape local minima

In [None]:
loss_config = config['loss']
dice_weight_config = loss_config.get('dice_weight', 0.8)
bce_weight_config = loss_config.get('bce_weight', 0.2)
smooth_config = loss_config.get('smooth', 1.0)

# Use DiceBCELoss
criterion = DiceBCELoss(
    dice_weight=dice_weight_config,
    bce_weight=bce_weight_config,
    smooth=smooth_config
)

print(f"Loss Function: DiceBCELoss")
print(f"  Dice weight: {dice_weight_config}")
print(f"  BCE weight: {bce_weight_config}")
print(f"  Smooth parameter: {smooth_config}")

# Optimizer (Adam)
optimizer_config = config['training']['optimizer']
optimizer = optim.Adam(
    model.parameters(),
    lr=optimizer_config['lr'],
    betas=tuple(optimizer_config['betas']),
    eps=optimizer_config['eps'],
    weight_decay=optimizer_config['weight_decay']
)
print(f"\nOptimizer: Adam")
print(f"  Learning rate: {optimizer_config['lr']}")
print(f"  Weight decay: {optimizer_config['weight_decay']}")

# Learning rate scheduler
scheduler_config = config['training']['scheduler']
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode=scheduler_config['mode'],
    factor=scheduler_config['factor'],
    patience=scheduler_config['patience'],
    min_lr=scheduler_config['min_lr']
)
print(f"\nScheduler: ReduceLROnPlateau")
print(f"  Mode: {scheduler_config['mode']}")
print(f"  Factor: {scheduler_config['factor']}")
print(f"  Patience: {scheduler_config['patience']}")
print(f"  Min LR: {scheduler_config['min_lr']}")

## 5. Data Loading and Augmentation

### Preprocessing Pipeline (All Images)

Applied consistently to training, validation, and test sets:
1. **Normalization:** Divide pixel values by 255.0 → [0, 1] range
2. **Resizing:** 256×256 pixels (maintains aspect ratio consistency)
3. **Tensor Conversion:** NumPy array → PyTorch tensor (C×H×W format)

### Data Augmentation (Training Only)

**On-the-fly augmentation** using Albumentations library:
- **HorizontalFlip:** p=0.5 (mirrors left-right)
- **VerticalFlip:** p=0.5 (mirrors top-bottom)
- **Rotation:** ±20° with p=0.5 (handles probe orientation variations)
- **ShiftScaleRotate:** p=0.5
  - Translation: ±10% (handles positioning variations)
  - Scaling: ±10% (handles zoom variations)

**Benefits:**
- Augmentation applied **per epoch** → different samples each time
- Improves model generalization and robustness
- Prevents overfitting on small datasets
- Image-mask transforms synchronized automatically

**Validation/Test:**
- **No augmentation** applied
- Only preprocessing (normalize, resize, tensorize)
- Ensures consistent evaluation metrics

In [None]:
data_config = config['data']
training_config = config['training']

# Helper to build paths
def get_path(config_path):
    """Helper to handle both absolute and relative paths"""
    p = Path(config_path)
    if p.is_absolute():
        return str(p)
    else:
        return str(project_root / config_path)

# Create augmentation transforms
print("Creating augmentation transforms...")
train_transform = get_transforms(height=256, width=256, is_train=True)
val_transform = get_transforms(height=256, width=256, is_train=False)
print("  Train transform: WITH augmentation (HorizontalFlip, VerticalFlip, Rotation, ShiftScaleRotate)")
print("  Val transform: WITHOUT augmentation (resize + normalize only)")

# Create datasets - using HC18Dataset for on-the-fly augmentation
print("\nCreating training dataset...")
train_dataset = HC18Dataset(
    image_dir=get_path(data_config['train_images']),
    mask_dir=get_path(data_config['train_masks']),
    transform=train_transform
)

print("Creating validation dataset...")
val_dataset = HC18Dataset(
    image_dir=get_path(data_config['val_images']),
    mask_dir=get_path(data_config['val_masks']),
    transform=val_transform
)

print("Creating test dataset...")
test_dataset = HC18Dataset(
    image_dir=get_path(data_config['test_images']),
    mask_dir=get_path(data_config['test_masks']),
    transform=val_transform
)

# Use num_workers=0 for Google Colab to avoid multiprocessing issues
num_workers = 0
print(f"\nDataLoader settings:")
print(f"  num_workers: {num_workers} (disabled for Google Colab)")
print(f"  Batch size: {training_config['batch_size']}")

# Data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=training_config['batch_size'],
    shuffle=True,
    num_workers=num_workers,
    pin_memory=training_config['pin_memory']
)

val_loader = DataLoader(
    val_dataset,
    batch_size=training_config['batch_size'],
    shuffle=False,
    num_workers=num_workers,
    pin_memory=training_config['pin_memory']
)

test_loader = DataLoader(
    test_dataset,
    batch_size=training_config['batch_size'],
    shuffle=False,
    num_workers=num_workers,
    pin_memory=training_config['pin_memory']
)

print(f"\n{'='*60}")
print(f"Datasets Ready:")
print(f"{'='*60}")
print(f"  Train samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Test samples: {len(test_dataset)}")
print(f"  Batch size: {training_config['batch_size']}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")
print(f"  Train augmentation: ENABLED (on-the-fly)")
print(f"  Val/Test augmentation: DISABLED")
print(f"{'='*60}\n")

## 6. Data Quality Verification

Verify data integrity before training to catch preprocessing errors early.

In [None]:
# Get a batch of training data
sample_images, sample_masks = next(iter(train_loader))

print(f"Sample batch:")
print(f"  Images shape: {sample_images.shape}")
print(f"  Masks shape: {sample_masks.shape}")
print(f"  Image range: [{sample_images.min():.4f}, {sample_images.max():.4f}]")
print(f"  Mask range: [{sample_masks.min():.4f}, {sample_masks.max():.4f}]")
print(f"  Mask unique values: {torch.unique(sample_masks)}")
print(f"  Mask mean (% foreground): {sample_masks.mean():.4f}")

# CRITICAL CHECK: Ensure masks are binary {0, 1}
if not torch.all((sample_masks == 0) | (sample_masks == 1)):
    print("\n⚠️  WARNING: Masks are not binary! Check preprocessing.")
else:
    print("\n✓ Masks are properly binary {0, 1}")

# Check if masks have reasonable foreground ratio (2-10% typical for fetal head)
fg_ratio = sample_masks.mean().item()
if fg_ratio < 0.01 or fg_ratio > 0.3:
    print(f"⚠️  WARNING: Unusual foreground ratio: {fg_ratio:.2%} (expected 2-10%)")
else:
    print(f"✓ Foreground ratio looks reasonable: {fg_ratio:.2%}")

# Visualize first sample
visualize_sample(sample_images[0], sample_masks[0])

## 7. Training and Validation Functions

### train_one_epoch()
- Sets model to training mode
- Iterates through training batches
- Forward pass → loss calculation → backward pass → optimizer step
- Returns average epoch loss

### validate()
- Sets model to evaluation mode (disables dropout, batchnorm updates)
- Computes loss and metrics on validation set
- Model outputs probabilities [0, 1] (sigmoid activated)
- Thresholds at 0.5 for binary predictions
- Calculates: Dice coefficient, IoU, Pixel Accuracy
- Returns average metrics across all validation samples

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """
    Train for one epoch
    """
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Train]", leave=False)
    for batch_idx, (images, masks) in enumerate(pbar):
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        
        # Calculate loss
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Update statistics
        running_loss += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    epoch_loss = running_loss / len(dataloader)
    return epoch_loss


def validate(model, dataloader, criterion, device, epoch):
    """
    Validate the model
    """
    model.eval()
    running_loss = 0.0
    dice_scores = []
    iou_scores = []
    pa_scores = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Val]", leave=False)
        for batch_idx, (images, masks) in enumerate(pbar):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Calculate loss
            loss = criterion(outputs, masks)
            running_loss += loss.item()
            
            # Calculate metrics
            # Model outputs probabilities [0, 1] (sigmoid activated)
            preds = (outputs > 0.5).float()
            
            for i in range(images.size(0)):
                dice = dice_coefficient(preds[i], masks[i])
                iou = iou_score(preds[i], masks[i])
                pa = pixel_accuracy(preds[i], masks[i])
                
                dice_scores.append(dice.item())
                iou_scores.append(iou.item())
                pa_scores.append(pa.item())
            
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'dice': f"{np.mean(dice_scores):.4f}"
            })
    
    # Calculate average metrics
    val_loss = running_loss / len(dataloader)
    val_dice = np.mean(dice_scores)
    val_iou = np.mean(iou_scores)
    val_pa = np.mean(pa_scores)
    
    return val_loss, val_dice, val_iou, val_pa


print("Training functions defined!")

## 8. Training Loop

### Training Configuration
- **Epochs:** Configurable (typically 50-100)
- **Early Stopping:** Monitors validation Dice coefficient
  - Stops if no improvement for N consecutive epochs
  - Prevents overfitting and saves compute time
- **Model Checkpointing:** Saves best model based on validation Dice

### Per-Epoch Workflow
1. Train on full training set
2. Validate on validation set
3. Update learning rate (ReduceLROnPlateau scheduler)
4. Log metrics: loss, Dice, IoU, pixel accuracy
5. Save model if validation Dice improves
6. Check early stopping criterion

In [None]:
print(f"{'='*60}")
print(f"Starting Training - ASPP-Enhanced Residual SE U-Net")
print(f"{'='*60}")

# Training configuration
num_epochs = config['training']['num_epochs']
patience = config['training']['early_stopping_patience']

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'val_dice': [],
    'val_iou': [],
    'val_pa': [],
    'lr': []
}

best_dice = 0.0
epochs_without_improvement = 0

print(f"Epochs: {num_epochs}")
print(f"Early Stopping Patience: {patience}")
print(f"{'='*60}\n")

# Main training loop
for epoch in range(num_epochs):    
    # Train
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch)
    
    # Validate
    val_loss, val_dice, val_iou, val_pa = validate(model, val_loader, criterion, device, epoch)
    
    # Update learning rate
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_dice)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Update history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_dice'].append(val_dice)
    history['val_iou'].append(val_iou)
    history['val_pa'].append(val_pa)
    history['lr'].append(current_lr)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"{'='*100}")
    dice_indicator = ' 🏆' if val_dice > best_dice else ''
    lr_change = f' ⬇️ (reduced from {old_lr:.6f})' if current_lr < old_lr else ''
    
    # Print metrics
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}{dice_indicator}")
    print(f"Val mIoU: {val_iou:.4f} | Val mPA: {val_pa:.4f}")
    if lr_change: print(f"LR: {current_lr:.6f}{lr_change}")
    print(f"{'='*100}")
    
    # Check for improvement
    is_best = val_dice > best_dice
    if is_best:
        best_dice = val_dice
        epochs_without_improvement = 0
        
        # Save best model
        checkpoint_dir = Path(get_path(config['logging']['checkpoint_dir']))
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        best_model_path = checkpoint_dir / 'best_model_aspp_residual_se_unet.pth'
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_dice': best_dice,
            'history': history,
            'config': config
        }, best_model_path)
        
        print(f"  → Saved best model (Dice: {best_dice:.4f})")
    else:
        epochs_without_improvement += 1
        print(f"  ⚠️  No improvement for {epochs_without_improvement}/{patience} epochs")
    
    # Early stopping
    if epochs_without_improvement >= patience:
        print(f"\n{'='*70}")
        print(f"⛔ EARLY STOPPING TRIGGERED")
        print(f"{'='*70}")
        print(f"  Stopped at epoch: {epoch+1}")
        print(f"  Best Dice Score:  {best_dice:.4f}")
        print(f"  Patience limit:   {patience} epochs without improvement")
        print(f"{'='*70}")
        break

# Final summary
print(f"\n{'='*60}")
print(f"Training Completed!")
print(f"{'='*60}")
print(f"Best Validation Dice: {best_dice:.4f}")
print(f"Best Validation IoU:  {max(history['val_iou']):.4f}")
print(f"Best Validation PA:   {max(history['val_pa']):.4f}")
print(f"{'='*60}\n")

## 9. Training Visualization

Generate plots to analyze training dynamics and convergence:
- **Loss curves:** Training vs validation loss over epochs
- **Dice coefficient:** Validation performance trend
- **IoU score:** Intersection over Union metric progression
- **Learning rate:** ReduceLROnPlateau schedule adjustments

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Loss curves
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss (Dice + BCE)', fontsize=12)
axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=11)
axes[0, 0].grid(True, alpha=0.3)

# Dice coefficient
axes[0, 1].plot(history['val_dice'], label='Val Dice', color='green', linewidth=2)
axes[0, 1].axhline(y=best_dice, color='red', linestyle='--', label=f'Best: {best_dice:.4f}')
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Dice Coefficient', fontsize=12)
axes[0, 1].set_title('Validation Dice Coefficient', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=11)
axes[0, 1].grid(True, alpha=0.3)

# IoU
axes[1, 0].plot(history['val_iou'], label='Val IoU', color='orange', linewidth=2)
axes[1, 0].axhline(y=max(history['val_iou']), color='red', linestyle='--', 
                   label=f"Best: {max(history['val_iou']):.4f}")
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('IoU Score', fontsize=12)
axes[1, 0].set_title('Validation IoU Score', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=11)
axes[1, 0].grid(True, alpha=0.3)

# Learning rate
axes[1, 1].plot(history['lr'], label='Learning Rate', color='red', linewidth=2)
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Learning Rate', fontsize=12)
axes[1, 1].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
axes[1, 1].legend(fontsize=11)
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_yscale('log')

plt.tight_layout()
plt.show()

# Save figure
log_dir = Path(get_path(config['logging']['log_dir']))
log_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(log_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
print(f"Training curves saved to {log_dir / 'training_curves.png'}")

## 10. Model Inference and Results

### Evaluation on Test Set

Load the best checkpoint (highest validation Dice) and visualize predictions:
- Compare input images, ground truth masks, and model predictions
- Calculate per-sample metrics (Dice, IoU, Pixel Accuracy)
- Assess segmentation quality visually

**Note:** Model outputs probabilities [0, 1] (sigmoid activated), thresholded at 0.5 for binary predictions.

In [None]:
# Load best model
checkpoint_path = Path(get_path(config['logging']['checkpoint_dir'])) / 'best_model_aspp_residual_se_unet.pth'
checkpoint = torch.load(checkpoint_path, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"Best Dice Score: {checkpoint['best_dice']:.4f}")

# Get test samples
test_images, test_masks = next(iter(test_loader))
test_images = test_images.to(device)

# Generate predictions
with torch.no_grad():
    # Model outputs probabilities [0, 1] (sigmoid activated)
    test_outputs = model(test_images)
    test_preds = (test_outputs > 0.5).float()

# Visualize predictions
num_samples = min(4, len(test_images))
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))

if num_samples == 1:
    axes = axes.reshape(1, -1)

for i in range(num_samples):
    # Move to CPU and convert to numpy
    img = test_images[i, 0].cpu().numpy()
    mask = test_masks[i, 0].numpy()
    pred = test_preds[i, 0].cpu().numpy()
    
    # Calculate metrics for this sample (ensure both tensors on same device)
    pred_tensor = test_preds[i].cpu()
    mask_tensor = test_masks[i].to(pred_tensor.device)
    dice = dice_coefficient(pred_tensor, mask_tensor).item()
    iou = iou_score(pred_tensor, mask_tensor).item()
    pa = pixel_accuracy(pred_tensor, mask_tensor)
    
    # Input image
    axes[i, 0].imshow(img, cmap='gray')
    axes[i, 0].set_title('Input Image', fontsize=12, fontweight='bold')
    axes[i, 0].axis('off')
    
    # Ground truth
    axes[i, 1].imshow(mask, cmap='gray')
    axes[i, 1].set_title('Ground Truth', fontsize=12, fontweight='bold')
    axes[i, 1].axis('off')
    
    # Prediction
    axes[i, 2].imshow(pred, cmap='gray')
    axes[i, 2].set_title(f'Prediction\nDice: {dice:.4f} | IoU: {iou:.4f} | PA: {pa:.4f}', 
                         fontsize=10, fontweight='bold')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

# Save predictions
pred_dir = Path(get_path(config['logging']['prediction_dir']))
pred_dir.mkdir(parents=True, exist_ok=True)
save_prediction_grid(test_images[:4].cpu(), test_masks[:4], test_preds[:4].cpu(), 
                    str(pred_dir / 'sample_predictions.png'), num_samples=4)
print(f"Sample predictions saved to {pred_dir / 'sample_predictions.png'}")

## 11. Download Results (Google Colab)

After training completes, download results to your local machine.

In [None]:
# Download results from Google Colab
from google.colab import files
import zipfile
import shutil

print("Creating archive of results...")

# Create zip archive
results_dir = output_root / 'results'
archive_path = output_root / 'aspp_residual_se_unet_results.zip'

# Remove old archive if exists
if archive_path.exists():
    archive_path.unlink()

# Create zip file
shutil.make_archive(
    str(output_root / 'aspp_residual_se_unet_results'),
    'zip',
    str(results_dir)
)

print(f"✓ Archive created: {archive_path}")
print(f"Archive size: {archive_path.stat().st_size / (1024**2):.2f} MB")

# Download archive
print("\nDownloading archive...")
files.download(str(archive_path))
print("✓ Download complete!")

## 12. Summary

### Key Innovations:

1. **ASPP Module at Bottleneck**: Multi-scale feature extraction
   - 1×1 convolution for point-wise features
   - 3×3 atrous convolutions with dilation rates [6, 12, 18]
   - Global average pooling for image-level context
   - Captures features at multiple scales simultaneously

2. **Residual Blocks with SE**: Two 3×3 convolutions + BatchNorm + ReLU + SE attention + skip connections
3. **SE Blocks on Skip Connections**: Channel-wise attention before concatenation
4. **Improved Gradient Flow**: Residual connections throughout the network

### Architecture Highlights:

- **Encoder**: 4 stages with ResidualBlockSE (64→128→256→512)
- **Bottleneck**: ASPP module (512→1024 channels, multi-scale context)
- **Decoder**: 4 stages with ResidualBlockSE (512→256→128→64)
- **Total Parameters**: ~38M parameters (~146 MB)

### Training Configuration:

- **Loss**: DiceBCELoss (0.8 Dice + 0.2 BCE)
- **Optimizer**: Adam (lr=1e-3, weight decay=1e-4)
- **Scheduler**: ReduceLROnPlateau (patience=5, factor=0.5)
- **Augmentation**: HorizontalFlip, VerticalFlip, Rotation, ShiftScaleRotate (on-the-fly)
- **Early Stopping**: Monitors validation Dice (patience=10)

### Expected Performance:

The ASPP module should improve segmentation accuracy by capturing multi-scale contextual information, especially useful for:
- Objects at varying scales
- Better boundary detection
- Improved context understanding
- Handling size variations in fetal head across different gestational ages

### Next Steps:

1. Compare with Standard U-Net and Residual SE U-Net (without ASPP)
2. Analyze performance improvement from ASPP
3. Visualize multi-scale features captured by ASPP
4. Evaluate on full HC18 test set
5. Calculate HD95 metric for boundary accuracy assessment

---

**Download all results** using the cell above to analyze on your local machine.