In [1]:
import os
import cv2
import random
import numpy as np
import pandas as pd
import nibabel as nib
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from scipy.ndimage import zoom
from torch.cuda.amp import autocast, GradScaler
torch.backends.cudnn.benchmark = True


In [None]:
# step1
'''
Initializes the dataset for 3D U-Net
modalities: Tuple of 4 MRI types
resize_hw: Only height & width (128×128), depth (182) is untouched
only_tumor=True: if True, only includes patients with tumor (seg > 0)'''

class BraTS3DDataset(Dataset):
    def __init__(self, data_dir, patients=None, modalities=("t1n", "t1c", "t2w", "t2f"),
                 resize_hw=(128, 128), only_tumor=True):
        self.modalities = modalities
        self.resize_hw = resize_hw  # Only resize H, W
        self.only_tumor = only_tumor
        self.data_dir = data_dir
        self.patients = sorted(os.listdir(self.data_dir)) if patients is None else patients

        '''Filters 1350 patients down to a smaller subset containing at least one labeled region.
         Why this matters:
            Ensures training only includes positive examples (to counter class imbalance).
            Especially useful for 3D U-Net which is memory heavy → better to train on informative samples.'''
        self.filtered = []
        for pat in self.patients:
            seg_path = os.path.join(self.data_dir, pat, f"{pat}-seg.nii.gz")
            seg = nib.load(seg_path).get_fdata()
            if self.only_tumor:
                if np.any(seg > 0):
                    self.filtered.append(pat)
            else:
                self.filtered.append(pat)

    def __len__(self):
        return len(self.filtered) # Returns the number of patients (each sample = 1 full 3D volume)

    # Load & Preprocess 3D Volume
    def __getitem__(self, idx):
        pat = self.filtered[idx]
        folder = os.path.join(self.data_dir, pat)
        H, W = self.resize_hw # Resizes every modality slice to 128×128
        vols = []

        # Load Each Modality → Normalize → Resize
        for mod in self.modalities:
            img = nib.load(os.path.join(folder, f"{pat}-{mod}.nii.gz")).get_fdata().astype(np.float32)
            mask = img != 0
            mu, sigma = (img[mask].mean(), img[mask].std()) if mask.sum() > 0 else (img.mean(), img.std())
            img = (img - mu) / (sigma + 1e-8)

            # Resize H and W only → shape: (D, H, W)
            zoom_factors = (1.0, H / img.shape[1], W / img.shape[2])
            img_resized = zoom(img, zoom_factors, order=3) # order=3: bicubic interpolation for smooth results
            vols.append(img_resized)

        x = np.stack(vols, axis=0)  # shape: (4, D, H, W). Stacks all 4 modalities → final shape (4, 182, 128, 128)

        seg = nib.load(os.path.join(folder, f"{pat}-seg.nii.gz")).get_fdata().astype(np.int64) # Loads 3D segmentation mask
        seg_resized = zoom(seg, zoom_factors, order=0) # Resizes using nearest neighbor (order=0) to preserve label integrity
        y = seg_resized  # shape: (D, H, W)

        return torch.from_numpy(x).float(), torch.from_numpy(y).long()
        # x: Tensor of shape (4, 182, 128, 128) → full 3D input
        # y: Tensor of shape (182, 128, 128) → label map


data_dir = "/scratch/scai/mtech/aib232081/brain_tumor_detection/BraTS2024_dataset/BraTS2024-BraTS-GLI-TrainingData/training_data1_v2"
all_patients = sorted(os.listdir(data_dir))  # 1350 patients

# Split (80/10/10) Train: ~1080, Val: ~135, Test: ~135
train_val, test = train_test_split(all_patients, test_size=0.1, random_state=42)
train, val = train_test_split(train_val, test_size=0.1, random_state=42)

# Dataset
train_ds = BraTS3DDataset(data_dir, patients=train, only_tumor=True)
val_ds   = BraTS3DDataset(data_dir, patients=val, only_tumor=True)
test_ds  = BraTS3DDataset(data_dir, patients=test, only_tumor=True)

# DataLoader with full prefetch optimizations
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8,
                          pin_memory=True, persistent_workers=True, prefetch_factor=4)

val_loader   = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=4,
                          pin_memory=True, persistent_workers=True)

test_loader  = DataLoader(test_ds, batch_size=2, shuffle=False, num_workers=4,
                          pin_memory=True, persistent_workers=True)

'''
Small batch_size=4 (3D tensors consume a lot of GPU memory)
num_workers=8: Use 8 CPU threads to load data in parallel
persistent_workers=True: avoids re-spawning workers every epoch
prefetch_factor=4: each worker loads 4 batches in advance → keeps GPU busy'''

In [None]:
# step2
class ConvBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch, p=0.3):  # dropout probability
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout3d(p),
            nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True)
        )

        ''' Two 3D convolution layers (kernel=3, padding=1 keeps dimensions same)
            Each followed by:
            BatchNorm3d for stable training
            ReLU activation
            Dropout3d(p) in the middle → helps regularize and prevent overfitting in deep 3D networks'''

    def forward(self, x):
        return self.block(x)

# in_ch=4: input channels = 4 modalities (t1n, t1c, t2w, t2f)
# base_ch=64: number of channels at the first layer, scaled up in deeper layers
# n_classes=5: output segmentation labels (0–4)
class UNet3D(nn.Module):
    def __init__(self, in_ch=4, base_ch=64, n_classes=5):  # base_ch can be 32 or 64
        super().__init__()
        self.enc1 = ConvBlock3D(in_ch, base_ch) # Input: [B, 4, 182, 128, 128] → [B, 64, 182, 128, 128]
        self.pool1 = nn.MaxPool3d(2) # spatial depth halves at each step. Downsample by factor of 2 → [B, 64, 91, 64, 64]; 182 → 91 (depth), 128 → 64 (height/width)

        self.enc2 = ConvBlock3D(base_ch, base_ch * 2) # 64 → 128 channels
        self.pool2 = nn.MaxPool3d(2) # 91 → 45 (depth), 64 → 32 (height/width)

        self.enc3 = ConvBlock3D(base_ch * 2, base_ch * 4) # 128 → 256 channels
        self.pool3 = nn.MaxPool3d(2) # 45 → 22 (depth), 32 → 16 (height/width)

        self.bottleneck = ConvBlock3D(base_ch * 4, base_ch * 8) # Deepest layer, 256 → 512 channels; Learns global features — structure, tumor shape, context

        # Decoder Path (Upsampling + Skip Connections)
        self.dec3 = ConvBlock3D(base_ch * 12, base_ch * 4)
        '''
        Takes:
            Upsampled bottleneck (512 → 256 channels),
            Concatenated with encoder3 output (256 channels),
            Total input: 512 → so channels = base_ch * 12 = 768
            Actually: 512 + 256 = 768, but this line assumes base_ch * 12, which is correct if base = 64 (64×12=768)'''
        self.dec2 = ConvBlock3D(base_ch * 6, base_ch * 2) # same as above: Interpolate → concatenate with skip connection → reduce channels
        self.dec1 = ConvBlock3D(base_ch * 3, base_ch)
        # Dimensions: recover to 128×128×182

        self.out_conv = nn.Conv3d(base_ch, n_classes, kernel_size=1) # Final prediction layer: 64 → 5 classes using 1×1×1 conv. [B, 5, 182, 128, 128] → raw logits

    def forward(self, x):
        # Encoder forward pass with downsampling. Stores e1, e2, e3 for skip connections
        e1 = self.enc1(x)                 # -> [B, C, D, H, W]
        e2 = self.enc2(self.pool1(e1))   # -> [B, 2C, D/2, H/2, W/2]
        e3 = self.enc3(self.pool2(e2))   # -> [B, 4C, D/4, H/4, W/4]
        b = self.bottleneck(self.pool3(e3))  # -> [B, 8C, D/8, H/8, W/8]

        u3 = F.interpolate(b, size=e3.shape[2:], mode='trilinear', align_corners=False) # Interpolates bottleneck back to e3’s shape

        d3 = self.dec3(torch.cat([u3, e3], dim=1)) # Concatenate and pass through decoder block. Same pattern follows for u2, d2, then u1, d1.

        u2 = F.interpolate(d3, size=e2.shape[2:], mode='trilinear', align_corners=False)
        d2 = self.dec2(torch.cat([u2, e2], dim=1))

        u1 = F.interpolate(d2, size=e1.shape[2:], mode='trilinear', align_corners=False)
        d1 = self.dec1(torch.cat([u1, e1], dim=1))

        return self.out_conv(d1) # Final logits (no softmax — handled in loss)


In [None]:
# step3

def soft_dice_loss_3d(pred, target, epsilon=1e-6):
    """Soft dice loss for 3D volumetric data"""
    pred = F.softmax(pred, dim=1) # Applies softmax over class dimension → shape becomes [B, C, D, H, W]
        # Each voxel now contains class probability distribution

    target_onehot = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 4, 1, 2, 3).float().to(pred.device)
        # Converts ground truth target of shape [B, D, H, W] to one-hot format → [B, C, D, H, W]
        # Matches shape with prediction for per-class Dice calculation

    intersect = (pred * target_onehot).sum(dim=(2, 3, 4)) # true positives per class per batch
    denominator = pred.sum(dim=(2, 3, 4)) + target_onehot.sum(dim=(2, 3, 4)) # predicted + actual voxels (union)

    dice = (2. * intersect + epsilon) / (denominator + epsilon) # Computes average Dice across all classes and batch
    return 1 - dice.mean()

def combined_3d_loss(pred, target, alpha=0.3, beta=0.7):
    # Class weights: reduce emphasis on background
    weights = torch.tensor([0.01, 1.0, 1.0, 1.0, 1.0], device=pred.device)
        # Background (label 0) gets very low weight = 0.01
        # Tumor classes (labels 1–4) get full weight = 1.0

    ce = F.cross_entropy(pred, target, weight=weights, reduction='mean') # Standard weighted cross-entropy over [B, C, D, H, W] logits and [B, D, H, W] targets
    dice = soft_dice_loss_3d(pred, target)

    total = alpha * ce + beta * dice # Final loss = 30% CE + 70% Dice
    return total, ce.item(), dice.item()


In [None]:
# step4

import time

# === Metric Calculation ===
def compute_metrics_3d(pred, target, num_classes=5): # Computes per-class Dice and IoU for 3D predictions.
    pred_soft = F.softmax(pred, dim=1) # Converts raw logits to predicted class labels using argmax
    pred_labels = pred_soft.argmax(dim=1)

    one_hot_pred = F.one_hot(pred_labels, num_classes).permute(0, 4, 1, 2, 3).float() # One-hot encodes to shape [B, C, D, H, W]
    one_hot_target = F.one_hot(target, num_classes).permute(0, 4, 1, 2, 3).float()

    one_hot_pred = one_hot_pred.to(pred.device)
    one_hot_target = one_hot_target.to(pred.device)

    # Computes per-class Dice and IoU scores
    intersect = (one_hot_pred * one_hot_target).sum(dim=(2, 3, 4))
    union = one_hot_pred.sum(dim=(2, 3, 4)) + one_hot_target.sum(dim=(2, 3, 4)) - intersect
    dice = (2 * intersect) / (one_hot_pred.sum(dim=(2, 3, 4)) + one_hot_target.sum(dim=(2, 3, 4)) + 1e-6)
    iou = intersect / (union + 1e-6)

    return dice.mean(dim=0).cpu().numpy(), iou.mean(dim=0).cpu().numpy()
        # Returns per-class average Dice and IoU (across the batch)

# === Model, Optimizer, and Scaler Setup ===
model = UNet3D(in_ch=4, base_ch=64, n_classes=5).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
scaler = GradScaler() # GradScaler() for automatic mixed precision training (saves memory and speeds up)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2) 
    # ReduceLROnPlateau: halves learning rate if validation loss doesn’t improve for 2 epochs

# === Logging ===
log_path = "training_logs_3d_unet_trial24b.csv"
checkpoint_path = "checkpoint_3d_unet_trial24b.pth"
history = []

best_val_loss = float('inf')
epochs = 50

# Training phase:
for epoch in range(epochs):
    model.train()
    start_time = time.time()

    total_train_loss = 0
    total_ce, total_dice = 0, 0
    dice_train, iou_train = [], []

    for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch+1} - Train"): # load a batch of 3D input and labels
        xb, yb = xb.cuda(), yb.cuda() # Shapes: xb → [B, 4, D, H, W], yb → [B, D, H, W]

        with autocast(): # Runs the model in mixed precision using autocast
            pred = model(xb)
            loss, ce_val, dice_val = combined_3d_loss(pred, yb) # Computes combined loss: 0.3 × CE + 0.7 × Dice

        optimizer.zero_grad() # Standard AMP training block with scaled gradients
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_train_loss += loss.item()
        total_ce += ce_val
        total_dice += dice_val

        d, i = compute_metrics_3d(pred, yb) # Computes and logs per-class metrics for each batch
        dice_train.append(d)
        iou_train.append(i)

    dice_train = np.array(dice_train)
    iou_train = np.array(iou_train)

    # Validation phase:
    model.eval()
    total_val_loss = 0
    val_ce, val_dice = 0, 0
    dice_val, iou_val = [], []

    with torch.no_grad():
        for xb, yb in tqdm(val_loader, desc=f"Epoch {epoch+1} - Val"):
            xb, yb = xb.cuda(), yb.cuda()

            with autocast():
                pred = model(xb)
                loss, ce_val, dice_val_ = combined_3d_loss(pred, yb)

            total_val_loss += loss.item()
            val_ce += ce_val
            val_dice += dice_val_

            d, i = compute_metrics_3d(pred, yb)
            dice_val.append(d)
            iou_val.append(i)

    dice_val = np.array(dice_val)
    iou_val = np.array(iou_val)

    epoch_time = time.time() - start_time
    lr = optimizer.param_groups[0]['lr']

    mean_dice_train = dice_train.mean()
    mean_dice_val = dice_val.mean()
    mean_iou_train = iou_train.mean()
    mean_iou_val = iou_val.mean()

    scheduler.step(total_val_loss)

    if total_val_loss < best_val_loss:
        best_val_loss = total_val_loss
        torch.save(model.state_dict(), "best_3d_unet_model_trial24b.pth")

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss
    }, checkpoint_path)

    history.append({
        'epoch': epoch + 1,
        'train_loss': total_train_loss,
        'val_loss': total_val_loss,
        'mean_dice_train': mean_dice_train,
        'mean_dice_val': mean_dice_val,
        'mean_iou_train': mean_iou_train,
        'mean_iou_val': mean_iou_val,
        'dice_ET_train': dice_train[:,1].mean(),
        'dice_NETC_train': dice_train[:,2].mean(),
        'dice_SNFH_train': dice_train[:,3].mean(),
        'dice_RC_train': dice_train[:,4].mean(),
        'dice_ET_val': dice_val[:,1].mean(),
        'dice_NETC_val': dice_val[:,2].mean(),
        'dice_SNFH_val': dice_val[:,3].mean(),
        'dice_RC_val': dice_val[:,4].mean(),
        'epoch_time_sec': epoch_time,
        'lr': lr,
        'train_ce': total_ce / len(train_loader),
        'train_dice': total_dice / len(train_loader),
        'val_ce': val_ce / len(val_loader),
        'val_dice': val_dice / len(val_loader),
    })

    pd.DataFrame(history).to_csv(log_path, index=False)
    print(f"Epoch {epoch+1} logged. Dice: Train={mean_dice_train:.4f}, Val={mean_dice_val:.4f}")

print("Training complete.")