In [1]:
# Brain AI Framework - Model Training Notebook
# 
# This notebook demonstrates the training process for the brain segmentation and abnormality detection models using MONAI framework. It covers:
# 1. Data loading and preparation
# 2. Model architecture setup
# 3. Training and validation
# 4. Model evaluation and saving
# 
# Author: Kishore Kumar Kalligive
# Date: March 31, 2025

# Import necessary libraries
import os
import sys
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# MONAI imports
import monai
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd,
    CropForegroundd, RandCropByPosNegLabeld, RandAffined, Spacingd,
    ToTensord, RandFlipd, RandRotate90d, RandScaleIntensityd, RandShiftIntensityd
)
from monai.networks.nets import UNet, DynUNet
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.losses import DiceLoss, DiceCELoss, FocalLoss
from monai.inferers import sliding_window_inference
from monai.data import list_data_collate, decollate_batch, Dataset, CacheDataset
from monai.visualize import plot_2d_or_3d_image

# Add project root to path
sys.path.append('..')

# Import project modules
from src.data.dataset import prepare_datalist
from src.models.segmentation import BrainSegmentationModel
from src.models.anomaly import AnomalyDetectionModel
from src.training.trainer import train_epoch, validate_epoch
from src.training.loss import CombinedLoss
from src.utils.metrics import calculate_metrics
import yaml

# Set random seeds for reproducibility
monai.utils.set_determinism(seed=42)

# Load configuration
with open("../config.yml", 'r') as file:
    config = yaml.safe_load(file)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 1. Data Loading and Preparation

# Load the dataset information
subjects_df = pd.read_csv(os.path.join(config['data']['clinical_dir'], 'ADNI_T1.csv'))
print(f"Total subjects: {len(subjects_df)}")

# Display the distribution of diagnostic groups
diagnosis_count = subjects_df['Diagnosis'].value_counts()
print("\nDiagnostic group distribution:")
print(diagnosis_count)

plt.figure(figsize=(8, 5))
diagnosis_count.plot(kind='bar')
plt.title('Subject Distribution by Diagnostic Group')
plt.xlabel('Diagnostic Group')
plt.ylabel('Count')
plt.tight_layout()
plt.show()

# Prepare data paths and labels
data_dir = config['data']['processed_dir']

# Create data dictionaries for MONAI
def create_data_dicts(subjects_df, data_dir):
    train_subjects, val_subjects, test_subjects = [], [], []
    
    # Split data by diagnosis to ensure balanced distribution across splits
    for diagnosis in ['CN', 'MCI', 'AD']:
        diag_subjects = subjects_df[subjects_df['Diagnosis'] == diagnosis]
        
        # Get subject IDs for each diagnosis
        subject_ids = diag_subjects['Subject_ID'].values
        np.random.shuffle(subject_ids)
        
        # Split: 70% train, 15% validation, 15% test
        n_subjects = len(subject_ids)
        n_train = int(0.7 * n_subjects)
        n_val = int(0.15 * n_subjects)
        
        train_ids = subject_ids[:n_train]
        val_ids = subject_ids[n_train:n_train+n_val]
        test_ids = subject_ids[n_train+n_val:]
        
        # Get corresponding subjects
        train_subjects.extend(diag_subjects[diag_subjects['Subject_ID'].isin(train_ids)].to_dict('records'))
        val_subjects.extend(diag_subjects[diag_subjects['Subject_ID'].isin(val_ids)].to_dict('records'))
        test_subjects.extend(diag_subjects[diag_subjects['Subject_ID'].isin(test_ids)].to_dict('records'))
    
    # Create dictionaries with image and segmentation paths
    train_files = [
        {"image": os.path.join(data_dir, f"{s['Subject_ID']}_t1.nii.gz"),
         "label": os.path.join(data_dir, f"{s['Subject_ID']}_seg.nii.gz"),
         "diagnosis": s['Diagnosis']} 
        for s in train_subjects
    ]
    
    val_files = [
        {"image": os.path.join(data_dir, f"{s['Subject_ID']}_t1.nii.gz"),
         "label": os.path.join(data_dir, f"{s['Subject_ID']}_seg.nii.gz"),
         "diagnosis": s['Diagnosis']} 
        for s in val_subjects
    ]
    
    test_files = [
        {"image": os.path.join(data_dir, f"{s['Subject_ID']}_t1.nii.gz"),
         "label": os.path.join(data_dir, f"{s['Subject_ID']}_seg.nii.gz"),
         "diagnosis": s['Diagnosis']} 
        for s in test_subjects
    ]
    
    return train_files, val_files, test_files

# Create data splits
train_files, val_files, test_files = create_data_dicts(subjects_df, data_dir)

print(f"Training samples: {len(train_files)}")
print(f"Validation samples: {len(val_files)}")
print(f"Testing samples: {len(test_files)}")

# Define data transformations for training
train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
    ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    RandCropByPosNegLabeld(
        keys=["image", "label"],
        label_key="label",
        spatial_size=(96, 96, 96),
        pos=1,
        neg=1,
        num_samples=4,
    ),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    RandScaleIntensityd(keys=["image"], factors=0.1, prob=0.5),
    RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
    ToTensord(keys=["image", "label"]),
])

# Define data transformations for validation and testing
val_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
    ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    ToTensord(keys=["image", "label"]),
])

# Create datasets
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
test_ds = CacheDataset(data=test_files, transform=val_transforms, cache_rate=1.0, num_workers=4)

# Create data loaders
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, pin_memory=True)

# 2. Model Architecture Setup

# Define number of output classes
# Class 0: Background
# Class 1: Gray matter
# Class 2: White matter
# Class 3: CSF/Ventricles
# Class 4: Hippocampus
# Class 5: Abnormality (lesions, tumors)
num_classes = 6

# Initialize the segmentation model
model = BrainSegmentationModel(
    spatial_dims=3,
    in_channels=1,
    out_channels=num_classes,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    dropout=0.2
).to(device)

# Define loss function and optimizer
loss_function = DiceCELoss(include_background=False, to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['training']['epochs'])

# Define metrics
dice_metric = DiceMetric(include_background=False, reduction="mean")
hausdorff_metric = HausdorffDistanceMetric(include_background=False, reduction="mean")

# 3. Training and Validation Loop

# Training parameters
num_epochs = config['training']['epochs']
val_interval = config['training']['val_interval']
best_metric = -1
best_metric_epoch = -1
train_losses = []
val_metrics = []

print(f"Training for {num_epochs} epochs...")

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    
    # Training
    model.train()
    epoch_loss = 0
    step = 0
    start_time = time.time()
    
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        print(f"{step}/{len(train_loader)}, train_loss: {loss.item():.4f}", end='\r')
        
    epoch_loss /= step
    train_losses.append(epoch_loss)
    training_time = time.time() - start_time
    print(f"Epoch {epoch + 1} training completed. Average loss: {epoch_loss:.4f}, Time: {training_time:.2f}s")
    
    scheduler.step()
    
    # Validation
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            dice_vals = []
            hausdorff_vals = []
            
            for val_data in val_loader:
                val_inputs, val_labels = val_data["image"].to(device), val_data["label"].to(device)
                
                # Sliding window inference for large volumes
                roi_size = (96, 96, 96)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                
                # Post-processing (argmax)
                val_outputs = torch.argmax(val_outputs, dim=1, keepdim=True)
                
                # Compute metrics
                dice_metric(y_pred=val_outputs, y=val_labels)
                hausdorff_metric(y_pred=val_outputs, y=val_labels)
                
            # Aggregate metrics
            dice_result = dice_metric.aggregate().item()
            hausdorff_result = hausdorff_metric.aggregate().item()
            dice_metric.reset()
            hausdorff_metric.reset()
            
            val_metrics.append(dice_result)
            
            if dice_result > best_metric:
                best_metric = dice_result
                best_metric_epoch = epoch + 1
                # Save the best model
                torch.save(model.state_dict(), os.path.join(config['training']['model_dir'], "best_model.pth"))
                print("Saved new best model!")
                
            print(
                f"Validation Dice score: {dice_result:.4f}, "
                f"Hausdorff distance: {hausdorff_result:.4f}, "
                f"Best Dice: {best_metric:.4f} at epoch {best_metric_epoch}"
            )

# 4. Model Evaluation on Test Set

print("\nEvaluating best model on test dataset...")
model.load_state_dict(torch.load(os.path.join(config['training']['model_dir'], "best_model.pth")))
model.eval()

test_dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=True)
test_hausdorff_metric = HausdorffDistanceMetric(include_background=False, reduction="mean", get_not_nans=True)

# Test time augmentation transforms
test_time_aug_transforms = [
    # Original
    lambda x: x,
    # Flips
    lambda x: torch.flip(x, dims=[-1]),
    lambda x: torch.flip(x, dims=[-2]),
    lambda x: torch.flip(x, dims=[-3]),
]

with torch.no_grad():
    for test_data in test_loader:
        test_inputs, test_labels = test_data["image"].to(device), test_data["label"].to(device)
        
        # Test-time augmentation (TTA)
        test_outputs = []
        roi_size = (96, 96, 96)
        sw_batch_size = 4
        
        for aug_transform in test_time_aug_transforms:
            # Apply augmentation
            aug_inputs = aug_transform(test_inputs)
            # Inference
            aug_outputs = sliding_window_inference(aug_inputs, roi_size, sw_batch_size, model)
            # Reverse augmentation on outputs
            rev_aug_outputs = aug_transform(aug_outputs)
            test_outputs.append(rev_aug_outputs)
        
        # Average predictions from TTA
        test_outputs = torch.mean(torch.stack(test_outputs), dim=0)
        
        # Apply argmax to get the predicted labels
        test_outputs = torch.argmax(test_outputs, dim=1, keepdim=True)
        
        # Compute metrics for each test case
        test_dice_metric(y_pred=test_outputs, y=test_labels)
        test_hausdorff_metric(y_pred=test_outputs, y=test_labels)
        
        # Save example segmentations
        if np.random.rand() < 0.1:  # Save 10% of test predictions for visualization
            subject_id = test_data['image_meta_dict']['filename_or_obj'][0].split('/')[-1].split('_')[0]
            output_dir = os.path.join(config['visualization']['output_dir'], 'test_predictions')
            os.makedirs(output_dir, exist_ok=True)
            
            # Save the central slice of each prediction
            for i in range(1, num_classes):
                plt.figure(figsize=(12, 4))
                
                # Find central slice for visualization
                z_dim = test_outputs.shape[-1] // 2
                
                # Original image
                plt.subplot(1, 3, 1)
                plt.imshow(test_inputs[0, 0, :, :, z_dim].cpu().numpy(), cmap='gray')
                plt.title('Original MRI')
                plt.axis('off')
                
                # Ground truth segmentation
                plt.subplot(1, 3, 2)
                gt_mask = (test_labels[0, 0, :, :, z_dim] == i).cpu().numpy()
                plt.imshow(test_inputs[0, 0, :, :, z_dim].cpu().numpy(), cmap='gray')
                plt.imshow(gt_mask, alpha=0.5, cmap='jet')
                plt.title(f'Ground Truth Class {i}')
                plt.axis('off')
                
                # Predicted segmentation
                plt.subplot(1, 3, 3)
                pred_mask = (test_outputs[0, 0, :, :, z_dim] == i).cpu().numpy()
                plt.imshow(test_inputs[0, 0, :, :, z_dim].cpu().numpy(), cmap='gray')
                plt.imshow(pred_mask, alpha=0.5, cmap='jet')
                plt.title(f'Prediction Class {i}')
                plt.axis('off')
                
                plt.tight_layout()
                plt.savefig(os.path.join(output_dir, f'{subject_id}_class{i}.png'))
                plt.close()

# Aggregate test metrics
test_dice = test_dice_metric.aggregate().cpu().numpy()
test_hausdorff = test_hausdorff_metric.aggregate().cpu().numpy()

# Print per-class metrics
print("\nTest Metrics (per class):")
for i in range(num_classes-1):  # Skip background class (0)
    print(f"Class {i+1}: Dice = {test_dice[i]:.4f}, HD = {test_hausdorff[i]:.4f}")

# Print average metrics
print(f"\nAverage Dice score: {np.mean(test_dice):.4f}")
print(f"Average Hausdorff Distance: {np.mean(test_hausdorff):.4f}")

# 5. Plot Training Curves

plt.figure(figsize=(12, 5))

# Plot training loss
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot validation Dice score
val_epochs = [i*val_interval for i in range(len(val_metrics))]
plt.subplot(1, 2, 2)
plt.plot(val_epochs, val_metrics, label='Validation Dice')
plt.title('Validation Dice Score')
plt.xlabel('Epoch')
plt.ylabel('Dice Score')
plt.legend()

plt.tight_layout()
plt.savefig(os.path.join(config['visualization']['output_dir'], 'training_curves.png'))
plt.show()

# 6. Save Final Model
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "train_losses": train_losses,
    "val_metrics": val_metrics,
    "test_dice": test_dice,
    "test_hausdorff": test_hausdorff,
    "best_metric": best_metric,
    "best_metric_epoch": best_metric_epoch,
    "config": config,
}, os.path.join(config['training']['model_dir'], "final_model.pth"))

print("\nTraining completed. Final model saved.")

# 7. Create Learning Curves Table
results_df = pd.DataFrame({
    'Epoch': list(range(1, num_epochs + 1)),
    'Training Loss': train_losses,
})

# Add validation metrics
val_results = np.full(num_epochs, np.nan)
for i, epoch in enumerate(val_epochs):
    val_results[epoch-1] = val_metrics[i]
results_df['Validation Dice'] = val_results

# Save results to CSV
results_df.to_csv(os.path.join(config['training']['model_dir'], 'training_results.csv'), index=False)
print("Learning curves data saved to CSV.")

ImportError: cannot import name 'prepare_datalist' from 'src.data.dataset' (D:\KLU 4th YEAR\Projects\Brain_AI\src\data\dataset.py)