# Fetal Head Segmentation Training (Universal - Colab/Kaggle Compatible)

This notebook implements the complete training pipeline for the Improved U-Net model.

**Target Performance Metrics:**
- DSC (Dice Similarity Coefficient): ≥97.81%
- mIoU (Mean Intersection over Union): ≥97.90%
- mPA (Mean Pixel Accuracy): ≥99.18%

**Platforms Supported:**
- ✅ Google Colab
- ✅ Kaggle Notebooks

## 1. Setup Environment & Platform Detection

In [None]:
# Detect platform
import os
import sys
from pathlib import Path

# Platform detection
IS_COLAB = 'COLAB_GPU' in os.environ or 'google.colab' in sys.modules
IS_KAGGLE = 'KAGGLE_KERNEL_RUN_TYPE' in os.environ

print("="*70)
print("PLATFORM DETECTION")
print("="*70)
if IS_COLAB:
    print("✓ Running on Google Colab")
    PLATFORM = 'colab'
elif IS_KAGGLE:
    print("✓ Running on Kaggle")
    PLATFORM = 'kaggle'
else:
    print("✓ Running locally")
    PLATFORM = 'local'
print("="*70)


PLATFORM DETECTION
✓ Running locally


In [2]:
# Check GPU availability
import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")


PyTorch version: 2.9.0+cpu
CUDA available: False


In [3]:
# Install required packages (if needed)
if PLATFORM == 'colab':
    !pip install albumentations==1.3.1 -q
    !pip install pyyaml -q
    print("✓ Packages installed for Colab")
elif PLATFORM == 'kaggle':
    # Kaggle has most packages pre-installed, install only if needed
    try:
        import albumentations
        print(f"✓ Albumentations version: {albumentations.__version__}")
    except ImportError:
        !pip install albumentations==1.3.1 -q
        print("✓ Installed albumentations")

## 2. Setup Paths (Platform-Specific)

### For Google Colab:
1. Automatically clones repository from GitHub to `/content/Fetal-Head-Segmentation`
2. Outputs saved to `/content/outputs/` (writable, lost after session)

### For Kaggle:
1. Upload your project as a Kaggle Dataset named `fetal-head-segmentation`
2. Add it as input to your notebook
3. Outputs saved to `/kaggle/working/` (writable, downloadable)

### For Local:
1. Automatically detects project root by finding `accuracy_focus` folder
2. Outputs saved to project structure as configured


In [None]:
# Platform-specific path configuration
if PLATFORM == 'colab':
    print("\n[Google Colab Setup]")
    
    # Clone repository if not already present
    repo_path = Path('/content/Fetal-Head-Segmentation')
    if not repo_path.exists():
        print("Cloning repository from GitHub...")
        !git clone https://github.com/TrinhThaiSonDHQT/Fetal-Head-Segmentation.git /content/Fetal-Head-Segmentation
        print("✓ Repository cloned successfully")
    else:
        print("✓ Repository already exists")
    
    PROJECT_PATH = repo_path
    OUTPUT_PATH = Path('/content/outputs')
    CACHE_ROOT = OUTPUT_PATH / 'cache'
    
    print(f"✓ Project path: {PROJECT_PATH}")
    print(f"✓ Output path: {OUTPUT_PATH}")

elif PLATFORM == 'kaggle':
    print("\n[Kaggle Setup]")
    PROJECT_PATH = Path('/kaggle/input/fetal-head-segmentation')
    OUTPUT_PATH = Path('/kaggle/working')
    CACHE_ROOT = OUTPUT_PATH / 'cache'
    
    if not PROJECT_PATH.exists():
        # Fallback: look for any mounted dataset
        input_datasets = os.listdir('/kaggle/input')
        if input_datasets:
            PROJECT_PATH = Path(f'/kaggle/input/{input_datasets[0]}')
            print(f"⚠️ Using available dataset: {input_datasets[0]}")
        else:
            raise FileNotFoundError(
                "No datasets found in /kaggle/input/\n"
                "Please add the 'fetal-head-segmentation' dataset to your Kaggle notebook."
            )
    
    if not (PROJECT_PATH / 'accuracy_focus').exists():
        raise RuntimeError(
            f"'accuracy_focus' folder not found in {PROJECT_PATH}\n"
            f"Please ensure your dataset structure is correct."
        )
    
    print(f"✓ Project path (read-only): {PROJECT_PATH}")
    print(f"✓ Output path: {OUTPUT_PATH}")

else:  # local
    print("\n[Local Setup]")
    current = Path(os.getcwd())
    PROJECT_PATH = None
    
    # Find project root by looking for 'accuracy_focus' folder
    for parent in [current] + list(current.parents):
        if (parent / 'accuracy_focus').exists():
            PROJECT_PATH = parent
            break
    
    if PROJECT_PATH is None:
        raise RuntimeError(
            f"Cannot find project root with 'accuracy_focus' folder.\n"
            f"Current directory: {os.getcwd()}"
        )
    
    OUTPUT_PATH = PROJECT_PATH / 'accuracy_focus' / 'improved_unet'
    CACHE_ROOT = OUTPUT_PATH / 'cache'
    
    print(f"✓ Project path: {PROJECT_PATH}")

# Add project to Python path
sys.path.insert(0, str(PROJECT_PATH))

print(f"\nProject files: {os.listdir(PROJECT_PATH)[:10]}")  # Show first 10 files
print(f"\n✓ Environment setup complete")


✓ Project path: e:\Fetal Head Segmentation\notebooks

Project files: ['01_data_exploration.ipynb', '02_training_experiments.ipynb', '03_results_analysis.ipynb', 'FHS_Accuracy_Focus.ipynb', 'FHS_Accuracy_Focus_Universal.ipynb', 'rebuild_cache.ipynb', 'results.txt']


## 3. Import Required Libraries

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Import from shared modules
from shared.configs.config_loader import load_config
from shared.src.utils import get_transforms, get_optimizer
from shared.src.utils.train import train_one_epoch, evaluate_model

# Import model and loss from improved_unet package
from accuracy_focus.improved_unet.src.models import ImprovedUNet
from accuracy_focus.improved_unet.src.losses import DiceBCELoss

print("✓ All imports successful")


## 4. Load Configuration

In [None]:
# Load configuration from improved_unet configs
config_path = PROJECT_PATH / 'accuracy_focus' / 'improved_unet' / 'configs' / 'improved_unet_config.yaml'
config = load_config(str(config_path))

# Extract config values
data_cfg = config['data']
model_cfg = config['model']
train_cfg = config['training']
aug_cfg = config['augmentation']
checkpoint_cfg = config['checkpoint']
logging_cfg = config['logging']

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Helper function to get paths
def get_data_path(relative_path):
    """Convert relative path to absolute path based on platform"""
    return str(PROJECT_PATH / relative_path)

# Create output directories based on platform
if PLATFORM in ['colab', 'kaggle']:
    # Colab/Kaggle: save to writable directories
    checkpoint_dir = OUTPUT_PATH / 'results' / 'checkpoints'
    log_dir = OUTPUT_PATH / 'results' / 'logs'
    prediction_dir = OUTPUT_PATH / 'results' / 'predictions'
    visualization_dir = OUTPUT_PATH / 'results' / 'visualizations'
    
    print(f"Adjusting paths for {PLATFORM.upper()} environment...")
    print(f"  Outputs will be saved to: {OUTPUT_PATH / 'results'}")
else:
    # Local: use config paths (pointing to improved_unet folder)
    checkpoint_dir = PROJECT_PATH / 'accuracy_focus' / 'improved_unet' / checkpoint_cfg['save_dir']
    log_dir = PROJECT_PATH / 'accuracy_focus' / 'improved_unet' / logging_cfg['log_dir']
    prediction_dir = PROJECT_PATH / 'accuracy_focus' / 'improved_unet' / logging_cfg['prediction_dir']
    visualization_dir = PROJECT_PATH / 'accuracy_focus' / 'improved_unet' / logging_cfg['visualization_dir']

os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
if logging_cfg.get('save_predictions', False):
    os.makedirs(prediction_dir, exist_ok=True)
    os.makedirs(visualization_dir, exist_ok=True)

print("="*70)
print("FETAL HEAD SEGMENTATION - IMPROVED U-NET TRAINING")
print("="*70)
print(f"Platform: {PLATFORM.upper()}")
print(f"Device: {device}")
print(f"Batch Size: {train_cfg['batch_size']}")
print(f"Learning Rate: {train_cfg['optimizer']['lr']}")
print(f"Number of Epochs: {train_cfg['num_epochs']}")
print(f"Checkpoints: {checkpoint_dir}")
print("="*70)


## 5. Prepare Datasets and Data Loaders

In [None]:
print("[1/4] Loading datasets...\n")

# Get image size from config
img_size = aug_cfg['preprocessing']['image_size'][0]  # Assuming square images

# Create augmentation transforms
print("Creating augmentation transforms...")
train_transforms = get_transforms(img_size, img_size, is_train=True)
val_transforms = get_transforms(img_size, img_size, is_train=False)
print("  Train transform: WITH augmentation (HorizontalFlip, Rotation, ShiftScaleRotate)")
print("  Val transform: WITHOUT augmentation (resize + normalize only)")

# Create datasets - use standard HC18Dataset for on-the-fly augmentation
print("\nCreating training dataset...")
train_dataset = HC18Dataset(
    get_data_path(data_cfg['train_images']), 
    get_data_path(data_cfg['train_masks']), 
    transform=train_transforms
)

print("Creating validation dataset...")
val_dataset = HC18Dataset(
    get_data_path(data_cfg['val_images']), 
    get_data_path(data_cfg['val_masks']), 
    transform=val_transforms
)

print(f"\n{'='*70}")
print(f"Datasets Ready:")
print(f"{'='*70}")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Test samples available at: {get_data_path(data_cfg['test_images'])}")
print(f"  Image size: {img_size}×{img_size}")
print(f"  Normalization: Divide by 255.0")
print(f"  Train augmentation: ENABLED (applied on-the-fly)")
print(f"  Val augmentation: DISABLED")
print(f"{'='*70}\n")

# Adjust num_workers for Colab/Kaggle (avoid multiprocessing issues)
num_workers = 0 if PLATFORM in ['colab', 'kaggle'] else train_cfg['num_workers']
print(f"DataLoader settings:")
print(f"  num_workers: {num_workers} ({'disabled for Colab/Kaggle' if num_workers == 0 else 'local multi-threading'})")
print(f"  Batch size: {train_cfg['batch_size']}")

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=train_cfg['batch_size'], 
    shuffle=True, 
    num_workers=num_workers, 
    pin_memory=train_cfg['pin_memory']
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=train_cfg['batch_size'], 
    shuffle=False, 
    num_workers=num_workers, 
    pin_memory=train_cfg['pin_memory']
)

print(f"\n✓ Data loaders created")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")


## 6. Initialize Model, Loss, and Optimizer

In [None]:
print("[2/4] Initializing model...\n")

# Initialize Improved U-Net model
model = ImprovedUNet(
    in_channels=model_cfg['in_channels'], 
    out_channels=model_cfg['out_channels']
).to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# CRITICAL FIX: Loss function with weighted BCE for class imbalance
loss_cfg = train_cfg['loss']

# Calculate pos_weight based on foreground ratio (default 200 for ~0.4% foreground)
# pos_weight = background_pixels / foreground_pixels
foreground_ratio = 0.004  # ~0.4% from HC18 dataset
pos_weight = (1 - foreground_ratio) / foreground_ratio
print(f"\n⚠️  CRITICAL: Class imbalance detected!")
print(f"   Foreground: {foreground_ratio*100:.2f}%, Background: {(1-foreground_ratio)*100:.2f}%")
print(f"   Using BCEWithLogitsLoss with pos_weight={pos_weight:.1f}")

# CRITICAL: Create pos_weight tensor on the SAME device as model
pos_weight_tensor = torch.tensor([pos_weight], device=device)

loss_fn = DiceBCELoss(
    dice_weight=loss_cfg.get('dice_weight', 0.8),  # Default 0.8 (prioritize Dice)
    bce_weight=loss_cfg.get('bce_weight', 0.2),    # Default 0.2
    pos_weight=pos_weight_tensor  # CRITICAL: Must be on same device as model
)
print(f"Loss: DiceBCELoss (dice_weight={loss_cfg.get('dice_weight', 0.8)}, bce_weight={loss_cfg.get('bce_weight', 0.2)})")
print(f"      BCE uses pos_weight={pos_weight:.1f} on device={device}")

# Optimizer - supports both SGD and Adam
optimizer_cfg = train_cfg['optimizer']
optimizer = get_optimizer(model, optimizer_cfg)

# Learning rate scheduler
scheduler_cfg = train_cfg['scheduler']
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode=scheduler_cfg['mode'],
    factor=scheduler_cfg['factor'],
    patience=scheduler_cfg['patience'],
    min_lr=scheduler_cfg['min_lr']
)

## 7. Training Loop

In [None]:
print("[3/4] Starting training...")
print("="*70)

best_dice = 0.0
early_stopping_patience = train_cfg.get('early_stopping_patience', 15)  # Default: 15 epochs
early_stopping_counter = 0
early_stopped = False

print(f"Early stopping enabled with patience: {early_stopping_patience} epochs\n")

for epoch in range(1, train_cfg['num_epochs'] + 1):
    print(f"\nEpoch {epoch}/{train_cfg['num_epochs']}")
    print("-" * 70)
    
    # Train for one epoch
    train_loss, train_dice = train_one_epoch(
        train_loader, model, optimizer, loss_fn, device, epoch
    )
    
    # Evaluate on validation set
    val_metrics = evaluate_model(val_loader, model, loss_fn, device)
    
    # Print metrics
    print(f"Train Loss: {train_loss:.4f} | Train Dice: {train_dice:.4f}")
    print(f"Val Loss: {val_metrics['loss']:.4f} | Val Dice: {val_metrics['dice']:.4f}")
    print(f"Val mIoU: {val_metrics['miou']:.4f} | Val mPA: {val_metrics['pixel_accuracy']:.4f}")
    
    # Update learning rate based on validation Dice
    current_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_metrics['dice'])
    new_lr = optimizer.param_groups[0]['lr']
    if new_lr != current_lr:
        print(f"Learning rate reduced: {current_lr:.6f} → {new_lr:.6f}")
    
    # Check for improvement
    if val_metrics['dice'] > best_dice:
        best_dice = val_metrics['dice']
        early_stopping_counter = 0
        
        # Save best model
        if checkpoint_cfg['save_best']:
            save_path = checkpoint_dir / 'best_model.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_dice': best_dice,
                'val_metrics': val_metrics,
                'config': config
            }, str(save_path))
            print(f"✓ Saved best model with Dice: {best_dice:.4f}")
    else:
        early_stopping_counter += 1
        print(f"Early stopping counter: {early_stopping_counter}/{early_stopping_patience}")
        
        # Check if early stopping should trigger
        if early_stopping_counter >= early_stopping_patience:
            print(f"\n⚠️ Early stopping triggered! No improvement for {early_stopping_patience} epochs.")
            early_stopped = True
    
    # Save last checkpoint
    if checkpoint_cfg['save_last']:
        save_path = checkpoint_dir / 'last_model.pth'
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_metrics': val_metrics,
            'config': config
        }, str(save_path))
    
    # Save checkpoint every 10 epochs
    if epoch % 10 == 0:
        save_path = checkpoint_dir / f'checkpoint_epoch_{epoch}.pth'
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'config': config
        }, str(save_path))
        print(f"✓ Saved checkpoint at epoch {epoch}")
    
    # Break if early stopping triggered
    if early_stopped:
        break


## 8. Training Summary

In [None]:
print("\n" + "="*70)
print("TRAINING COMPLETED!")
print("="*70)
print(f"Best Dice Score: {best_dice:.4f} ({best_dice*100:.2f}%)")
print(f"Best model saved at: {checkpoint_dir / 'best_model.pth'}")

# Print early stopping info
if early_stopped:
    print(f"\n⚠️ Training stopped early at epoch {epoch} (no improvement for {early_stopping_patience} epochs)")
else:
    print(f"\n✓ Completed all {train_cfg['num_epochs']} epochs")

# Print target metrics comparison
target_metrics = config.get('target_metrics', {})
if target_metrics:
    print("\nTarget Performance Metrics:")
    print(f"  Target Dice: {target_metrics.get('dice', 0)*100:.2f}% | Achieved: {best_dice*100:.2f}%")
    
    if best_dice >= target_metrics.get('dice', 0):
        print("\n🎉 Target Dice score achieved!")
    else:
        print(f"\n⚠️ Target not reached. Gap: {(target_metrics.get('dice', 0) - best_dice)*100:.2f}%")

print("="*70)


## 9. Download Trained Model

### Google Colab: 
- Downloads directly to your browser

### Kaggle: 
- Files saved to `/kaggle/working/results/` (auto-downloaded when notebook finishes)

### Local:
- Files saved to `accuracy_focus/improved_unet/results/`


In [None]:
best_model_path = checkpoint_dir / 'best_model.pth'

if best_model_path.exists():
    if PLATFORM == 'colab':
        from google.colab import files
        files.download(str(best_model_path))
        print(f"✓ Downloaded: {best_model_path}")
    elif PLATFORM == 'kaggle':
        print(f"✓ Model saved at: {best_model_path}")
        print(f"✓ Files in /kaggle/working will be auto-downloaded when notebook completes")
        print(f"\nSaved files:")
        for f in os.listdir(checkpoint_dir):
            print(f"  - {f}")
    else:
        print(f"✓ Model saved at: {best_model_path}")
else:
    print("⚠️ Best model not found!")


## 10. Visualize Sample Predictions

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Set model to evaluation mode
model.eval()

# Get a batch from validation set
with torch.no_grad():
    images, masks = next(iter(val_loader))
    images, masks = images.to(device), masks.to(device)
    
    # CRITICAL FIX: Model outputs logits, apply sigmoid for visualization
    logits = model(images)
    probs = torch.sigmoid(logits)  # Convert logits to probabilities [0, 1]
    preds = (probs > 0.5).float()  # Threshold at 0.5

# Visualize first 4 samples
num_samples = min(4, images.shape[0])
fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))

for i in range(num_samples):
    # Original image
    img = images[i].cpu().squeeze().numpy()
    axes[i, 0].imshow(img, cmap='gray')
    axes[i, 0].set_title('Input Image')
    axes[i, 0].axis('off')
    
    # Ground truth mask
    mask = masks[i].cpu().squeeze().numpy()
    axes[i, 1].imshow(mask, cmap='gray')
    axes[i, 1].set_title('Ground Truth')
    axes[i, 1].axis('off')
    
    # Probability heatmap (NEW: show model confidence)
    prob = probs[i].cpu().squeeze().numpy()
    axes[i, 2].imshow(prob, cmap='jet', vmin=0, vmax=1)
    axes[i, 2].set_title('Prediction Probability')
    axes[i, 2].axis('off')
    
    # Binary prediction
    pred = preds[i].cpu().squeeze().numpy()
    axes[i, 3].imshow(pred, cmap='gray')
    
    # Calculate Dice for this sample
    dice = (2 * (pred * mask).sum()) / (pred.sum() + mask.sum() + 1e-6)
    axes[i, 3].set_title(f'Binary Prediction (Dice: {dice:.3f})')
    axes[i, 3].axis('off')

plt.tight_layout()
plt.show()

print("✓ Visualization complete")
print("  Note: Model outputs logits → sigmoid applied for visualization")