# Notebook 3: Model Training and Benchmarking

### Objectives:
1.  **Set up the Training Environment:** Configure hyperparameters, data loaders, models, and loss functions.
2.  **Train and Benchmark Models:** Train our three candidate models (`BaselineUNet`, `ResNetUNet`, `TransUNet`) on the processed dataset.
3.  **Evaluate Performance:** Analyze training progress with loss curves and evaluate final models on the unseen test set using the Dice score.
4.  **Visualize Results:** Create high-quality visualizations of model predictions for the final report and presentation.

## 1. Setup, Imports, and Path Definitions

We begin by importing all necessary libraries and defining the project's directory structure. This ensures our environment is correctly configured.

In [1]:
import os
import sys  # Import the sys module to manipulate Python's path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
from glob import glob
import random

# --- Define and Add Root Directory to Path ---
# This is the crucial step to fix the ModuleNotFoundError.
# When running a notebook from the 'notebooks' folder, we need to tell Python
# where to find our 'src' module in the parent directory.
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), '..'))
if ROOT_DIR not in sys.path:
    sys.path.append(ROOT_DIR)
    print(f"Added Root Directory to Path: {ROOT_DIR}")

# --- Import custom models from src ---
# This will now work correctly
from src.models import BaselineUNet, ResNetUNet, TransUNet
print("Successfully imported custom models.")

# --- Define Project Directories ---
PROCESSED_DATA_DIR = os.path.join(ROOT_DIR, 'data', 'processed')
MODELS_DIR = os.path.join(ROOT_DIR, 'models')
FIGURES_DIR = os.path.join(ROOT_DIR, 'figures')

# Create directories if they don't exist
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(FIGURES_DIR, exist_ok=True)

print(f"Processed Data Directory: {PROCESSED_DATA_DIR}")
print(f"Saved Models Directory: {MODELS_DIR}")
print(f"Saved Figures Directory: {FIGURES_DIR}")

Added Root Directory to Path: D:\Coding\GitHub\MRI-Tumor-Segmentation
Successfully imported custom models.
Processed Data Directory: D:\Coding\GitHub\MRI-Tumor-Segmentation\data\processed
Saved Models Directory: D:\Coding\GitHub\MRI-Tumor-Segmentation\models
Saved Figures Directory: D:\Coding\GitHub\MRI-Tumor-Segmentation\figures


## 2. Configuration & Hyperparameters

This cell contains all the key parameters for our training run. To benchmark a different model, simply change the `MODEL_TO_TRAIN` variable and re-run the notebook.

In [2]:
# --- CONFIGURATION & HYPERPARAMETERS ---

# Choose which model to train: 'BaselineUNet', 'ResNetUNet', or 'TransUNet'
MODEL_TO_TRAIN = 'ResNetUNet'

# Data parameters
TEST_SPLIT_SIZE = 0.15
VALIDATION_SPLIT_SIZE = 0.15 # From the remaining data after test split
RANDOM_STATE = 42

# Training parameters
LEARNING_RATE = 1e-4
BATCH_SIZE = 16
NUM_EPOCHS = 25 # Start with a reasonable number, can be increased

# System parameters
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"DEVICE: {DEVICE}")
NUM_WORKERS = 0 # os.cpu_count()
PIN_MEMORY = True

# Ensure reproducibility
random.seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_STATE)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE: cuda


## 3. Custom PyTorch Dataset

We define a custom `Dataset` class. This is the standard PyTorch way to handle data. It loads data efficiently from disk slice by slice, which is essential for managing memory when working with large datasets.

In [3]:
import collections

class BrainMRIDataset(Dataset):
    def __init__(self, file_list, cache_size=8000): # Cache can hold 8000 items
        self.file_list = file_list
        self.cache_size = cache_size
        # Using an OrderedDict to implement a simple LRU (Least Recently Used) cache
        self.cache = collections.OrderedDict()
        
        print(f"Initialized dataset with {len(file_list)} files.")
        print(f"Cache size set to {cache_size} items.")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        # --- Check if the item is in the cache ---
        if idx in self.cache:
            # Move the accessed item to the end to mark it as recently used
            self.cache.move_to_end(idx)
            return self.cache[idx]
        
        # --- If not in cache (cache miss), load from disk ---
        image_path = self.file_list[idx]
        mask_path = image_path.replace("_image.npy", "_mask.npy")
        
        image = np.load(image_path).astype(np.float32)
        mask = np.load(mask_path).astype(np.float32)

        image_tensor = torch.from_numpy(image).permute(2, 0, 1)
        mask_tensor = torch.from_numpy(mask).unsqueeze(0)
        
        # --- Manage cache size ---
        if len(self.cache) >= self.cache_size:
            # Pop the least recently used item (the first item in the dict)
            self.cache.popitem(last=False)
        
        # Add the newly loaded item to the cache
        self.cache[idx] = (image_tensor, mask_tensor)
        
        return image_tensor, mask_tensor

## 4. Data Splitting and Loaders

We split our list of processed slices into training, validation, and test sets. This is a critical step to ensure we can evaluate our model's generalization ability on unseen data. Then, we wrap these sets in `DataLoader` objects, which handle batching and shuffling automatically.

In [4]:
# Get all image file paths
all_files = glob(os.path.join(PROCESSED_DATA_DIR, "*_image.npy"))

# First, split into training+validation and test sets
train_val_files, test_files = train_test_split(
    all_files, test_size=TEST_SPLIT_SIZE, random_state=RANDOM_STATE
)

# Then, split the training+validation set into training and validation sets
relative_val_size = VALIDATION_SPLIT_SIZE / (1 - TEST_SPLIT_SIZE)
train_files, val_files = train_test_split(
    train_val_files, test_size=relative_val_size, random_state=RANDOM_STATE
)

# Create Dataset objects
train_dataset = BrainMRIDataset(train_files)
val_dataset = BrainMRIDataset(val_files)
test_dataset = BrainMRIDataset(test_files)

# Create DataLoader objects
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

# --- Print a summary of the data splitting ---
print("--- Data Splitting Summary ---")
print(f"Total slices: {len(all_files)}")
print(f"Training set:   {len(train_files)} slices ({len(train_files)/len(all_files):.1%})")
print(f"Validation set: {len(val_files)} slices ({len(val_files)/len(all_files):.1%})")
print(f"Test set:       {len(test_files)} slices ({len(test_files)/len(all_files):.1%})")
print("------------------------------")

Initialized dataset with 7484 files.
Cache size set to 8000 items.
Initialized dataset with 1604 files.
Cache size set to 8000 items.
Initialized dataset with 1604 files.
Cache size set to 8000 items.
--- Data Splitting Summary ---
Total slices: 10692
Training set:   7484 slices (70.0%)
Validation set: 1604 slices (15.0%)
Test set:       1604 slices (15.0%)
------------------------------


## 5. Loss Function and Evaluation Metric

We define our loss function, which is what the model tries to minimize. For segmentation, a combination of Binary Cross-Entropy (BCE) and Dice Loss is a powerful choice. 
- **BCE Loss:** Good for pixel-wise classification.
- **Dice Loss:** Directly optimizes the Dice coefficient, our main evaluation metric, which is excellent for handling class imbalance (more non-tumor pixels than tumor pixels).

In [5]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        y_pred = y_pred.contiguous()
        y_true = y_true.contiguous()

        intersection = (y_pred * y_true).sum(dim=2).sum(dim=2)
        dice = (2. * intersection + self.smooth) / (y_pred.sum(dim=2).sum(dim=2) + y_true.sum(dim=2).sum(dim=2) + self.smooth)
        
        return 1 - dice.mean()

def dice_coefficient(y_pred, y_true, smooth=1e-6):
    y_pred = (y_pred > 0.5).float() # Binarize the prediction
    intersection = (y_pred * y_true).sum()
    dice = (2. * intersection + smooth) / (y_pred.sum() + y_true.sum() + smooth)
    return dice.item()

# Combined Loss Function
def combined_loss(y_pred, y_true):
    bce = nn.BCEWithLogitsLoss()(y_pred, y_true) # BCEWithLogitsLoss is more stable
    dice = DiceLoss()(torch.sigmoid(y_pred), y_true)
    return bce + dice

## 6. Training & Validation Loop

This is the core of our notebook. We define two functions:
1. `train_fn`: Handles one epoch of training. It iterates through the `train_loader`, performs forward and backward passes, and updates the model's weights.
2. `evaluate_fn`: Handles model evaluation on the validation or test set. It calculates the loss and Dice score without updating the model.

The main loop then orchestrates the training over multiple epochs, tracks performance, and saves the best model based on validation Dice score.

In [6]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader, desc="Training")
    running_loss = 0.0

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.to(device=DEVICE)

        # Forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # Backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())
    
    return running_loss / len(loader)

def evaluate_fn(loader, model, loss_fn, device):
    model.eval()
    loop = tqdm(loader, desc="Evaluating")
    val_loss = 0.0
    val_dice = 0.0

    with torch.no_grad():
        for data, targets in loop:
            data = data.to(device=device)
            targets = targets.to(device=device)
            
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            val_loss += loss.item()
            
            # Calculate dice score
            preds_sigmoid = torch.sigmoid(predictions)
            dice = dice_coefficient(preds_sigmoid, targets)
            val_dice += dice
            loop.set_postfix(dice=dice)
            
    model.train()
    return val_loss / len(loader), val_dice / len(loader)

## 7. Run the Training

Now, we bring everything together. We select the model based on our configuration, initialize the optimizer, and start the main training loop. We will save the model with the best validation Dice score.

In [None]:
# --- Model Selection ---
if MODEL_TO_TRAIN == 'BaselineUNet':
    model = BaselineUNet(in_channels=4, out_channels=1).to(DEVICE)
elif MODEL_TO_TRAIN == 'ResNetUNet':
    model = ResNetUNet(in_channels=4, out_channels=1).to(DEVICE)
elif MODEL_TO_TRAIN == 'TransUNet':
    model = TransUNet(in_channels=4, out_channels=1).to(DEVICE)
else:
    raise ValueError(f"Unknown model: {MODEL_TO_TRAIN}")

# --- Optimizer, Loss, and Scaler ---
loss_fn = combined_loss # Our custom combined loss
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# THE FIX IS HERE: Corrected class name and enabled only for CUDA
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))

# --- Training History and Best Model State ---
history = {'train_loss': [], 'val_loss': [], 'val_dice': []}
best_val_dice = -1.0
model_save_path = os.path.join(MODELS_DIR, f'{MODEL_TO_TRAIN}_best_model.pth')

# --- Print Training Configuration ---
print("--- Training Run Initiated ---")
print(f"Model: {MODEL_TO_TRAIN}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Device: {DEVICE}")
print("------------------------------")

# --- Main Training Loop ---
for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    train_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler)
    val_loss, val_dice = evaluate_fn(val_loader, model, loss_fn, DEVICE)

    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}")

    # Save history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_dice'].append(val_dice)

    # Save the best model
    if val_dice > best_val_dice:
        print(f"Validation Dice score improved from {best_val_dice:.4f} to {val_dice:.4f}. Saving model to {model_save_path}")
        best_val_dice = val_dice
        torch.save(model.state_dict(), model_save_path)

print("\n--- Training Finished ---")

  scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))


--- Training Run Initiated ---
Model: ResNetUNet
Epochs: 25
Batch Size: 16
Learning Rate: 0.0001
Device: cuda
------------------------------

Epoch 1/25


  with torch.cuda.amp.autocast():
Training: 100%|███████████████████████████████████████| 468/468 [14:07<00:00,  1.81s/it, loss=1.53]
Evaluating: 100%|████████████████████████████████████| 101/101 [03:39<00:00,  2.17s/it, dice=0.126]


Train Loss: 1.5791 | Val Loss: 1.4839 | Val Dice: 0.1432
Validation Dice score improved from -1.0000 to 0.1432. Saving model to D:\Coding\GitHub\MRI-Tumor-Segmentation\models\ResNetUNet_best_model.pth

Epoch 2/25


Training: 100%|███████████████████████████████████████| 468/468 [02:40<00:00,  2.91it/s, loss=1.48]
Evaluating: 100%|████████████████████████████████████| 101/101 [00:08<00:00, 11.95it/s, dice=0.126]


Train Loss: 1.4795 | Val Loss: 1.4703 | Val Dice: 0.1432

Epoch 3/25


Training: 100%|███████████████████████████████████████| 468/468 [01:03<00:00,  7.36it/s, loss=1.44]
Evaluating: 100%|████████████████████████████████████| 101/101 [00:07<00:00, 13.20it/s, dice=0.126]


Train Loss: 1.4726 | Val Loss: 1.4668 | Val Dice: 0.1432

Epoch 4/25


Training: 100%|███████████████████████████████████████| 468/468 [00:58<00:00,  8.03it/s, loss=1.51]
Evaluating: 100%|████████████████████████████████████| 101/101 [00:07<00:00, 12.73it/s, dice=0.126]


Train Loss: 1.4704 | Val Loss: 1.4651 | Val Dice: 0.1432

Epoch 5/25


Training: 100%|███████████████████████████████████████| 468/468 [00:54<00:00,  8.54it/s, loss=1.48]
Evaluating: 100%|████████████████████████████████████| 101/101 [00:07<00:00, 13.66it/s, dice=0.126]


Train Loss: 1.4692 | Val Loss: 1.4641 | Val Dice: 0.1432

Epoch 6/25


Training: 100%|███████████████████████████████████████| 468/468 [00:55<00:00,  8.43it/s, loss=1.55]
Evaluating: 100%|████████████████████████████████████| 101/101 [00:07<00:00, 13.61it/s, dice=0.126]


Train Loss: 1.4685 | Val Loss: 1.4638 | Val Dice: 0.1432
Validation Dice score improved from 0.1432 to 0.1432. Saving model to D:\Coding\GitHub\MRI-Tumor-Segmentation\models\ResNetUNet_best_model.pth

Epoch 7/25


Training: 100%|███████████████████████████████████████| 468/468 [00:54<00:00,  8.59it/s, loss=1.54]
Evaluating: 100%|████████████████████████████████████| 101/101 [00:07<00:00, 13.47it/s, dice=0.126]


Train Loss: 1.4681 | Val Loss: 1.4634 | Val Dice: 0.1432
Validation Dice score improved from 0.1432 to 0.1432. Saving model to D:\Coding\GitHub\MRI-Tumor-Segmentation\models\ResNetUNet_best_model.pth

Epoch 8/25


Training: 100%|███████████████████████████████████████| 468/468 [00:55<00:00,  8.38it/s, loss=1.45]
Evaluating: 100%|████████████████████████████████████| 101/101 [00:07<00:00, 12.91it/s, dice=0.126]


Train Loss: 1.4678 | Val Loss: 1.4633 | Val Dice: 0.1432
Validation Dice score improved from 0.1432 to 0.1432. Saving model to D:\Coding\GitHub\MRI-Tumor-Segmentation\models\ResNetUNet_best_model.pth

Epoch 9/25


Training:  18%|███████▎                                | 85/468 [00:10<00:45,  8.40it/s, loss=1.46]

## 8. Visualize Training Progress

Plotting the training and validation metrics over time is crucial for understanding how our model learned. It helps us spot issues like overfitting and determine if the model trained for a sufficient number of epochs. These plots are essential for our final presentation.

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
fig.suptitle(f'Training History for {MODEL_TO_TRAIN}', fontsize=16)

# Plotting Loss
ax1.plot(history['train_loss'], label='Training Loss')
ax1.plot(history['val_loss'], label='Validation Loss')
ax1.set_title('Loss vs. Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

# Plotting Validation Dice Score
ax2.plot(history['val_dice'], label='Validation Dice Score', color='green')
ax2.set_title('Validation Dice Score vs. Epochs')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Dice Score')
ax2.legend()
ax2.grid(True)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])

# Save the figure
figure_path = os.path.join(FIGURES_DIR, f'{MODEL_TO_TRAIN}_training_history.png')
plt.savefig(figure_path, dpi=300)
print(f"Training history plot saved to: {figure_path}")
plt.show()

## 9. Final Evaluation on Test Set

Now for the ultimate test. We load our best-performing model (saved during training) and evaluate its performance on the completely held-out test set. This gives us the final, unbiased measure of our model's quality.

In [None]:
# Load the best model
model.load_state_dict(torch.load(model_save_path))

print(f"--- Evaluating Best {MODEL_TO_TRAIN} on Test Set ---")
test_loss, test_dice = evaluate_fn(test_loader, model, loss_fn, DEVICE)

print("\n--- Final Test Set Performance ---")
print(f"Model: {MODEL_TO_TRAIN}")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Dice Score: {test_dice:.4f}")
print("------------------------------------");

## 10. Visualize Predictions

Numbers are great, but visualizations are powerful. Here we take a few random samples from our test set and create a side-by-side comparison of the input image, the ground truth mask, and our model's prediction. This provides clear, intuitive proof of our model's capabilities.

In [None]:
num_samples = 5
model.eval()

fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
fig.suptitle(f'Sample Predictions for {MODEL_TO_TRAIN} on Test Set', fontsize=20, y=1.02)

with torch.no_grad():
    for i in range(num_samples):
        # Get a random sample from the test set
        idx = random.randint(0, len(test_dataset) - 1)
        image, mask = test_dataset[idx]
        image_gpu = image.unsqueeze(0).to(DEVICE)
        
        # Get model prediction
        pred_mask = model(image_gpu)
        pred_mask = (torch.sigmoid(pred_mask) > 0.5).float().cpu().squeeze(0)
        
        # Prepare for plotting (C, H, W) -> (H, W, C) for image
        image = image.permute(1, 2, 0).numpy()
        mask = mask.squeeze(0).numpy()
        pred_mask = pred_mask.squeeze(0).numpy()
        
        # We'll visualize one of the MRI channels, e.g., T1c (channel 0)
        display_image = image[:, :, 0]
        
        # Plot Input Image
        axes[i, 0].imshow(display_image, cmap='bone')
        axes[i, 0].set_title(f'Sample {i+1}: Input (T1c)')
        axes[i, 0].axis('off')
        
        # Plot Ground Truth Mask
        axes[i, 1].imshow(mask, cmap='magma')
        axes[i, 1].set_title(f'Sample {i+1}: Ground Truth')
        axes[i, 1].axis('off')
        
        # Plot Predicted Mask
        axes[i, 2].imshow(pred_mask, cmap='magma')
        axes[i, 2].set_title(f'Sample {i+1}: Prediction')
        axes[i, 2].axis('off')

plt.tight_layout()

# Save the figure
figure_path = os.path.join(FIGURES_DIR, f'{MODEL_TO_TRAIN}_test_predictions.png')
plt.savefig(figure_path, dpi=300, bbox_inches='tight')
print(f"\nPrediction visualization saved to: {figure_path}")
plt.show()

## End of Notebook 3

This concludes the model training and benchmarking phase. We have successfully:
- Trained one of our selected model architectures.
- Saved the best performing weights.
- Plotted the training history to analyze its learning behavior.
- Evaluated its final performance on the unseen test set.
- Visualized its segmentation capabilities on sample images.

To benchmark the other models, simply change the `MODEL_TO_TRAIN` variable at the top of this notebook and re-run all cells. Once all models are benchmarked, we will be ready to proceed to **Notebook 4: Inference and App Preparation**.