In [None]:
import os
import cv2
import random
import numpy as np
import pandas as pd
import nibabel as nib
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from scipy.ndimage import zoom
from torch.cuda.amp import autocast, GradScaler
torch.backends.cudnn.benchmark = True


In [None]:

# now taking 9 adjacent slices: 4 before, 1 center, 4 after → much deeper 2.5D context than trial18 (which used 3)
class BraTSDataset(Dataset):
    def __init__(self, data_dir, patients=None, modalities=("t1n", "t1c", "t2w", "t2f"),
                 patch_depth=9, target_size=(128, 128), only_tumor_slices=True): # Explicitly filters out background-only slices, helping reduce class imbalance.
        self.modalities = modalities
        self.patch_depth = patch_depth
        self.half = patch_depth // 2
        self.target_size = target_size
        self.only_tumor_slices = only_tumor_slices
        self.data_dir = data_dir
        self.patients = sorted(os.listdir(self.data_dir)) if patients is None else patients

        self.index_map = []
        for pat in self.patients:
            seg_path = os.path.join(self.data_dir, pat, f"{pat}-seg.nii.gz")
            seg = nib.load(seg_path).get_fdata()
            for z in range(seg.shape[2]):
                if self.only_tumor_slices:
                    if np.any(seg[:, :, z] > 0):
                        self.index_map.append((pat, z))
                else:
                    self.index_map.append((pat, z))     

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx): # Input Creation
        pat, z = self.index_map[idx]
        folder = os.path.join(self.data_dir, pat)
        H, W = self.target_size

        vols = []
        for mod in self.modalities:
            img = nib.load(os.path.join(folder, f"{pat}-{mod}.nii.gz")).get_fdata().astype(np.float32)
            mask = img != 0
            mu, sigma = (img[mask].mean(), img[mask].std()) if mask.sum() > 0 else (img.mean(), img.std())
            img = (img - mu) / (sigma + 1e-8)

            resized = np.zeros((H, W, img.shape[2]), dtype=np.float32)
            for zi in range(img.shape[2]):
                resized[:, :, zi] = cv2.resize(img[:, :, zi], (W, H), interpolation=cv2.INTER_CUBIC)
            vols.append(resized)

        vol4 = np.stack(vols, axis=0)  # shape: (4, H, W, D)
        vol4 = np.pad(vol4, ((0, 0), (0, 0), (0, 0), (self.half, self.half)), mode='edge')
        zp = z + self.half
        patch = vol4[:, :, :, zp - self.half: zp + self.half + 1]  # (4, H, W, D=5). Patch shape becomes (36, 128, 128): 9 slices × 4 modalities
        x = patch.transpose(3, 0, 1, 2).reshape(self.patch_depth * 4, H, W)  # (36, H, W)

        '''Segmentation Label Handling:
            No change here:
            Resizing each slice with nearest neighbor (preserves label integrity),
            Padding to align with the same depth,
            Returning the center slice label as target.'''
        seg = nib.load(os.path.join(folder, f"{pat}-seg.nii.gz")).get_fdata().astype(np.int64)
        seg_r = np.zeros((H, W, seg.shape[2]), dtype=np.int64)
        for zi in range(seg.shape[2]):
            seg_r[:, :, zi] = cv2.resize(seg[:, :, zi], (W, H), interpolation=cv2.INTER_NEAREST)
        seg_r = np.pad(seg_r, ((0, 0), (0, 0), (self.half, self.half)), mode='edge')
        y = seg_r[:, :, zp]  # (H, W)

        return torch.from_numpy(x).float(), torch.from_numpy(y).long()

# === Data Setup ===
data_dir = "/scratch/scai/mtech/aib232081/brain_tumor_detection/BraTS2024_dataset/BraTS2024-BraTS-GLI-TrainingData/training_data1_v2"
all_patients = sorted(os.listdir(data_dir))  # 1350 patients

'''Train: 80% (1200 patients)
Val: 10% (150 patients)
Test: 10% (150 patients)
The use of 0.1111 ensures 150 val out of the 1350 total.'''
# Split: 80/10/10
train_val_patients, test_patients = train_test_split(all_patients, test_size=0.1, random_state=42)
train_patients, val_patients = train_test_split(train_val_patients, test_size=0.1111, random_state=42)

# Dataset
train_ds = BraTSDataset(data_dir, patients=train_patients, patch_depth=9, only_tumor_slices=True)
val_ds   = BraTSDataset(data_dir, patients=val_patients, patch_depth=9, only_tumor_slices=True)
test_ds  = BraTSDataset(data_dir, patients=test_patients, patch_depth=9, only_tumor_slices=True)

# DataLoader
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=10, pin_memory=True, persistent_workers=True, prefetch_factor=4)
val_loader   = DataLoader(val_ds, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
'''we are now using:
    num_workers=10: Loads data using 10 CPU threads.
    persistent_workers=True: Keeps workers alive across epochs (faster).
    prefetch_factor=4: Each worker loads 4 batches in advance.
    This can significantly speed up training on GPU systems.'''


comparison to trial18:
| Feature                     | Trial18 (baseline)  | **Trial23**                                       |
| --------------------------- | ------------------- | ------------------------------------------------- |
| Depth of Input              | 3 slices            | **9 slices** (richer context)                     |
| Encoder Blocks              | Simple Conv-BN-ReLU | **Residual Blocks**                               |
| Skip Connection Enhancement | None                | **Attention Gates** (AG)                          |
| Normalization Type          | BatchNorm           | **GroupNorm** (more stable for small batch sizes) |
| Decoder Path                | Normal Conv + Up    | **Attention + Residual Decoder**                  |


In [None]:

'''
ResidualConvBlock:
    Two 3×3 convolutions + GroupNorm + ReLU.
    Adds residual connection via a 1×1 conv on the input.
    Output: ReLU(F(x) + x).

Why this helps:
    Easier gradient flow → improves training stability in deeper networks.
    Helps preserve spatial information across layers.'''
class ResidualConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GroupNorm(8, out_ch)
        )
        self.relu = nn.ReLU(inplace=True)
        self.residual = nn.Conv2d(in_ch, out_ch, kernel_size=1)

    def forward(self, x):
        res = self.residual(x)
        out = self.conv(x)
        return self.relu(out + res)

'''
AttentionGate:
    Takes in g (decoder signal) and x (encoder skip connection).
    Applies 1×1 convs + GroupNorm to both, then adds, passes through ReLU and sigmoid.
    Final attention map modulates the encoder features.

Why this helps:
    Forces the decoder to focus only on relevant regions (e.g., tumor).
    Prevents "noise" from early layers diluting the learning.'''
class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1),
            nn.GroupNorm(8, F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1),
            nn.GroupNorm(8, F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class UNet2p5D_AttentionResidual(nn.Module): # Three levels of residual encoder blocks. Each downsampling via MaxPool halves resolution, doubles channels.
    def __init__(self, in_ch=36, base_ch=64, n_classes=5):
        super().__init__()
        self.enc1 = ResidualConvBlock(in_ch, base_ch)
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = ResidualConvBlock(base_ch, base_ch * 2)
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = ResidualConvBlock(base_ch * 2, base_ch * 4)
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = ResidualConvBlock(base_ch * 4, base_ch * 8) # Deepest layer. Learns abstract features like tumor boundaries, location consistency, etc.

        # Decoder Path with Attention:
        self.up3 = nn.ConvTranspose2d(base_ch * 8, base_ch * 4, kernel_size=2, stride=2) # Upsample the bottleneck output
        self.att3 = AttentionGate(F_g=base_ch * 4, F_l=base_ch * 4, F_int=base_ch * 2) # Use attention to filter skip features from encoder
        self.dec3 = ResidualConvBlock(base_ch * 8, base_ch * 4) # Concatenate and apply a residual conv block

        self.up2 = nn.ConvTranspose2d(base_ch * 4, base_ch * 2, kernel_size=2, stride=2)
        self.att2 = AttentionGate(F_g=base_ch * 2, F_l=base_ch * 2, F_int=base_ch)
        self.dec2 = ResidualConvBlock(base_ch * 4, base_ch * 2)

        self.up1 = nn.ConvTranspose2d(base_ch * 2, base_ch, kernel_size=2, stride=2)
        self.att1 = AttentionGate(F_g=base_ch, F_l=base_ch, F_int=base_ch // 2)
        self.dec1 = ResidualConvBlock(base_ch * 2, base_ch)

        self.out = nn.Conv2d(base_ch, n_classes, kernel_size=1) # Final 1×1 convolution produces logits of shape [B, 5, 128, 128]

    def forward(self, x):
        e1 = self.enc1(x)               # 64
        e2 = self.enc2(self.pool1(e1))  # 128
        e3 = self.enc3(self.pool2(e2))  # 256
        b = self.bottleneck(self.pool3(e3))  # 512

        g3 = self.up3(b)
        a3 = self.att3(g3, e3)
        d3 = self.dec3(torch.cat([g3, a3], dim=1))

        g2 = self.up2(d3)
        a2 = self.att2(g2, e2)
        d2 = self.dec2(torch.cat([g2, a2], dim=1))

        g1 = self.up1(d2)
        a1 = self.att1(g1, e1)
        d1 = self.dec1(torch.cat([g1, a1], dim=1))

        return self.out(d1)


In [None]:
'''
Addressing class imbalance
Using a temperature-scaled Generalized Dice Loss
Emphasizing Dice over CrossEntropy with alpha = 0.3 and beta = 0.7

Applies temperature-scaled softmax, pred = F.softmax(pred / temperature, dim=1):
    Lower temperature (<1) makes the distribution sharper.
    Helps focus on confident predictions — can improve segmentation boundaries.
    
    '''
def generalized_dice_loss(pred, target, epsilon=1e-6, temperature=0.7):
    # Apply temperature-scaled softmax
    pred = F.softmax(pred / temperature, dim=1)
    one_hot = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2).float().to(pred.device) # Converts target (shape [B, H, W]) to one-hot (shape [B, C, H, W]) to match predictions.

    # Class weights: inverse square of ground truth volume
    weights = 1.0 / (one_hot.sum(dim=(2, 3)) ** 2 + epsilon)
    '''
    Generalized Dice Loss (GDL) uses weights that reduce the impact of large-volume classes:
        Background has high volume → gets very small weight.
        Tumor regions get higher importance.'''
    
    # Computes per-class intersection and union over spatial dimensions.
    intersect = (pred * one_hot).sum(dim=(2, 3))
    union = pred.sum(dim=(2, 3)) + one_hot.sum(dim=(2, 3))
    dice = 2 * (weights * intersect).sum(dim=1) / (weights * union).sum(dim=1) # Weighted Dice for each sample in batch

    return 1 - dice.mean()


def combined_loss(pred, target, alpha=0.3, beta=0.7, temperature=0.7):
    # Class weights:  reduces overemphasis on background in CrossEntropy
    class_weights = torch.tensor([0.01, 1, 1, 1, 1], device=pred.device)
    
    ce = F.cross_entropy(pred, target, weight=class_weights, reduction='mean') # Standard cross entropy loss with class weights
    gdl = generalized_dice_loss(pred, target, temperature=temperature)

    total = alpha * ce + beta * gdl
    return total, ce.item(), gdl.item() # also returns individual loss components for logging


In [None]:
# Import model and loss
# from model import UNet2p5D_AttentionResidual
# from loss import combined_loss
import time

def compute_metrics(pred, target, num_classes=5):
    pred_soft = F.softmax(pred, dim=1)
    pred_labels = pred_soft.argmax(dim=1)

    one_hot_pred = F.one_hot(pred_labels, num_classes).permute(0, 3, 1, 2).float()
    one_hot_target = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float()

    one_hot_pred = one_hot_pred.to(pred.device)
    one_hot_target = one_hot_target.to(pred.device)

    intersect = (one_hot_pred * one_hot_target).sum(dim=(2, 3))
    union = one_hot_pred.sum(dim=(2, 3)) + one_hot_target.sum(dim=(2, 3)) - intersect
    dice = (2 * intersect) / (one_hot_pred.sum(dim=(2, 3)) + one_hot_target.sum(dim=(2, 3)) + 1e-6)
    iou = intersect / (union + 1e-6)

    return dice.mean(dim=0).cpu().numpy(), iou.mean(dim=0).cpu().numpy()

# Initialize model
model = UNet2p5D_AttentionResidual(in_ch=36, base_ch=64, n_classes=5).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
scaler = GradScaler() # For stable mixed-precision training
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
checkpoint_path = "checkpoint_trial23.pth"
log_path = "training_logs_trial23.csv"

# Load history if exists
if os.path.exists(log_path):
    history_df = pd.read_csv(log_path)
    history = history_df.to_dict(orient='list')
    start_epoch = int(history_df['epoch'].max()) + 1
    print(f"Resuming from epoch {start_epoch}")
else:
    history = {
        'epoch': [], 'train_loss': [], 'val_loss': [],
        'mean_dice_train': [], 'mean_dice_val': [],
        'mean_iou_train': [], 'mean_iou_val': [],
        'dice_ET_train': [], 'dice_NETC_train': [], 'dice_SNFH_train': [], 'dice_RC_train': [],
        'dice_ET_val': [], 'dice_NETC_val': [], 'dice_SNFH_val': [], 'dice_RC_val': [],
        'epoch_time_sec': [], 'lr': [], 'train_ce': [], 'train_gdl': [], 'val_ce': [], 'val_gdl': []
    }
    start_epoch = 0
    print("Starting new training")

# Load checkpoint if exists
best_val_loss = float('inf')
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    best_val_loss = checkpoint['best_val_loss']
    start_epoch = checkpoint.get('epoch', 0) + 1
    print(f"Loaded checkpoint with best_val_loss = {best_val_loss:.4f}")

# ==== Training ====
num_epochs = 50
early_stop_patience = 6 # Stops training if no val improvement for 6 epochs
early_stop_counter = 0

for epoch in range(start_epoch, num_epochs):
    start_time = time.time()
    model.train()

    train_loss_total, ce_loss_total, gdl_loss_total = 0, 0, 0
    val_ce_total, val_gdl_total = 0, 0
    dice_all, iou_all = [], []

    for xb, yb in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        xb, yb = xb.cuda(), yb.cuda()

        with autocast():
            pred = model(xb)
            loss, ce_val, gdl_val = combined_loss(pred, yb, alpha=0.3, beta=0.7, temperature=0.7)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss_total += loss.item()
        ce_loss_total += ce_val
        gdl_loss_total += gdl_val
        d, i = compute_metrics(pred, yb)
        dice_all.append(d)
        iou_all.append(i)

    dice_all = np.array(dice_all)
    iou_all = np.array(iou_all)
    mean_dice_train = dice_all.mean()
    mean_iou_train = iou_all.mean()

    # ==== Validation ====
    model.eval()
    val_loss_total = 0
    val_dice_all, val_iou_all = [], []

    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.cuda(), yb.cuda()
            with autocast():
                pred = model(xb)
                loss, ce_val, gdl_val = combined_loss(pred, yb, alpha=0.3, beta=0.7, temperature=0.7)

            val_loss_total += loss.item()
            val_ce_total += ce_val
            val_gdl_total += gdl_val
            d, i = compute_metrics(pred, yb)
            val_dice_all.append(d)
            val_iou_all.append(i)

    val_dice_all = np.array(val_dice_all)
    val_iou_all = np.array(val_iou_all)
    mean_dice_val = val_dice_all.mean()
    mean_iou_val = val_iou_all.mean()

    # Logging and Saving
    epoch_time = time.time() - start_time
    lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch+1} -- Train Loss: {train_loss_total:.4f}, Val Loss: {val_loss_total:.4f}")
    print(f"Mean Dice: Train={mean_dice_train:.4f}, Val={mean_dice_val:.4f}")

    scheduler.step(val_loss_total)

    if val_loss_total < best_val_loss:
        best_val_loss = val_loss_total
        early_stop_counter = 0
        torch.save(model.state_dict(), "unet2p5d_best_trial23.pth")
        print("Saved best model.")
    else:
        early_stop_counter += 1
        if early_stop_counter >= early_stop_patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss
    }, checkpoint_path)

    history['epoch'].append(epoch+1)
    history['train_loss'].append(train_loss_total)
    history['val_loss'].append(val_loss_total)
    history['train_ce'].append(ce_loss_total / len(train_loader))
    history['train_gdl'].append(gdl_loss_total / len(train_loader))
    history['val_ce'].append(val_ce_total / len(val_loader))
    history['val_gdl'].append(val_gdl_total / len(val_loader))
    history['mean_dice_train'].append(mean_dice_train)
    history['mean_dice_val'].append(mean_dice_val)
    history['mean_iou_train'].append(mean_iou_train)
    history['mean_iou_val'].append(mean_iou_val)
    history['dice_ET_train'].append(dice_all[:,1].mean())
    history['dice_NETC_train'].append(dice_all[:,2].mean())
    history['dice_SNFH_train'].append(dice_all[:,3].mean())
    history['dice_RC_train'].append(dice_all[:,4].mean())
    history['dice_ET_val'].append(val_dice_all[:,1].mean())
    history['dice_NETC_val'].append(val_dice_all[:,2].mean())
    history['dice_SNFH_val'].append(val_dice_all[:,3].mean())
    history['dice_RC_val'].append(val_dice_all[:,4].mean())
    history['epoch_time_sec'].append(epoch_time)
    history['lr'].append(lr)

    pd.DataFrame(history).to_csv(log_path, index=False)
    print(f"Epoch {epoch+1} completed and logged.\n")

print("Training complete.")
