In [1]:
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from pathlib import Path
import nibabel as nib
from tqdm.auto import tqdm
from skimage import measure
import gc

# --- CONFIGURATION ---
DEBUG = False  # <--- SET TO FALSE FOR THE OVERNIGHT RUN

CONFIG = {
    "experiment_name": "Exp08_Final_5Fold_CV",
    "image_size": 256,
    "batch_size": 12,
    "num_workers": 0,
    "learning_rate": 1e-4,
    "epochs": 1 if DEBUG else 100,           # 1 epoch for debug, 100 for real
    "early_stopping_patience": 20,
    "seed": 42,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "manifest_path": "../data/UT-EndoMRI/D2_Half_Split/d2_half_split_manifest.csv",
    "data_root": "../data/UT-EndoMRI/D2_Half_Split",
    "save_dir": "../models/final_5fold_results"
}

# Create save directory
Path(CONFIG["save_dir"]).mkdir(parents=True, exist_ok=True)

print(f"Running in {CONFIG['device']} mode.")
if DEBUG:
    print("WARNING: DEBUG MODE IS ON. Training will be fake.")

# --- UTILS ---
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(CONFIG["seed"])

# --- DATASET ---
class OvaryDataset(Dataset):
    def __init__(self, root_dir, df, image_size=256, augment=False, debug=False):
        self.root_dir = Path(root_dir)
        self.df = df.reset_index(drop=True)
        
        # In DEBUG mode, slice the dataframe to just 2 patients to be fast
        if debug:
            unique_pids = self.df['pid'].unique()[:2]
            self.df = self.df[self.df['pid'].isin(unique_pids)].reset_index(drop=True)
            
        self.image_size = image_size
        self.augment = augment
        self.samples = []
        
        # Pre-scan for valid slices (where mask > 0)
        # This might take a minute per fold, but ensures cleaner training
        for idx, row in self.df.iterrows():
            img_p = self.root_dir / Path(row['t2fs_path']).name
            msk_p = self.root_dir / Path(row['ov_path']).name
            
            if not img_p.exists(): continue
            
            try:
                # Fast header check or just load
                msk_vol = nib.load(str(msk_p)).get_fdata()
                # Find slices with content
                z_sums = np.sum(msk_vol, axis=(0, 1))
                valid_slices = np.where(z_sums > 0)[0]
                
                for z in valid_slices:
                    self.samples.append({
                        'img_path': str(img_p),
                        'msk_path': str(msk_p),
                        'slice_idx': z,
                        'pid': row['pid']
                    })
            except Exception as e:
                # In debug we might ignore errors, in prod we want to know
                pass

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        info = self.samples[idx]
        
        # Load data
        img_vol = nib.load(info['img_path']).get_fdata()
        msk_vol = nib.load(info['msk_path']).get_fdata()
        z = info['slice_idx']
        
        img = img_vol[:, :, z]
        msk = msk_vol[:, :, z]
        
        # Preprocessing (Standard Min-Max)
        p1 = np.percentile(img, 1)
        p99 = np.percentile(img, 99)
        img = np.clip(img, p1, p99)
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        
        msk = (msk > 0).astype(np.float32)
        
        # To Tensor (C, H, W)
        img = torch.from_numpy(img.T).float().unsqueeze(0)
        msk = torch.from_numpy(msk.T).float().unsqueeze(0)
        
        # Resize
        img = TF.resize(img, [self.image_size, self.image_size], interpolation=T.InterpolationMode.BILINEAR, antialias=True)
        msk = TF.resize(msk, [self.image_size, self.image_size], interpolation=T.InterpolationMode.NEAREST, antialias=True)
        
        # Augmentation
        if self.augment:
            angle = random.uniform(-25, 25)
            img = TF.rotate(img, angle, interpolation=T.InterpolationMode.BILINEAR)
            msk = TF.rotate(msk, angle, interpolation=T.InterpolationMode.NEAREST)
            
            max_shift = int(self.image_size * 0.1)
            t_x = random.randint(-max_shift, max_shift)
            t_y = random.randint(-max_shift, max_shift)
            img = TF.affine(img, angle=0, translate=(t_x, t_y), scale=1.0, shear=0, interpolation=T.InterpolationMode.BILINEAR)
            msk = TF.affine(msk, angle=0, translate=(t_x, t_y), scale=1.0, shear=0, interpolation=T.InterpolationMode.NEAREST)
            
        return img, msk

# --- MODEL (Attention U-Net) ---
class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class AttentionUNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=1):
        super(AttentionUNet, self).__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(64, 128))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(128, 256))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(256, 512))
        self.down4 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(512, 1024))
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.att1 = AttentionGate(F_g=512, F_l=512, F_int=256)
        self.conv1 = DoubleConv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.att2 = AttentionGate(F_g=256, F_l=256, F_int=128)
        self.conv2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.att3 = AttentionGate(F_g=128, F_l=128, F_int=64)
        self.conv3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.att4 = AttentionGate(F_g=64, F_l=64, F_int=32)
        self.conv4 = DoubleConv(128, 64)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        s1 = self.inc(x)
        s2 = self.down1(s1)
        s3 = self.down2(s2)
        s4 = self.down3(s3)
        s5 = self.down4(s4)
        d4 = self.up1(s5)
        s4_att = self.att1(g=d4, x=s4)
        d4 = torch.cat([d4, s4_att], dim=1)
        d4 = self.conv1(d4)
        d3 = self.up2(d4)
        s3_att = self.att2(g=d3, x=s3)
        d3 = torch.cat([d3, s3_att], dim=1)
        d3 = self.conv2(d3)
        d2 = self.up3(d3)
        s2_att = self.att3(g=d2, x=s2)
        d2 = torch.cat([d2, s2_att], dim=1)
        d2 = self.conv3(d2)
        d1 = self.up4(d2)
        s1_att = self.att4(g=d1, x=s1)
        d1 = torch.cat([d1, s1_att], dim=1)
        d1 = self.conv4(d1)
        return self.outc(d1)

# --- LOSS & METRICS ---
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()
    def forward(self, inputs, targets, smooth=1):
        inputs_sigmoid = torch.sigmoid(inputs)
        bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean')
        inputs_flat = inputs_sigmoid.view(-1)
        targets_flat = targets.view(-1)
        intersection = (inputs_flat * targets_flat).sum()
        dice_loss = 1 - (2.*intersection + smooth)/(inputs_flat.sum() + targets_flat.sum() + smooth)
        return bce + dice_loss

def keep_largest_component(mask):
    """Post-processing: Keep only the largest connected component."""
    labels = measure.label(mask)
    if labels.max() == 0:
        return mask
    largest_cc = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1
    return largest_cc.astype(np.float32)

# --- TRAINING ENGINE ---
def train_fold(train_loader, val_loader, fold):
    print(f"\n--- Training Fold {fold} ---")
    model = AttentionUNet().to(CONFIG["device"])
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"], weight_decay=1e-5)
    criterion = DiceBCELoss()
    
    best_dice = 0.0
    patience = 0
    save_path = Path(CONFIG["save_dir"]) / f"best_model_fold_{fold}.pth"
    
    for epoch in range(CONFIG["epochs"]):
        # Train
        model.train()
        train_loss = 0
        for img, msk in train_loader:
            img, msk = img.to(CONFIG["device"]), msk.to(CONFIG["device"])
            optimizer.zero_grad()
            out = model(img)
            loss = criterion(out, msk)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        # Validate (Slice-wise Dice for monitoring)
        model.eval()
        val_dice = 0
        with torch.no_grad():
            for img, msk in val_loader:
                img, msk = img.to(CONFIG["device"]), msk.to(CONFIG["device"])
                out = model(img)
                pred = (torch.sigmoid(out) > 0.5).float()
                inter = (pred * msk).sum()
                d = (2. * inter) / (pred.sum() + msk.sum() + 1e-8)
                val_dice += d.item()
        
        val_dice /= len(val_loader)
        
        # Reporting
        if (epoch+1) % 10 == 0 or DEBUG:
            print(f"  Ep {epoch+1}/{CONFIG['epochs']} | Loss: {train_loss/len(train_loader):.4f} | Val Dice: {val_dice:.4f}")
            
        # Early Stopping
        if val_dice > best_dice:
            best_dice = val_dice
            torch.save(model.state_dict(), save_path)
            patience = 0
        else:
            patience += 1
            
        if patience >= CONFIG["early_stopping_patience"]:
            print(f"  Early stopping at epoch {epoch+1}")
            break
            
    # Ensure a file exists even if training fails or is skipped in debug
    if not save_path.exists():
        torch.save(model.state_dict(), save_path)
        
    return save_path

def optimize_and_test(model_path, val_ds, test_ds):
    """
    1. Find best params (threshold + post-proc) on VAL set (3D Patient Dice)
    2. Apply those params to TEST set (3D Patient Dice)
    """
    model = AttentionUNet().to(CONFIG["device"])
    model.load_state_dict(torch.load(model_path))
    model.eval()
    
    # --- Helper to calculate 3D Dice for a dataset ---
    def get_3d_dice(dataset, thresh, do_pp):
        patient_data = {}
        for s in dataset.samples:
            if s['pid'] not in patient_data: patient_data[s['pid']] = []
            patient_data[s['pid']].append(s)
            
        scores = []
        with torch.no_grad():
            for pid, samples in patient_data.items():
                vol_pred, vol_gt = [], []
                for s in samples:
                    # Load original data
                    img_vol = nib.load(s['img_path']).get_fdata()
                    msk_vol = nib.load(s['msk_path']).get_fdata()
                    z = s['slice_idx']
                    img = img_vol[:,:,z]
                    msk = msk_vol[:,:,z]
                    
                    # Preprocess
                    p1, p99 = np.percentile(img, 1), np.percentile(img, 99)
                    img = np.clip(img, p1, p99)
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    
                    # Tensorize
                    img_t = torch.from_numpy(img.T).float().unsqueeze(0).unsqueeze(0).to(CONFIG["device"])
                    img_t = TF.resize(img_t, [CONFIG["image_size"], CONFIG["image_size"]], interpolation=T.InterpolationMode.BILINEAR, antialias=True)
                    
                    # Predict
                    logits = model(img_t)
                    pred = (torch.sigmoid(logits) > thresh).float().cpu().numpy().squeeze()
                    
                    if do_pp:
                        pred = keep_largest_component(pred)
                    
                    # GT (Resize to match model output 256x256)
                    msk_t = torch.from_numpy(msk.T).float().unsqueeze(0).unsqueeze(0)
                    msk_t = TF.resize(msk_t, [CONFIG["image_size"], CONFIG["image_size"]], interpolation=T.InterpolationMode.NEAREST, antialias=True)
                    gt = msk_t.numpy().squeeze()
                    gt = (gt > 0).astype(np.float32)
                    
                    vol_pred.append(pred)
                    vol_gt.append(gt)
                
                # Calculate 3D Dice for this patient
                vp, vg = np.array(vol_pred), np.array(vol_gt)
                dice = (2. * np.sum(vp * vg)) / (np.sum(vp) + np.sum(vg) + 1e-8)
                scores.append(dice)
        return np.mean(scores), np.std(scores)

    # --- 1. Optimize on Validation ---
    print("  Optimizing parameters on Validation Set...")
    best_val_score = -1
    best_params = {'thresh': 0.5, 'pp': False}
    
    # Grid search
    thresholds = [0.3, 0.4, 0.5, 0.6, 0.7]
    pp_options = [False, True]
    
    for th in thresholds:
        for pp in pp_options:
            score, _ = get_3d_dice(val_ds, th, pp)
            if score > best_val_score:
                best_val_score = score
                best_params = {'thresh': th, 'pp': pp}
    
    print(f"  Best Params found: Thresh={best_params['thresh']}, KeepLargest={best_params['pp']} (Val Dice: {best_val_score:.4f})")
    
    # --- 2. Evaluate on Test ---
    print("  Evaluating on Test Set with best params...")
    test_mean, test_std = get_3d_dice(test_ds, best_params['thresh'], best_params['pp'])
    
    return test_mean, test_std, best_params

# --- MAIN EXECUTION ---
full_df = pd.read_csv(CONFIG["manifest_path"])
final_results = []

for fold in range(5):
    # Garbage collection to be safe
    gc.collect()
    torch.cuda.empty_cache()
    
    # Split
    test_fold = fold
    val_fold = (fold + 1) % 5
    train_folds = [f for f in range(5) if f != test_fold and f != val_fold]
    
    print(f"\n================ FOLD {fold} ================")
    print(f"Train: {train_folds}, Val: {val_fold}, Test: {test_fold}")
    
    # Datasets
    df_train = full_df[full_df['fold'].isin(train_folds)]
    df_val = full_df[full_df['fold'] == val_fold]
    df_test = full_df[full_df['fold'] == test_fold]
    
    train_ds = OvaryDataset(CONFIG["data_root"], df_train, augment=True, debug=DEBUG)
    val_ds = OvaryDataset(CONFIG["data_root"], df_val, augment=False, debug=DEBUG)
    test_ds = OvaryDataset(CONFIG["data_root"], df_test, augment=False, debug=DEBUG)
    
    train_loader = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=CONFIG["num_workers"])
    val_loader = DataLoader(val_ds, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=CONFIG["num_workers"])
    
    # Train
    model_path = train_fold(train_loader, val_loader, fold)
    
    # Optimize & Test
    mean_dice, std_dice, params = optimize_and_test(model_path, val_ds, test_ds)
    
    print(f"RESULT FOLD {fold}: Test Dice = {mean_dice:.4f} ± {std_dice:.4f}")
    
    # Save incrementally
    final_results.append({
        'fold': fold,
        'test_dice': mean_dice,
        'test_std': std_dice,
        'best_thresh': params['thresh'],
        'use_pp': params['pp']
    })
    
    # Save CSV after every fold
    pd.DataFrame(final_results).to_csv(Path(CONFIG["save_dir"]) / "final_results.csv", index=False)

print("\n\n=== DONE ===")
print(pd.DataFrame(final_results))

Running in cuda mode.

Train: [2, 3, 4], Val: 1, Test: 0

--- Training Fold 0 ---
  Ep 10/100 | Loss: 1.2558 | Val Dice: 0.5020
  Ep 20/100 | Loss: 1.1939 | Val Dice: 0.4410
  Ep 30/100 | Loss: 1.1365 | Val Dice: 0.5854
  Ep 40/100 | Loss: 1.0916 | Val Dice: 0.5330
  Ep 50/100 | Loss: 1.0494 | Val Dice: 0.5420
  Early stopping at epoch 50


  model.load_state_dict(torch.load(model_path))


  Optimizing parameters on Validation Set...
  Best Params found: Thresh=0.7, KeepLargest=True (Val Dice: 0.5333)
  Evaluating on Test Set with best params...
RESULT FOLD 0: Test Dice = 0.5085 ± 0.0654

Train: [0, 3, 4], Val: 2, Test: 1

--- Training Fold 1 ---
  Ep 10/100 | Loss: 1.2867 | Val Dice: 0.3437
  Ep 20/100 | Loss: 1.2276 | Val Dice: 0.4673
  Ep 30/100 | Loss: 1.1854 | Val Dice: 0.4102
  Ep 40/100 | Loss: 1.1438 | Val Dice: 0.3138
  Ep 50/100 | Loss: 1.1076 | Val Dice: 0.2750
  Ep 60/100 | Loss: 1.0704 | Val Dice: 0.5322
  Early stopping at epoch 64


  model.load_state_dict(torch.load(model_path))


  Optimizing parameters on Validation Set...
  Best Params found: Thresh=0.5, KeepLargest=True (Val Dice: 0.4970)
  Evaluating on Test Set with best params...
RESULT FOLD 1: Test Dice = 0.4787 ± 0.3072

Train: [0, 1, 4], Val: 3, Test: 2

--- Training Fold 2 ---
  Ep 10/100 | Loss: 1.2999 | Val Dice: 0.0000
  Ep 20/100 | Loss: 1.2307 | Val Dice: 0.4392
  Ep 30/100 | Loss: 1.1737 | Val Dice: 0.4202
  Ep 40/100 | Loss: 1.1269 | Val Dice: 0.4493
  Ep 50/100 | Loss: 1.0911 | Val Dice: 0.4011
  Ep 60/100 | Loss: 1.0373 | Val Dice: 0.6394
  Ep 70/100 | Loss: 0.9964 | Val Dice: 0.5407
  Ep 80/100 | Loss: 0.9500 | Val Dice: 0.5884
  Early stopping at epoch 87
  Optimizing parameters on Validation Set...
  Best Params found: Thresh=0.6, KeepLargest=True (Val Dice: 0.6777)
  Evaluating on Test Set with best params...
RESULT FOLD 2: Test Dice = 0.4294 ± 0.2658

Train: [0, 1, 2], Val: 4, Test: 3

--- Training Fold 3 ---
  Ep 10/100 | Loss: 1.3354 | Val Dice: 0.0000
  Ep 20/100 | Loss: 1.2729 | Val 