# Training Experiments - Quick Iteration & Testing

This notebook is designed for rapid prototyping and hyperparameter experimentation.

**Purpose:**
- Quick training runs (5-10 epochs) for fast iteration
- Test model architecture changes
- Experiment with hyperparameters
- Visualize training progress in real-time

**For full 100-epoch training, use `main.py` script.**

## 1. Import Libraries

In [None]:
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

# Add project root to path
sys.path.append('..')

from src.data import HC18Dataset
from src.models import ImprovedUNet
from src.losses import DiceBCELoss
from src.utils import get_transforms
from train import train_one_epoch, evaluate_model

# Check device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Quick Experiment Configuration

In [None]:
# === EXPERIMENT SETTINGS (Adjust for quick tests) ===
IMG_HEIGHT = 256
IMG_WIDTH = 256
BATCH_SIZE = 8
NUM_EPOCHS = 5  # Quick test (change to 10 for more thorough tests)
LEARNING_RATE = 0.001  # Lower LR for quick experiments
SUBSET_SIZE = 100  # Use subset for faster iteration (set to None for full dataset)

# Data paths
TRAIN_IMG_DIR = '../dataset/training_set/images'
TRAIN_MASK_DIR = '../dataset/training_set/masks'
VAL_IMG_DIR = '../dataset/test_set/images'
VAL_MASK_DIR = '../dataset/test_set/masks'

print("="*60)
print("EXPERIMENT CONFIGURATION")
print("="*60)
print(f"Image Size:      {IMG_HEIGHT}×{IMG_WIDTH}")
print(f"Batch Size:      {BATCH_SIZE}")
print(f"Epochs:          {NUM_EPOCHS}")
print(f"Learning Rate:   {LEARNING_RATE}")
print(f"Subset Size:     {SUBSET_SIZE if SUBSET_SIZE else 'Full dataset'}")
print(f"Device:          {DEVICE}")
print("="*60)

## 3. Load Data

In [None]:
# Get transforms
train_transforms = get_transforms(IMG_HEIGHT, IMG_WIDTH, is_train=True)
val_transforms = get_transforms(IMG_HEIGHT, IMG_WIDTH, is_train=False)

# Create datasets
train_dataset_full = HC18Dataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, transform=train_transforms)
val_dataset_full = HC18Dataset(VAL_IMG_DIR, VAL_MASK_DIR, transform=val_transforms)

# Use subset for quick experiments
if SUBSET_SIZE:
    train_indices = np.random.choice(len(train_dataset_full), SUBSET_SIZE, replace=False)
    val_indices = np.random.choice(len(val_dataset_full), min(50, len(val_dataset_full)), replace=False)
    train_dataset = Subset(train_dataset_full, train_indices)
    val_dataset = Subset(val_dataset_full, val_indices)
else:
    train_dataset = train_dataset_full
    val_dataset = val_dataset_full

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                          num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                        num_workers=0, pin_memory=True)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 4. Initialize Model

In [None]:
# Initialize Improved U-Net
model = ImprovedUNet(in_channels=1, out_channels=1).to(DEVICE)

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("="*60)
print("MODEL ARCHITECTURE: Improved U-Net")
print("="*60)
print(f"Total parameters:      {total_params:,}")
print(f"Trainable parameters:  {trainable_params:,}")
print(f"Model size:            ~{total_params * 4 / (1024**2):.2f} MB (FP32)")
print("="*60)

# Loss and optimizer
loss_fn = DiceBCELoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=2
)

print(f"\nLoss Function: Dice + BCE")
print(f"Optimizer: Adam (LR={LEARNING_RATE})")
print(f"Scheduler: ReduceLROnPlateau (patience=2)")

## 5. Quick Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_dice': [],
    'val_loss': [],
    'val_dice': [],
    'val_miou': [],
    'val_pa': []
}

best_dice = 0.0

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\nEpoch {epoch}/{NUM_EPOCHS}")
    print("-" * 60)
    
    # Train
    train_loss, train_dice = train_one_epoch(
        train_loader, model, optimizer, loss_fn, DEVICE, epoch
    )
    
    # Validate
    val_metrics = evaluate_model(val_loader, model, loss_fn, DEVICE)
    
    # Store history
    history['train_loss'].append(train_loss)
    history['train_dice'].append(train_dice)
    history['val_loss'].append(val_metrics['loss'])
    history['val_dice'].append(val_metrics['dice'])
    history['val_miou'].append(val_metrics['miou'])
    history['val_pa'].append(val_metrics['pixel_accuracy'])
    
    # Print
    print(f"Train: Loss={train_loss:.4f}, Dice={train_dice:.4f}")
    print(f"Val:   Loss={val_metrics['loss']:.4f}, Dice={val_metrics['dice']:.4f}, "
          f"mIoU={val_metrics['miou']:.4f}, PA={val_metrics['pixel_accuracy']:.4f}")
    
    # Update scheduler
    scheduler.step(val_metrics['dice'])
    
    # Track best
    if val_metrics['dice'] > best_dice:
        best_dice = val_metrics['dice']
        print(f"✓ New best Dice: {best_dice:.4f}")

print("\n" + "="*60)
print(f"TRAINING COMPLETE - Best Dice: {best_dice:.4f}")
print("="*60)

## 6. Plot Training Curves

In [None]:
epochs_range = range(1, NUM_EPOCHS + 1)

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(epochs_range, history['train_loss'], 'o-', label='Train', linewidth=2)
axes[0, 0].plot(epochs_range, history['val_loss'], 's-', label='Val', linewidth=2)
axes[0, 0].set_title('Loss', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Dice Score
axes[0, 1].plot(epochs_range, history['train_dice'], 'o-', label='Train', linewidth=2)
axes[0, 1].plot(epochs_range, history['val_dice'], 's-', label='Val', linewidth=2)
axes[0, 1].set_title('Dice Score', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Dice Score')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# mIoU
axes[1, 0].plot(epochs_range, history['val_miou'], 'o-', color='green', linewidth=2)
axes[1, 0].set_title('Validation mIoU', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('mIoU')
axes[1, 0].grid(True, alpha=0.3)

# Pixel Accuracy
axes[1, 1].plot(epochs_range, history['val_pa'], 'o-', color='purple', linewidth=2)
axes[1, 1].set_title('Validation Pixel Accuracy', fontsize=14, fontweight='bold')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Pixel Accuracy')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Visualize Predictions

In [None]:
# Get predictions on validation samples
model.eval()
n_samples = 6

fig, axes = plt.subplots(n_samples, 4, figsize=(16, 4*n_samples))

with torch.no_grad():
    for i, (images, masks) in enumerate(val_loader):
        if i >= n_samples:
            break
            
        images = images.to(DEVICE)
        masks = masks.to(DEVICE)
        
        predictions = model(images)
        
        # Show first image from batch
        img = images[0].cpu().squeeze().numpy()
        mask = masks[0].cpu().squeeze().numpy()
        pred = predictions[0].cpu().squeeze().numpy()
        pred_binary = (pred > 0.5).astype(np.float32)
        
        # Error map
        error = np.abs(mask - pred_binary)
        
        axes[i, 0].imshow(img, cmap='gray')
        axes[i, 0].set_title('Input Image')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(mask, cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(pred_binary, cmap='gray')
        axes[i, 2].set_title('Prediction')
        axes[i, 2].axis('off')
        
        axes[i, 3].imshow(error, cmap='hot')
        axes[i, 3].set_title('Error Map')
        axes[i, 3].axis('off')

plt.tight_layout()
plt.suptitle('Predictions on Validation Set', y=1.001, fontsize=16, fontweight='bold')
plt.show()