<a href="https://colab.research.google.com/github/Bhargav021/BraTs-Challenge-6/blob/main/BraTS_OPtimizations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
# ===================================================================
# CELL 0: DRIVE MOUNT & ENVIRONMENT SETUP
# ===================================================================
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
!pip install torchio

Collecting torchio
  Downloading torchio-0.21.0-py3-none-any.whl.metadata (52 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/52.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.6/52.6 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting deprecated>=1.2 (from torchio)
  Downloading deprecated-1.3.1-py2.py3-none-any.whl.metadata (5.9 kB)
Collecting simpleitk!=2.0.*,!=2.1.1.1,>=1.3 (from torchio)
  Downloading simpleitk-2.5.3-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.4 kB)
Downloading torchio-0.21.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.9/193.9 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading deprecated-1.3.1-py2.py3-none-any.whl (11 kB)
Downloading simpleitk-2.5.3-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (52.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.6/5

# pre-processing again...

In [None]:
# =========================================================
#                  IMPORTS
# =========================================================
import os
import shutil
from pathlib import Path
import torch
import numpy as np
import SimpleITK as sitk
import torchio as tio
from scipy.ndimage import binary_dilation

# =========================================================
#               GLOBAL CONFIG
# =========================================================
MODALITIES = ("t1c", "t1n", "t2f", "t2w")
NYUL_TRAIN_PATIENTS = 50
LANDMARKS_PATH = "/content/nyul_landmarks.npy"

# =========================================================
#               BASIC I/O HELPERS
# =========================================================
def load_nifti(path):
    img = sitk.ReadImage(str(path))
    arr = sitk.GetArrayFromImage(img)
    print(f"DEBUG: Loaded NIfTI {path}, shape: {arr.shape}, dtype: {arr.dtype}")
    return img, arr

def save_pt(image_tensor, label_tensor, out_path):
    out = {"image": image_tensor.float(), "label": label_tensor.long()}
    torch.save(out, out_path)
    print(f"DEBUG: Saved {out_path}")

# =========================================================
#       RAS + RESAMPLE
# =========================================================
def force_ras(img_sitk):
    nii = tio.ScalarImage.from_sitk(img_sitk)
    canonical = tio.ToCanonical()(nii)
    print(f"DEBUG: force_ras done, shape: {sitk.GetArrayFromImage(canonical.as_sitk()).shape}")
    return canonical.as_sitk()

def resample_1mm(img_sitk, interp='linear'):
    nii = tio.ScalarImage.from_sitk(img_sitk)
    resampler = tio.Resample((1,1,1), image_interpolation=interp)
    out = resampler(nii)
    print(f"DEBUG: Resampled to 1mm, interpolation={interp}, shape: {sitk.GetArrayFromImage(out.as_sitk()).shape}")
    return out.as_sitk()

# =========================================================
#     BRAIN MASK USING OTSU OVER MULTIPLE MODALITIES
# =========================================================
def brain_mask_from_modalities(mod_np_list):
    stacked = np.stack(mod_np_list, axis=0)
    norm = stacked / (np.percentile(stacked, 99) + 1e-8)
    meanImg = norm.mean(0)
    sitk_img = sitk.GetImageFromArray(meanImg)
    mask = sitk.OtsuThreshold(sitk_img, 0, 1, 200)
    mask_np = sitk.GetArrayFromImage(mask).astype(np.uint8)
    print(f"DEBUG: Brain mask computed, shape: {mask_np.shape}, sum={mask_np.sum()}")
    return mask_np

# =========================================================
#    MASKED CLAHE (slice-wise)
# =========================================================
def clahe_slice_np(arr, mask, clip_limit=2.0, tile_grid=(8,8)):
    import cv2
    out = np.zeros_like(arr)
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid)
    for z in range(arr.shape[0]):
        sl = arr[z]
        mk = mask[z]
        if mk.sum() == 0:
            out[z] = sl
            continue
        sl_norm = sl - sl[mk == 1].min()
        if sl_norm[mk == 1].max() > 0:
            sl_norm = sl_norm / sl_norm[mk == 1].max()
        sl_uint8 = (sl_norm * 255).astype(np.uint8)
        enhanced = clahe.apply(sl_uint8)
        out[z] = enhanced.astype(np.float32) / 255.0
    print(f"DEBUG: CLAHE applied, shape: {arr.shape}")
    return out

# =========================================================
#        GRADIENT MAGNITUDE FOR 3D VOLUME
# =========================================================
def gradient_mag(volume):
    gx, gy, gz = np.gradient(volume.astype(np.float32))
    grad = np.sqrt(gx**2 + gy**2 + gz**2)
    print(f"DEBUG: Gradient magnitude computed, shape: {grad.shape}")
    return grad

# =========================================================
#        NYUL HISTOGRAM STANDARDIZATION
# =========================================================
def train_nyul_landmarks(patient_arrays):
    import tempfile
    temp_dir = tempfile.mkdtemp()
    subjects_paths = []
    print(f"DEBUG: Training Nyul landmarks on {len(patient_arrays)} patients...")
    try:
        for i, case in enumerate(patient_arrays):
            case_paths = []
            for j, arr in enumerate(case):
                path = os.path.join(temp_dir, f"patient{i}_mod{j}.nii.gz")
                img = sitk.GetImageFromArray(arr)
                sitk.WriteImage(img, path)
                case_paths.append(path)
            subjects_paths.append(case_paths)
        flat_paths = [p for case_paths in subjects_paths for p in case_paths]
        landmarks = tio.HistogramStandardization.train(flat_paths)
        np.save(LANDMARKS_PATH, landmarks)
        print(f"DEBUG: Saved Nyul landmarks → {LANDMARKS_PATH}")
        return LANDMARKS_PATH
    finally:
        shutil.rmtree(temp_dir)

def load_landmarks(path):
    landmarks_raw = np.load(path, allow_pickle=True)
    print(f"DEBUG: Raw landmarks type: {type(landmarks_raw)}, shape: {getattr(landmarks_raw, 'shape', 'N/A')}, contents: {landmarks_raw}")
    # TorchIO expects a dict {channel_name: landmarks_array}
    if isinstance(landmarks_raw, np.ndarray) and landmarks_raw.ndim == 1:
        # Convert to dict with keys m0..m3
        n_mods = len(MODALITIES)
        landmarks_dict = {f"m{i}": landmarks_raw for i in range(n_mods)}
        print(f"DEBUG: Converted 1D landmarks array to dict: keys={list(landmarks_dict.keys())}")
        return landmarks_dict
    elif isinstance(landmarks_raw, dict):
        return landmarks_raw
    else:
        raise ValueError(f"Unexpected landmarks type: {type(landmarks_raw)}")

def apply_nyul(mod_np_list, landmarks_path):
    landmarks = load_landmarks(landmarks_path)
    subj = tio.Subject(**{f"m{i}": tio.ScalarImage(tensor=mod_np_list[i][None]) for i in range(len(mod_np_list))})
    print(f"DEBUG: Applying Nyul. Subject keys: {list(subj.keys())}, landmarks type: {type(landmarks)}")
    # masking_method=None is safe; Otsu is not allowed in TorchIO >=0.21
    tf = tio.HistogramStandardization(landmarks=landmarks, masking_method=None)
    out = tf(subj)
    out_arrays = [out[f"m{i}"].data.numpy()[0] for i in range(len(mod_np_list))]
    return out_arrays

# =========================================================
#           Z-SCORE NORMALIZATION (MASKED)
# =========================================================
def masked_z_norm(img, mask):
    vals = img[mask == 1]
    if vals.size < 10:
        return img
    m, s = vals.mean(), vals.std()
    if s < 1e-6:
        return np.zeros_like(img)
    return (img - m) / s

# =========================================================
#         COMPUTE BOUNDING BOX WITH DILATION
# =========================================================
def compute_bbox(mask):
    coords = np.array(np.nonzero(mask))
    if coords.size == 0:
        D, H, W = mask.shape
        return slice(0,D), slice(0,H), slice(0,W)
    minc = coords.min(1)
    maxc = coords.max(1)
    return (slice(minc[0], maxc[0]+1),
            slice(minc[1], maxc[1]+1),
            slice(minc[2], maxc[2]+1))

# =========================================================
#                 MAIN PREPROCESS
# =========================================================
def preprocess_folder(input_root, output_root):
    input_root = Path(input_root)
    output_root = Path(output_root)
    output_root.mkdir(parents=True, exist_ok=True)

    cases = sorted([d for d in input_root.iterdir() if d.is_dir() and not d.name.startswith('.')])
    print(f"DEBUG: Found {len(cases)} cases")

    # ------------------------------
    # PREPARE NYUL LANDMARKS
    # ------------------------------
    if os.path.exists(LANDMARKS_PATH):
        print(f"DEBUG: Loading existing Nyul landmarks from {LANDMARKS_PATH}")
        landmarks_path = LANDMARKS_PATH
    else:
        print(f"DEBUG: Training Nyul landmarks on first {NYUL_TRAIN_PATIENTS} cases...")
        train_arrays = []
        for case_dir in cases[:NYUL_TRAIN_PATIENTS]:
            arr_list = []
            for mod in MODALITIES:
                _, arr = load_nifti(case_dir / f"{case_dir.name}-{mod}.nii.gz")
                arr_list.append(arr.astype(np.float32))
            train_arrays.append(arr_list)
        landmarks_path = train_nyul_landmarks(train_arrays)

    # ----------------------------------------------------
    # PROCESS EVERY CASE
    # ----------------------------------------------------
    for case_dir in cases:
        print(f"DEBUG: Processing {case_dir.name}")

        # ---- LOAD ----
        imgs_sitk = [load_nifti(case_dir / f"{case_dir.name}-{mod}.nii.gz")[0] for mod in MODALITIES]
        label_sitk, label_np = load_nifti(case_dir / f"{case_dir.name}-seg.nii.gz")

        # ---- RAS + RESAMPLE ----
        imgs_r = [resample_1mm(force_ras(x), "linear") for x in imgs_sitk]
        label_r = resample_1mm(force_ras(label_sitk), "nearest")
        label_np = sitk.GetArrayFromImage(label_r)
        print(f"DEBUG: After resampling, label shape: {label_np.shape}")

        # ---- Extract numpy ----
        mod_np = [sitk.GetArrayFromImage(x).astype(np.float32) for x in imgs_r]

        # ---- Brain mask ----
        mask = brain_mask_from_modalities(mod_np)
        mask = binary_dilation(mask, iterations=2).astype(np.uint8)

        # ---- CLAHE ----
        mod_np = [clahe_slice_np(m, mask) for m in mod_np]

        # ---- Nyul ----
        mod_np = apply_nyul(mod_np, landmarks_path)

        # ---- Masked Z-score ----
        mod_np = [masked_z_norm(m, mask) for m in mod_np]

        # ---- Gradient(T2w) ----
        t2w = mod_np[3]
        grad = gradient_mag(t2w)

        # ---- Final channel order ----
        final_modalities = [
            mod_np[0],  # t1c
            mod_np[1],  # t1n
            grad,       # gradient(T2w)
            mod_np[3]   # t2w
        ]

        # ---- Crop ----
        bbox = compute_bbox(mask)
        final_modalities = [m[bbox] for m in final_modalities]
        label_np_c = label_np[bbox]

        # ---- Save ----
        image_tensor = torch.tensor(np.stack(final_modalities, axis=0))
        label_tensor = torch.tensor(label_np_c)
        save_pt(image_tensor, label_tensor, output_root / f"{case_dir.name}.pt")

# =========================================================
#                 MAIN EXECUTION
# =========================================================
def main():
    INPUT = "/content/drive/MyDrive/BraTS Challenge 2025 - Task 6/BraTS2024-PED-Challenge-TrainingData/BraTS-PEDs2024_Training"
    OUTPUT = "/content/drive/MyDrive/BraTS Challenge 2025 - Task 6/BraTS2024-PED-Challenge-TrainingData/processed_pt"
    preprocess_folder(INPUT, OUTPUT)

if __name__ == "__main__":
    main()


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
DEBUG: Resampled to 1mm, interpolation=linear, shape: (155, 240, 240)
DEBUG: force_ras done, shape: (155, 240, 240)
DEBUG: Resampled to 1mm, interpolation=linear, shape: (155, 240, 240)
DEBUG: force_ras done, shape: (155, 240, 240)
DEBUG: Resampled to 1mm, interpolation=nearest, shape: (155, 240, 240)
DEBUG: After resampling, label shape: (155, 240, 240)
DEBUG: Brain mask computed, shape: (155, 240, 240), sum=2532467
DEBUG: CLAHE applied, shape: (155, 240, 240)
DEBUG: CLAHE applied, shape: (155, 240, 240)
DEBUG: CLAHE applied, shape: (155, 240, 240)
DEBUG: CLAHE applied, shape: (155, 240, 240)
DEBUG: Raw landmarks type: <class 'numpy.ndarray'>, shape: (13,), contents: [0.00000000e+00 4.90940774e-02 1.52195963e-01 2.89141189e-01
 4.92147924e-01 1.05982656e+00 1.76370030e+00 2.55459212e+00
 7.72693043e+00 2.04514365e+01 3.50247878e+01 5.95142056e+01
 1.00000000e+02]
DEBUG: Converted 1D landmarks array to dict: keys=['m0', '

RuntimeError: Exception thrown in SimpleITK ImageFileReader_Execute: /work/src/Code/IO/src/sitkImageReaderBase.cxx:91:
sitk::ERROR: The file "/content/drive/MyDrive/BraTS Challenge 2025 - Task 6/BraTS2024-PED-Challenge-TrainingData/BraTS-PEDs2024_Training/InputScans/InputScans-t1c.nii.gz" does not exist.

# Training

## Training env setup

In [None]:
# --- 1. SETUP KAGGLE ENVIRONMENT ---
# --- Install Dependencies ---
!pip uninstall -y monai torch torchvision torchaudio
!pip install "monai[nibabel,tqdm]>=1.3.0" --upgrade -q
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install torch-geometric scikit-image einops -q
!pip install torchio
!pip install kornia -q
print("Environment setup complete. Please RESTART THE SESSION now from the 'Run' menu.")

[0mFound existing installation: torch 2.8.0+cu126
Uninstalling torch-2.8.0+cu126:
  Successfully uninstalled torch-2.8.0+cu126
Found existing installation: torchvision 0.23.0+cu126
Uninstalling torchvision-0.23.0+cu126:
  Successfully uninstalled torchvision-0.23.0+cu126
Found existing installation: torchaudio 2.8.0+cu126
Uninstalling torchaudio-2.8.0+cu126:
  Successfully uninstalled torchaudio-2.8.0+cu126
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m89.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m899.7/899.7 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m594.3/594.3 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m126.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.0/88.0 MB[0m [31m29.1 MB/s[0m eta [36m0:00

## Optimized - v1

In [None]:
import os, glob, time, random, gc, warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from torch.utils.checkpoint import checkpoint
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.optim.lr_scheduler import OneCycleLR
import torchio as tio
import psutil
from collections import OrderedDict

warnings.filterwarnings("ignore")

# =============== OPTIMIZED Configuration ===============
DATA_DIR = "/content/drive/MyDrive/InputScans_Final"
MODEL_SAVE_PATH = "/content/drive/MyDrive/best_optimized_model_v2.pt"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

# OPTIMIZATION 1: Reduced model size & faster training
EPOCHS = 250
BATCH_SIZE = 2  # Increased from 3 (better GPU utilization)
GRADIENT_ACCUMULATION_STEPS = 3  # Reduced from 4 (effective batch = 12)
INITIAL_LR = 1e-3  # OPTIMIZATION 2: Higher LR (was 5e-4) - backed by research
WEIGHT_DECAY = 1e-4
EARLY_STOPPING_PATIENCE = 50
PATCHES_PER_SCAN = 2
TUMOR_SAMPLING_PROB = 0.85
WARMUP_EPOCHS = 5

IN_CHANNELS = 5
NUM_CLASSES = None
PATCH_SIZE = (128, 128, 128)
GRADIENT_CLIP_VAL = 12.0

# OPTIMIZATION 3: Aggressive RAM caching - utilize available 50GB
MAX_CACHE_SIZE_GB = 35  # Train cache: 35GB (was 10GB)
VAL_CACHE_SIZE_GB = 12  # Val cache: 12GB (was 5GB) - CACHE ENTIRE VALIDATION SET
NUM_WORKERS = 0
USE_CHANNELS_LAST = True
USE_GRADIENT_CHECKPOINTING = True

USE_DROPOUT = True
DROPOUT_RATE = 0.2

# =============== Utilities ===============
def get_memory_usage():
    return psutil.virtual_memory().percent

def memory_cleanup():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def get_num_groups(channels):
    for num_groups in [32, 16, 8, 4, 2, 1]:
        if channels % num_groups == 0:
            return min(num_groups, channels)
    return 1

def detect_num_classes(file_paths, check_n=20):
    observed = set()
    for p in file_paths[:min(check_n, len(file_paths))]:
        try:
            d = torch.load(p, map_location='cpu')
            observed |= set(torch.unique(d['label']).tolist())
        except:
            continue
    max_label = int(max(observed)) if observed else 4
    classes = max_label + 1
    print(f"Detected {classes} classes from labels: {sorted(observed)}")
    return classes

def compute_class_weights(file_paths, num_classes, sample_size=50):
    """Compute class weights based on inverse frequency"""
    print(f"Computing class weights from {min(sample_size, len(file_paths))} samples...")
    class_counts = np.zeros(num_classes)

    for p in tqdm(file_paths[:sample_size], desc="Analyzing class distribution"):
        try:
            d = torch.load(p, map_location='cpu')
            label = d['label']
            for c in range(num_classes):
                class_counts[c] += (label == c).sum().item()
        except:
            continue

    total = class_counts.sum()
    class_weights = total / (num_classes * class_counts + 1e-8)
    class_weights = class_weights / class_weights[0]
    class_weights = np.clip(class_weights, 0.1, 10.0)

    print(f"Class distribution: {class_counts}")
    print(f"Class weights: {class_weights}")

    return torch.FloatTensor(class_weights)

def pad_to_patch(img, lbl, patch_size):
    if lbl.ndim == 3:
        lbl = lbl.unsqueeze(0)

    pad_needed = []
    for dim_size, patch_dim in zip(img.shape[1:], patch_size):
        if dim_size < patch_dim:
            total_pad = patch_dim - dim_size
            pad_before = total_pad // 2
            pad_after = total_pad - pad_before
            pad_needed.extend([pad_before, pad_after])
        else:
            pad_needed.extend([0, 0])

    if any(p > 0 for p in pad_needed):
        pad_format = [pad_needed[4], pad_needed[5], pad_needed[2], pad_needed[3], pad_needed[0], pad_needed[1]]
        img = F.pad(img, pad_format, mode='constant', value=0)
        lbl = F.pad(lbl, pad_format, mode='constant', value=0)

    return img, lbl.squeeze(0)

def nnunet_normalization(image):
    """nnU-Net normalization"""
    for c in range(image.shape[0]):
        modality = image[c]
        mask = modality > 0
        if mask.sum() > 0:
            values = modality[mask]
            p1, p99 = torch.quantile(values, torch.tensor([0.005, 0.995], device=values.device))
            modality = torch.clamp(modality, p1, p99)

            mean_val = values.mean()
            std_val = values.std()
            if std_val > 0:
                modality[mask] = (modality[mask] - mean_val) / (std_val + 1e-8)
            image[c] = modality
    return image

# =============== Losses ===============
class GeneralizedDiceLoss(nn.Module):
    """Generalized Dice Loss - handles class imbalance"""
    def __init__(self, num_classes, class_weights=None, smooth=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.smooth = smooth
        self.class_weights = class_weights

    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)
        target_one_hot = F.one_hot(target, self.num_classes).permute(0, 4, 1, 2, 3).float()

        if self.class_weights is None:
            w = []
            for c in range(self.num_classes):
                class_sum = target_one_hot[:, c].sum()
                w.append(1.0 / (class_sum ** 2 + self.smooth))
            w = torch.stack(w).to(pred.device)
        else:
            w = self.class_weights.to(pred.device)

        dice_sum = 0
        for c in range(self.num_classes):
            pred_c = pred[:, c]
            target_c = target_one_hot[:, c]

            intersection = (pred_c * target_c).sum()
            cardinality = pred_c.sum() + target_c.sum()

            dice = (2.0 * intersection + self.smooth) / (cardinality + self.smooth)
            dice_sum += w[c] * dice

        gdl = 1.0 - (dice_sum / w.sum())
        return gdl

class FocalLoss(nn.Module):
    """Focal Loss for hard examples"""
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred, target):
        ce = F.cross_entropy(pred, target, weight=self.alpha, reduction='none')
        pt = torch.exp(-ce)
        focal = ((1 - pt) ** self.gamma) * ce
        return focal.mean()

class CombinedLoss(nn.Module):
    """Generalized Dice + Focal + CE for class imbalance"""
    def __init__(self, num_classes, class_weights,
                 gdl_weight=2.0, focal_weight=1.0, ce_weight=0.5):
        super().__init__()
        self.num_classes = num_classes
        self.gdl_weight = gdl_weight
        self.focal_weight = focal_weight
        self.ce_weight = ce_weight

        self.gdl = GeneralizedDiceLoss(num_classes, class_weights)
        self.focal = FocalLoss(alpha=class_weights, gamma=2.0)
        self.ce = nn.CrossEntropyLoss(weight=class_weights)

    def forward(self, pred, target):
        gdl = self.gdl(pred, target)
        focal = self.focal(pred, target)
        ce = self.ce(pred, target)

        total = self.gdl_weight * gdl + self.focal_weight * focal + self.ce_weight * ce

        return total, {
            'gdl': gdl.item(),
            'focal': focal.item(),
            'ce': ce.item()
        }

# =============== Attention Modules ===============
class ChannelAttention(nn.Module):
    """Lightweight channel attention"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        while channels % reduction != 0 and reduction > 1:
            reduction //= 2

        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.max_pool = nn.AdaptiveMaxPool3d(1)

        self.fc = nn.Sequential(
            nn.Conv3d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(channels // reduction, channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = self.sigmoid(avg_out + max_out)
        return x * out

class SpatialAttention(nn.Module):
    """Lightweight spatial attention"""
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv3d(2, 1, kernel_size=7, padding=3, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.sigmoid(self.conv(out))
        return x * out

# =============== OPTIMIZATION 4: Smaller Model Architecture ===============
# Base filters reduced from 40 to 28 (compromise between 24-32 range from research)
class EnhancedUNet(nn.Module):
    """Optimized 3D U-Net with attention - SMALLER model (28 base filters instead of 40)"""
    def __init__(self, in_channels, num_classes, base_filters=28):
        super().__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes

        # Encoder with reduced filters
        self.enc1 = self._encoder_block(in_channels, base_filters)
        self.enc2 = self._encoder_block(base_filters, base_filters * 2)
        self.enc3 = self._encoder_block(base_filters * 2, base_filters * 4)
        self.enc4 = self._encoder_block(base_filters * 4, base_filters * 8)

        # Bottleneck
        self.bottleneck = self._encoder_block(base_filters * 8, base_filters * 16)

        # Decoder (input channels = upsampled + skip connection)
        self.dec4 = self._decoder_block(base_filters * 16 + base_filters * 8, base_filters * 8)
        self.dec3 = self._decoder_block(base_filters * 8 + base_filters * 4, base_filters * 4)
        self.dec2 = self._decoder_block(base_filters * 4 + base_filters * 2, base_filters * 2)
        self.dec1 = self._decoder_block(base_filters * 2 + base_filters, base_filters)

        # Deep supervision outputs
        self.deep_sup4 = nn.Conv3d(base_filters * 8, num_classes, 1)
        self.deep_sup3 = nn.Conv3d(base_filters * 4, num_classes, 1)
        self.deep_sup2 = nn.Conv3d(base_filters * 2, num_classes, 1)

        # Final output
        self.out_conv = nn.Conv3d(base_filters, num_classes, 1)

        self.pool = nn.MaxPool3d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)

    def _encoder_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.GroupNorm(get_num_groups(out_ch), out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout3d(DROPOUT_RATE) if USE_DROPOUT else nn.Identity(),
            nn.Conv3d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.GroupNorm(get_num_groups(out_ch), out_ch),
            nn.ReLU(inplace=True),
            ChannelAttention(out_ch),
            SpatialAttention()
        )

    def _decoder_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.GroupNorm(get_num_groups(out_ch), out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout3d(DROPOUT_RATE) if USE_DROPOUT else nn.Identity(),
            nn.Conv3d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.GroupNorm(get_num_groups(out_ch), out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        # Bottleneck
        b = self.bottleneck(self.pool(e4))

        # Decoder with skip connections
        d4 = self.dec4(torch.cat([self.up(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up(d2), e1], dim=1))

        # Deep supervision
        if self.training:
            ds4 = F.interpolate(self.deep_sup4(d4), x.shape[2:], mode='trilinear', align_corners=False)
            ds3 = F.interpolate(self.deep_sup3(d3), x.shape[2:], mode='trilinear', align_corners=False)
            ds2 = F.interpolate(self.deep_sup2(d2), x.shape[2:], mode='trilinear', align_corners=False)
            return [ds4, ds3, ds2, self.out_conv(d1)]
        else:
            return self.out_conv(d1)

# =============== OPTIMIZATION 5: Enhanced Dataset with Aggressive Caching ===============
class OptimizedDataset(Dataset):
    """Dataset with AGGRESSIVE RAM caching - utilizes up to 50GB RAM"""
    def __init__(self, file_paths, patch_size, patches_per_scan=1,
                 is_training=True, tumor_prob=0.5, max_cache_gb=35):
        self.file_paths = file_paths
        self.patch_size = patch_size
        self.patches_per_scan = patches_per_scan
        self.is_training = is_training
        self.tumor_prob = tumor_prob
        self.max_cache_gb = max_cache_gb

        # OPTIMIZATION: Aggressive caching - cache as much as possible
        self.cache = {}
        self.cache_indices = []
        self._build_cache()

        self.transforms = None
        self.epoch = 0

    def _build_cache(self):
        """Cache data aggressively in RAM"""
        print(f"Building cache (max {self.max_cache_gb}GB)...")
        cache_size_bytes = 0
        max_cache_bytes = self.max_cache_gb * 1024**3

        for idx, path in enumerate(tqdm(self.file_paths, desc="Caching data")):
            try:
                data = torch.load(path, map_location='cpu')
                img = data['image'].float()
                lbl = data['label'].long()

                # Estimate size
                item_size = img.element_size() * img.nelement() + lbl.element_size() * lbl.nelement()

                if cache_size_bytes + item_size < max_cache_bytes:
                    self.cache[idx] = {'image': img, 'label': lbl}
                    self.cache_indices.append(idx)
                    cache_size_bytes += item_size
                else:
                    break
            except:
                continue

        cache_size_gb = cache_size_bytes / 1024**3
        print(f"Cached {len(self.cache_indices)}/{len(self.file_paths)} samples ({cache_size_gb:.2f}GB)")

        # Prefetch remaining file sizes for memory estimates
        self.non_cached_indices = [i for i in range(len(self.file_paths)) if i not in self.cache_indices]

    def update_transforms(self, epoch):
        """Dynamic augmentation strength based on epoch"""
        self.epoch = epoch

        # Stronger augmentation in early epochs
        if epoch < 50:
            aug_strength = 0.3
        elif epoch < 100:
            aug_strength = 0.2
        else:
            aug_strength = 0.1

        if self.is_training:
            self.transforms = tio.Compose([
                tio.RandomFlip(axes=(0, 1, 2), p=0.5),
                tio.RandomAffine(
                    scales=(0.9, 1.1),
                    degrees=15,
                    translation=10,
                    p=0.5
                ),
                tio.RandomElasticDeformation(
                    num_control_points=7,
                    max_displacement=aug_strength * 10,
                    p=0.3
                ),
                tio.RandomGamma(log_gamma=(-0.3, 0.3), p=0.3),
            ])

    def __len__(self):
        return len(self.file_paths) * self.patches_per_scan

    def __getitem__(self, idx):
        scan_idx = idx // self.patches_per_scan

        # Load from cache or disk
        if scan_idx in self.cache_indices:
            data = self.cache[scan_idx]
            img = data['image'].clone()
            lbl = data['label'].clone()
        else:
            try:
                data = torch.load(self.file_paths[scan_idx], map_location='cpu')
                img = data['image'].float()
                lbl = data['label'].long()
            except:
                return self.__getitem__((idx + 1) % len(self))

        # Pad if needed
        img, lbl = pad_to_patch(img, lbl, self.patch_size)

        # Normalize
        img = nnunet_normalization(img)

        # Extract patch
        if self.is_training:
            # Tumor-focused sampling
            if random.random() < self.tumor_prob and (lbl > 0).any():
                tumor_coords = torch.where(lbl > 0)
                idx_choice = random.randint(0, len(tumor_coords[0]) - 1)
                center = [tumor_coords[i][idx_choice].item() for i in range(3)]
            else:
                center = [random.randint(ps//2, s - ps//2)
                         for ps, s in zip(self.patch_size, img.shape[1:])]

            starts = [max(0, c - ps//2) for c, ps in zip(center, self.patch_size)]
            ends = [min(s + ps, img.shape[i+1]) for i, (s, ps) in enumerate(zip(starts, self.patch_size))]

            img = img[:, starts[0]:ends[0], starts[1]:ends[1], starts[2]:ends[2]]
            lbl = lbl[starts[0]:ends[0], starts[1]:ends[1], starts[2]:ends[2]]
        else:
            # Center crop for validation
            starts = [(s - ps) // 2 for s, ps in zip(img.shape[1:], self.patch_size)]
            ends = [min(s + ps, img.shape[i+1]) for i, (s, ps) in enumerate(zip(starts, self.patch_size))]
            img = img[:, starts[0]:ends[0], starts[1]:ends[1], starts[2]:ends[2]]
            lbl = lbl[starts[0]:ends[0], starts[1]:ends[1], starts[2]:ends[2]]

        # CRITICAL FIX: Ensure patch is exactly patch_size by padding if needed
        # This handles cases where cropped volumes are smaller than patch_size
        current_size = img.shape[1:]  # (D, H, W)
        pad_needed = []
        for i, (curr, target) in enumerate(zip(current_size, self.patch_size)):
            if curr < target:
                total_pad = target - curr
                pad_before = total_pad // 2
                pad_after = total_pad - pad_before
                pad_needed.extend([pad_before, pad_after])
            else:
                pad_needed.extend([0, 0])

        if any(p > 0 for p in pad_needed):
            # Pad format for F.pad is (W_left, W_right, H_left, H_right, D_left, D_right)
            pad_format = [pad_needed[4], pad_needed[5], pad_needed[2], pad_needed[3], pad_needed[0], pad_needed[1]]
            img = F.pad(img, pad_format, mode='constant', value=0)
            lbl_4d = lbl.unsqueeze(0)  # Add channel dim for padding
            lbl_4d = F.pad(lbl_4d, pad_format, mode='constant', value=0)
            lbl = lbl_4d.squeeze(0)

        # Apply augmentation
        if self.transforms is not None:
            subject = tio.Subject(
                image=tio.ScalarImage(tensor=img),  # Already (C, D, H, W) - 4D
                label=tio.LabelMap(tensor=lbl.unsqueeze(0))  # Make it (1, D, H, W) - 4D
            )
            subject = self.transforms(subject)
            img = subject.image.tensor  # Keep as (C, D, H, W)
            lbl = subject.label.tensor.squeeze(0).long()  # Back to (D, H, W)

        return img, lbl

# =============== Metrics ===============
def calculate_metrics(pred, target, num_classes):
    """Calculate Dice scores"""
    pred = torch.argmax(pred, dim=1)
    dice_scores = []

    for c in range(1, num_classes):
        pred_c = (pred == c)
        target_c = (target == c)

        intersection = (pred_c & target_c).sum().float()
        union = pred_c.sum().float() + target_c.sum().float()

        if union > 0:
            dice = (2.0 * intersection) / (union + 1e-8)
            dice_scores.append(dice.item())
        else:
            dice_scores.append(0.0)

    return {'dice': np.mean(dice_scores)}

# =============== Training ===============
def train_epoch(model, train_loader, optimizer, loss_fn, scheduler, device, scaler, epoch, accumulation_steps):
    model.train()
    total_loss = 0
    gdl_sum = 0
    focal_sum = 0
    ce_sum = 0
    num_batches = 0

    optimizer.zero_grad()

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for batch_idx, (inputs, targets) in enumerate(pbar):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if USE_CHANNELS_LAST:
            inputs = inputs.to(memory_format=torch.channels_last_3d)

        with autocast(device_type=device.type):
            outputs = model(inputs)

            if isinstance(outputs, list):
                weights = [0.5, 0.75, 1.0, 1.0]
                loss = 0
                for w, o in zip(weights, outputs):
                    l, _ = loss_fn(o, targets)
                    loss += w * l
                loss = loss / sum(weights)
                _, components = loss_fn(outputs[-1], targets)
            else:
                loss, components = loss_fn(outputs, targets)

            loss = loss / accumulation_steps

        if torch.isfinite(loss):
            scaler.scale(loss).backward()

            if (batch_idx + 1) % accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VAL)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()  # OneCycleLR steps every batch

            total_loss += loss.item() * accumulation_steps
            gdl_sum += components['gdl']
            focal_sum += components['focal']
            ce_sum += components['ce']
            num_batches += 1

            pbar.set_postfix({
                'loss': f'{loss.item() * accumulation_steps:.4f}',
                'gdl': f'{components["gdl"]:.3f}',
                'lr': f'{scheduler.get_last_lr()[0]:.6f}'
            })

        if batch_idx % 50 == 0:
            memory_cleanup()

    return total_loss / max(num_batches, 1), {
        'gdl': gdl_sum / max(num_batches, 1),
        'focal': focal_sum / max(num_batches, 1),
        'ce': ce_sum / max(num_batches, 1)
    }

@torch.no_grad()
def validate(model, val_loader, loss_fn, device, num_classes):
    model.eval()
    total_loss = 0
    all_dice = []
    num_batches = 0

    for inputs, targets in tqdm(val_loader, desc="Validation"):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if USE_CHANNELS_LAST:
            inputs = inputs.to(memory_format=torch.channels_last_3d)

        with autocast(device_type=device.type):
            outputs = model(inputs)
            if isinstance(outputs, list):
                outputs = outputs[-1]
            loss, _ = loss_fn(outputs, targets)
            outputs = F.softmax(outputs, dim=1)

        total_loss += loss.item()
        metrics = calculate_metrics(outputs, targets, num_classes)
        all_dice.append(metrics['dice'])
        num_batches += 1

    return {
        'loss': total_loss / max(num_batches, 1),
        'dice': np.mean(all_dice)
    }

# =============== Main ===============
def main():
    print("=" * 80)
    print(f"PyTorch Version: {torch.__version__}")
    print(f"Device: {DEVICE}")
    if torch.cuda.is_available():
        print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
        print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")

    all_files = sorted(glob.glob(os.path.join(DATA_DIR, "*.pt")))
    print(f"Found {len(all_files)} files")

    global NUM_CLASSES
    NUM_CLASSES = detect_num_classes(all_files)

    train_files, val_files = train_test_split(all_files, test_size=0.2, random_state=SEED)
    print(f"Train: {len(train_files)}, Val: {len(val_files)}")

    # Compute class weights
    class_weights = compute_class_weights(train_files, NUM_CLASSES, sample_size=50)
    print(f"\nClass Weights: {class_weights}")

    # OPTIMIZATION: Aggressive caching with larger limits
    train_dataset = OptimizedDataset(
        train_files, PATCH_SIZE, PATCHES_PER_SCAN,
        is_training=True, tumor_prob=TUMOR_SAMPLING_PROB,
        max_cache_gb=MAX_CACHE_SIZE_GB
    )

    # OPTIMIZATION: Cache entire validation set
    val_dataset = OptimizedDataset(
        val_files, PATCH_SIZE, patches_per_scan=1,
        is_training=False, max_cache_gb=VAL_CACHE_SIZE_GB
    )

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=False,
        persistent_workers=False
    )

    val_loader = DataLoader(
        val_dataset, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=False
    )

    print(f"Training batches per epoch: {len(train_loader)}")
    print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")

    # OPTIMIZATION: Smaller model (28 base filters instead of 40)
    model = EnhancedUNet(IN_CHANNELS, NUM_CLASSES, base_filters=28).to(DEVICE)

    if USE_CHANNELS_LAST:
        model = model.to(memory_format=torch.channels_last_3d)

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model parameters: {total_params/1e6:.2f}M")

    # OPTIMIZATION: AdamW instead of Adam (better performance in research)
    optimizer = optim.AdamW(model.parameters(), lr=INITIAL_LR, weight_decay=WEIGHT_DECAY)

    # OPTIMIZATION: OneCycleLR - proven better for medical imaging than CosineAnnealing
    total_steps = len(train_loader) * EPOCHS
    scheduler = OneCycleLR(
        optimizer,
        max_lr=INITIAL_LR,
        total_steps=total_steps,
        pct_start=0.3,  # 30% warmup
        anneal_strategy='cos',
        div_factor=25.0,  # initial_lr = max_lr/25
        final_div_factor=10000.0
    )

    # Combined loss
    loss_fn = CombinedLoss(
        NUM_CLASSES,
        class_weights=class_weights.to(DEVICE),
        gdl_weight=2.0,
        focal_weight=1.0,
        ce_weight=0.5
    ).to(DEVICE)

    scaler = GradScaler(init_scale=1024, growth_interval=2000)

    best_dice = 0.0
    patience = 0

    # print(f"\n{'='*80}")
    # print("OPTIMIZATIONS APPLIED:")
    # print("1. Reduced model size: 28 base filters (was 40) - ~9M params instead of 17M")
    # print("2. Higher learning rate: 1e-3 (was 5e-4) with OneCycleLR scheduler")
    # print("3. Aggressive RAM caching: 35GB train + 12GB val (was 10GB + 5GB)")
    # print("4. Validation set fully cached in RAM for faster validation")
    # print("5. AdamW optimizer instead of Adam (research-backed improvement)")
    # print("6. Increased batch size to 4 (from 3) for better GPU utilization")
    # print("7. OneCycleLR scheduler (better than CosineAnnealing for medical imaging)")
    # print(f"{'='*80}\n")

    print(f"Starting training for {EPOCHS} epochs...")
    # print(f"Expected: Faster training (~25-30 min/epoch) with comparable accuracy")

    for epoch in range(1, EPOCHS + 1):
        train_dataset.update_transforms(epoch)

        train_loss, loss_components = train_epoch(
            model, train_loader, optimizer, loss_fn,
            scheduler, DEVICE, scaler, epoch, GRADIENT_ACCUMULATION_STEPS
        )

        # Validate every epoch
        val_results = validate(model, val_loader, loss_fn, DEVICE, NUM_CLASSES)

        print(f"\nEpoch {epoch:3d}/{EPOCHS}")
        print(f"Train - Loss: {train_loss:.4f} | GDL: {loss_components['gdl']:.4f}")
        print(f"Val   - Loss: {val_results['loss']:.4f} | Dice: {val_results['dice']:.4f}")
        print(f"LR: {scheduler.get_last_lr()[0]:.6f}")

        if val_results['dice'] > best_dice:
            best_dice = val_results['dice']
            patience = 0

            checkpoint = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'epoch': epoch,
                'best_dice': best_dice,
                'num_classes': NUM_CLASSES,
                'class_weights': class_weights
            }
            torch.save(checkpoint, MODEL_SAVE_PATH)
            print(f"✓ New best! Dice: {best_dice:.4f}")

            if best_dice >= 0.75:
                print(f"Target 0.75 achieved!")
        else:
            patience += 1
            print(f"Patience: {patience}/{EARLY_STOPPING_PATIENCE}")

        if patience >= EARLY_STOPPING_PATIENCE:
            print(f"Early stopping at epoch {epoch}")
            break

        if epoch % 10 == 0:
            memory_cleanup()

    print(f"\n{'='*80}")
    print(f"Training completed! Best Dice: {best_dice:.4f}")
    print(f"Model saved to: {MODEL_SAVE_PATH}")
    print(f"{'='*80}")

if __name__ == '__main__':
    main()

PyTorch Version: 2.5.1+cu121
Device: cuda
CUDA Device: NVIDIA L4
CUDA Memory: 22.2GB
Found 260 files
Detected 5 classes from labels: [0, 1, 2, 3, 4]
Train: 208, Val: 52
Computing class weights from 50 samples...


Analyzing class distribution: 100%|██████████| 50/50 [02:55<00:00,  3.52s/it]


Class distribution: [4.11637563e+08 2.74323000e+05 1.93015500e+06 7.19530000e+04
 1.96406000e+05]
Class weights: [ 1. 10. 10. 10. 10.]

Class Weights: tensor([ 1., 10., 10., 10., 10.])
Building cache (max 35GB)...


Caching data:  78%|███████▊  | 162/208 [06:16<01:47,  2.33s/it]


Cached 162/208 samples (34.97GB)
Building cache (max 12GB)...


Caching data: 100%|██████████| 52/52 [02:20<00:00,  2.71s/it]


Cached 52/52 samples (11.54GB)
Training batches per epoch: 104
Effective batch size: 12
Model parameters: 18.06M
Starting training for 250 epochs...


Epoch 1:   1%|          | 1/104 [00:58<1:40:59, 58.83s/it, loss=4.3511, gdl=0.985, lr=0.000040]


OutOfMemoryError: CUDA out of memory. Tried to allocate 896.00 MiB. GPU 0 has a total capacity of 22.16 GiB of which 493.38 MiB is free. Process 35714 has 21.67 GiB memory in use. Of the allocated memory 20.08 GiB is allocated by PyTorch, and 1.37 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Optimized - v2

In [None]:
import os, glob, time, random, gc, warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from torch.utils.checkpoint import checkpoint
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.optim.lr_scheduler import OneCycleLR
import torchio as tio
import psutil
from collections import OrderedDict

warnings.filterwarnings("ignore")

# =============== OPTIMIZED Configuration ===============
DATA_DIR = "/content/drive/MyDrive/InputScans_Final"
MODEL_SAVE_PATH = "/content/drive/MyDrive/best_optimized_model_v2.pt"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

# OPTIMIZATION 1: Minimal model + larger batch size
EPOCHS = 250
BATCH_SIZE = 8  # Increased from 4 (smaller model = more GPU room)
GRADIENT_ACCUMULATION_STEPS = 2  # Reduced from 3 (effective batch = 12)
INITIAL_LR = 1e-3  # OPTIMIZATION 2: Higher LR (was 5e-4)
WEIGHT_DECAY = 1e-4
EARLY_STOPPING_PATIENCE = 50
PATCHES_PER_SCAN = 2
TUMOR_SAMPLING_PROB = 0.85
WARMUP_EPOCHS = 5

IN_CHANNELS = 5
NUM_CLASSES = None
PATCH_SIZE = (128, 128, 128)
GRADIENT_CLIP_VAL = 12.0

# OPTIMIZATION 3: Aggressive RAM caching - utilize available 50GB
MAX_CACHE_SIZE_GB = 35  # Train cache: 35GB (was 10GB)
VAL_CACHE_SIZE_GB = 12  # Val cache: 12GB (was 5GB) - CACHE ENTIRE VALIDATION SET
NUM_WORKERS = 0
USE_CHANNELS_LAST = True
USE_GRADIENT_CHECKPOINTING = True

USE_DROPOUT = True
DROPOUT_RATE = 0.1  # Light dropout for speed

# =============== Utilities ===============
def get_memory_usage():
    return psutil.virtual_memory().percent

def memory_cleanup():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def get_num_groups(channels):
    for num_groups in [32, 16, 8, 4, 2, 1]:
        if channels % num_groups == 0:
            return min(num_groups, channels)
    return 1

def detect_num_classes(file_paths, check_n=20):
    observed = set()
    for p in file_paths[:min(check_n, len(file_paths))]:
        try:
            d = torch.load(p, map_location='cpu')
            observed |= set(torch.unique(d['label']).tolist())
        except:
            continue
    max_label = int(max(observed)) if observed else 4
    classes = max_label + 1
    print(f"Detected {classes} classes from labels: {sorted(observed)}")
    return classes

def compute_class_weights(file_paths, num_classes, sample_size=50):
    """Compute class weights based on inverse frequency"""
    print(f"Computing class weights from {min(sample_size, len(file_paths))} samples...")
    class_counts = np.zeros(num_classes)

    for p in tqdm(file_paths[:sample_size], desc="Analyzing class distribution"):
        try:
            d = torch.load(p, map_location='cpu')
            label = d['label']
            for c in range(num_classes):
                class_counts[c] += (label == c).sum().item()
        except:
            continue

    total = class_counts.sum()
    class_weights = total / (num_classes * class_counts + 1e-8)
    class_weights = class_weights / class_weights[0]
    class_weights = np.clip(class_weights, 0.1, 10.0)

    print(f"Class distribution: {class_counts}")
    print(f"Class weights: {class_weights}")

    return torch.FloatTensor(class_weights)

def pad_to_patch(img, lbl, patch_size):
    if lbl.ndim == 3:
        lbl = lbl.unsqueeze(0)

    pad_needed = []
    for dim_size, patch_dim in zip(img.shape[1:], patch_size):
        if dim_size < patch_dim:
            total_pad = patch_dim - dim_size
            pad_before = total_pad // 2
            pad_after = total_pad - pad_before
            pad_needed.extend([pad_before, pad_after])
        else:
            pad_needed.extend([0, 0])

    if any(p > 0 for p in pad_needed):
        pad_format = [pad_needed[4], pad_needed[5], pad_needed[2], pad_needed[3], pad_needed[0], pad_needed[1]]
        img = F.pad(img, pad_format, mode='constant', value=0)
        lbl = F.pad(lbl, pad_format, mode='constant', value=0)

    return img, lbl.squeeze(0)

def nnunet_normalization(image):
    """nnU-Net normalization"""
    for c in range(image.shape[0]):
        modality = image[c]
        mask = modality > 0
        if mask.sum() > 0:
            values = modality[mask]
            p1, p99 = torch.quantile(values, torch.tensor([0.005, 0.995], device=values.device))
            modality = torch.clamp(modality, p1, p99)

            mean_val = values.mean()
            std_val = values.std()
            if std_val > 0:
                modality[mask] = (modality[mask] - mean_val) / (std_val + 1e-8)
            image[c] = modality
    return image

# =============== Losses ===============
class GeneralizedDiceLoss(nn.Module):
    """Generalized Dice Loss - handles class imbalance"""
    def __init__(self, num_classes, class_weights=None, smooth=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.smooth = smooth
        self.class_weights = class_weights

    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)
        target_one_hot = F.one_hot(target, self.num_classes).permute(0, 4, 1, 2, 3).float()

        if self.class_weights is None:
            w = []
            for c in range(self.num_classes):
                class_sum = target_one_hot[:, c].sum()
                w.append(1.0 / (class_sum ** 2 + self.smooth))
            w = torch.stack(w).to(pred.device)
        else:
            w = self.class_weights.to(pred.device)

        dice_sum = 0
        for c in range(self.num_classes):
            pred_c = pred[:, c]
            target_c = target_one_hot[:, c]

            intersection = (pred_c * target_c).sum()
            cardinality = pred_c.sum() + target_c.sum()

            dice = (2.0 * intersection + self.smooth) / (cardinality + self.smooth)
            dice_sum += w[c] * dice

        gdl = 1.0 - (dice_sum / w.sum())
        return gdl

class FocalLoss(nn.Module):
    """Focal Loss for hard examples"""
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred, target):
        ce = F.cross_entropy(pred, target, weight=self.alpha, reduction='none')
        pt = torch.exp(-ce)
        focal = ((1 - pt) ** self.gamma) * ce
        return focal.mean()

class CombinedLoss(nn.Module):
    """Generalized Dice + Focal + CE for class imbalance"""
    def __init__(self, num_classes, class_weights,
                 gdl_weight=2.0, focal_weight=1.0, ce_weight=0.5):
        super().__init__()
        self.num_classes = num_classes
        self.gdl_weight = gdl_weight
        self.focal_weight = focal_weight
        self.ce_weight = ce_weight

        self.gdl = GeneralizedDiceLoss(num_classes, class_weights)
        self.focal = FocalLoss(alpha=class_weights, gamma=2.0)
        self.ce = nn.CrossEntropyLoss(weight=class_weights)

    def forward(self, pred, target):
        gdl = self.gdl(pred, target)
        focal = self.focal(pred, target)
        ce = self.ce(pred, target)

        total = self.gdl_weight * gdl + self.focal_weight * focal + self.ce_weight * ce

        return total, {
            'gdl': gdl.item(),
            'focal': focal.item(),
            'ce': ce.item()
        }

# =============== TRULY OPTIMIZED: Minimal, Fast 3D U-Net ===============
# NO attention, NO deep supervision, base_filters=16
# Target: ~3-5M parameters, <10s/batch, <12GB GPU
class LeanUNet(nn.Module):
    """Minimal 3D U-Net - optimized for speed and memory efficiency"""
    def __init__(self, in_channels, num_classes, base_filters=16):
        super().__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes

        # Encoder - single conv per block for speed
        self.enc1 = self._make_layer(in_channels, base_filters)
        self.enc2 = self._make_layer(base_filters, base_filters * 2)
        self.enc3 = self._make_layer(base_filters * 2, base_filters * 4)
        self.enc4 = self._make_layer(base_filters * 4, base_filters * 8)

        # Bottleneck
        self.bottleneck = self._make_layer(base_filters * 8, base_filters * 16)

        # Decoder - minimal channels
        self.dec4 = self._make_layer(base_filters * 16 + base_filters * 8, base_filters * 8)
        self.dec3 = self._make_layer(base_filters * 8 + base_filters * 4, base_filters * 4)
        self.dec2 = self._make_layer(base_filters * 4 + base_filters * 2, base_filters * 2)
        self.dec1 = self._make_layer(base_filters * 2 + base_filters, base_filters)

        # Single output (no deep supervision)
        self.out_conv = nn.Conv3d(base_filters, num_classes, 1)

        self.pool = nn.MaxPool3d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)

    def _make_layer(self, in_ch, out_ch):
        """Single conv block - fast and efficient"""
        return nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.InstanceNorm3d(out_ch),  # Faster than GroupNorm
            nn.ReLU(inplace=True),
            nn.Dropout3d(0.1)  # Light dropout
        )

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        # Bottleneck
        b = self.bottleneck(self.pool(e4))

        # Decoder with skip connections
        d4 = self.dec4(torch.cat([self.up(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up(d2), e1], dim=1))

        return self.out_conv(d1)

# =============== OPTIMIZATION 5: Enhanced Dataset with Aggressive Caching ===============
class OptimizedDataset(Dataset):
    """Dataset with AGGRESSIVE RAM caching - utilizes up to 50GB RAM"""
    def __init__(self, file_paths, patch_size, patches_per_scan=1,
                 is_training=True, tumor_prob=0.5, max_cache_gb=35):
        self.file_paths = file_paths
        self.patch_size = patch_size
        self.patches_per_scan = patches_per_scan
        self.is_training = is_training
        self.tumor_prob = tumor_prob
        self.max_cache_gb = max_cache_gb

        # OPTIMIZATION: Aggressive caching - cache as much as possible
        self.cache = {}
        self.cache_indices = []
        self._build_cache()

        self.transforms = None
        self.epoch = 0

    def _build_cache(self):
        """Cache data aggressively in RAM"""
        print(f"Building cache (max {self.max_cache_gb}GB)...")
        cache_size_bytes = 0
        max_cache_bytes = self.max_cache_gb * 1024**3

        for idx, path in enumerate(tqdm(self.file_paths, desc="Caching data")):
            try:
                data = torch.load(path, map_location='cpu')
                img = data['image'].float()
                lbl = data['label'].long()

                # Estimate size
                item_size = img.element_size() * img.nelement() + lbl.element_size() * lbl.nelement()

                if cache_size_bytes + item_size < max_cache_bytes:
                    self.cache[idx] = {'image': img, 'label': lbl}
                    self.cache_indices.append(idx)
                    cache_size_bytes += item_size
                else:
                    break
            except:
                continue

        cache_size_gb = cache_size_bytes / 1024**3
        print(f"Cached {len(self.cache_indices)}/{len(self.file_paths)} samples ({cache_size_gb:.2f}GB)")

        # Prefetch remaining file sizes for memory estimates
        self.non_cached_indices = [i for i in range(len(self.file_paths)) if i not in self.cache_indices]

    def update_transforms(self, epoch):
        """Dynamic augmentation strength based on epoch"""
        self.epoch = epoch

        # Stronger augmentation in early epochs
        if epoch < 50:
            aug_strength = 0.3
        elif epoch < 100:
            aug_strength = 0.2
        else:
            aug_strength = 0.1

        if self.is_training:
            self.transforms = tio.Compose([
                tio.RandomFlip(axes=(0, 1, 2), p=0.5),
                tio.RandomAffine(
                    scales=(0.9, 1.1),
                    degrees=15,
                    translation=10,
                    p=0.5
                ),
                tio.RandomElasticDeformation(
                    num_control_points=7,
                    max_displacement=aug_strength * 10,
                    p=0.3
                ),
                tio.RandomGamma(log_gamma=(-0.3, 0.3), p=0.3),
            ])

    def __len__(self):
        return len(self.file_paths) * self.patches_per_scan

    def __getitem__(self, idx):
        scan_idx = idx // self.patches_per_scan

        # Load from cache or disk
        if scan_idx in self.cache_indices:
            data = self.cache[scan_idx]
            img = data['image'].clone()
            lbl = data['label'].clone()
        else:
            try:
                data = torch.load(self.file_paths[scan_idx], map_location='cpu')
                img = data['image'].float()
                lbl = data['label'].long()
            except:
                return self.__getitem__((idx + 1) % len(self))

        # Pad if needed
        img, lbl = pad_to_patch(img, lbl, self.patch_size)

        # Normalize
        img = nnunet_normalization(img)

        # Extract patch
        if self.is_training:
            # Tumor-focused sampling
            if random.random() < self.tumor_prob and (lbl > 0).any():
                tumor_coords = torch.where(lbl > 0)
                idx_choice = random.randint(0, len(tumor_coords[0]) - 1)
                center = [tumor_coords[i][idx_choice].item() for i in range(3)]
            else:
                center = [random.randint(ps//2, s - ps//2)
                         for ps, s in zip(self.patch_size, img.shape[1:])]

            starts = [max(0, c - ps//2) for c, ps in zip(center, self.patch_size)]
            ends = [min(s + ps, img.shape[i+1]) for i, (s, ps) in enumerate(zip(starts, self.patch_size))]

            img = img[:, starts[0]:ends[0], starts[1]:ends[1], starts[2]:ends[2]]
            lbl = lbl[starts[0]:ends[0], starts[1]:ends[1], starts[2]:ends[2]]
        else:
            # Center crop for validation
            starts = [(s - ps) // 2 for s, ps in zip(img.shape[1:], self.patch_size)]
            ends = [min(s + ps, img.shape[i+1]) for i, (s, ps) in enumerate(zip(starts, self.patch_size))]
            img = img[:, starts[0]:ends[0], starts[1]:ends[1], starts[2]:ends[2]]
            lbl = lbl[starts[0]:ends[0], starts[1]:ends[1], starts[2]:ends[2]]

        # CRITICAL FIX: Ensure patch is exactly patch_size by padding if needed
        # This handles cases where cropped volumes are smaller than patch_size
        current_size = img.shape[1:]  # (D, H, W)
        pad_needed = []
        for i, (curr, target) in enumerate(zip(current_size, self.patch_size)):
            if curr < target:
                total_pad = target - curr
                pad_before = total_pad // 2
                pad_after = total_pad - pad_before
                pad_needed.extend([pad_before, pad_after])
            else:
                pad_needed.extend([0, 0])

        if any(p > 0 for p in pad_needed):
            # Pad format for F.pad is (W_left, W_right, H_left, H_right, D_left, D_right)
            pad_format = [pad_needed[4], pad_needed[5], pad_needed[2], pad_needed[3], pad_needed[0], pad_needed[1]]
            img = F.pad(img, pad_format, mode='constant', value=0)
            lbl_4d = lbl.unsqueeze(0)  # Add channel dim for padding
            lbl_4d = F.pad(lbl_4d, pad_format, mode='constant', value=0)
            lbl = lbl_4d.squeeze(0)

        # Apply augmentation
        if self.transforms is not None:
            subject = tio.Subject(
                image=tio.ScalarImage(tensor=img),  # Already (C, D, H, W) - 4D
                label=tio.LabelMap(tensor=lbl.unsqueeze(0))  # Make it (1, D, H, W) - 4D
            )
            subject = self.transforms(subject)
            img = subject.image.tensor  # Keep as (C, D, H, W)
            lbl = subject.label.tensor.squeeze(0).long()  # Back to (D, H, W)

        return img, lbl

# =============== Metrics ===============
def calculate_metrics(pred, target, num_classes):
    """Calculate Dice scores"""
    pred = torch.argmax(pred, dim=1)
    dice_scores = []

    for c in range(1, num_classes):
        pred_c = (pred == c)
        target_c = (target == c)

        intersection = (pred_c & target_c).sum().float()
        union = pred_c.sum().float() + target_c.sum().float()

        if union > 0:
            dice = (2.0 * intersection) / (union + 1e-8)
            dice_scores.append(dice.item())
        else:
            dice_scores.append(0.0)

    return {'dice': np.mean(dice_scores)}

# =============== Training ===============
def train_epoch(model, train_loader, optimizer, loss_fn, scheduler, device, scaler, epoch, accumulation_steps):
    model.train()
    total_loss = 0
    gdl_sum = 0
    focal_sum = 0
    ce_sum = 0
    num_batches = 0

    optimizer.zero_grad()

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for batch_idx, (inputs, targets) in enumerate(pbar):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if USE_CHANNELS_LAST:
            inputs = inputs.to(memory_format=torch.channels_last_3d)

        with autocast(device_type=device.type):
            outputs = model(inputs)
            loss, components = loss_fn(outputs, targets)

            loss = loss / accumulation_steps

        if torch.isfinite(loss):
            scaler.scale(loss).backward()

            if (batch_idx + 1) % accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VAL)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()  # OneCycleLR steps every batch

            total_loss += loss.item() * accumulation_steps
            gdl_sum += components['gdl']
            focal_sum += components['focal']
            ce_sum += components['ce']
            num_batches += 1

            pbar.set_postfix({
                'loss': f'{loss.item() * accumulation_steps:.4f}',
                'gdl': f'{components["gdl"]:.3f}',
                'lr': f'{scheduler.get_last_lr()[0]:.6f}'
            })

        if batch_idx % 50 == 0:
            memory_cleanup()

    return total_loss / max(num_batches, 1), {
        'gdl': gdl_sum / max(num_batches, 1),
        'focal': focal_sum / max(num_batches, 1),
        'ce': ce_sum / max(num_batches, 1)
    }

@torch.no_grad()
def validate(model, val_loader, loss_fn, device, num_classes):
    model.eval()
    total_loss = 0
    all_dice = []
    num_batches = 0

    for inputs, targets in tqdm(val_loader, desc="Validation"):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if USE_CHANNELS_LAST:
            inputs = inputs.to(memory_format=torch.channels_last_3d)

        with autocast(device_type=device.type):
            outputs = model(inputs)
            loss, _ = loss_fn(outputs, targets)
            outputs = F.softmax(outputs, dim=1)

        total_loss += loss.item()
        metrics = calculate_metrics(outputs, targets, num_classes)
        all_dice.append(metrics['dice'])
        num_batches += 1

    return {
        'loss': total_loss / max(num_batches, 1),
        'dice': np.mean(all_dice)
    }

# =============== Main ===============
def main():
    print("=" * 80)
    print(f"PyTorch Version: {torch.__version__}")
    print(f"Device: {DEVICE}")
    if torch.cuda.is_available():
        print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
        print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")

    all_files = sorted(glob.glob(os.path.join(DATA_DIR, "*.pt")))
    print(f"Found {len(all_files)} files")

    global NUM_CLASSES
    NUM_CLASSES = detect_num_classes(all_files)

    train_files, val_files = train_test_split(all_files, test_size=0.2, random_state=SEED)
    print(f"Train: {len(train_files)}, Val: {len(val_files)}")

    # Compute class weights
    class_weights = compute_class_weights(train_files, NUM_CLASSES, sample_size=50)
    print(f"\nClass Weights: {class_weights}")

    # OPTIMIZATION: Aggressive caching with larger limits
    train_dataset = OptimizedDataset(
        train_files, PATCH_SIZE, PATCHES_PER_SCAN,
        is_training=True, tumor_prob=TUMOR_SAMPLING_PROB,
        max_cache_gb=MAX_CACHE_SIZE_GB
    )

    # OPTIMIZATION: Cache entire validation set
    val_dataset = OptimizedDataset(
        val_files, PATCH_SIZE, patches_per_scan=1,
        is_training=False, max_cache_gb=VAL_CACHE_SIZE_GB
    )

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=False,
        persistent_workers=False
    )

    val_loader = DataLoader(
        val_dataset, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=False
    )

    print(f"Training batches per epoch: {len(train_loader)}")
    print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")

    # TRULY OPTIMIZED: Minimal model with 16 base filters
    model = LeanUNet(IN_CHANNELS, NUM_CLASSES, base_filters=16).to(DEVICE)

    if USE_CHANNELS_LAST:
        model = model.to(memory_format=torch.channels_last_3d)

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model parameters: {total_params/1e6:.2f}M")

    # OPTIMIZATION: AdamW instead of Adam (better performance in research)
    optimizer = optim.AdamW(model.parameters(), lr=INITIAL_LR, weight_decay=WEIGHT_DECAY)

    # OPTIMIZATION: OneCycleLR - proven better for medical imaging than CosineAnnealing
    total_steps = len(train_loader) * EPOCHS
    scheduler = OneCycleLR(
        optimizer,
        max_lr=INITIAL_LR,
        total_steps=total_steps,
        pct_start=0.3,  # 30% warmup
        anneal_strategy='cos',
        div_factor=25.0,  # initial_lr = max_lr/25
        final_div_factor=10000.0
    )

    # Combined loss
    loss_fn = CombinedLoss(
        NUM_CLASSES,
        class_weights=class_weights.to(DEVICE),
        gdl_weight=2.0,
        focal_weight=1.0,
        ce_weight=0.5
    ).to(DEVICE)

    scaler = GradScaler(init_scale=1024, growth_interval=2000)

    best_dice = 0.0
    patience = 0

    print(f"\n{'='*80}")
    print("OPTIMIZATIONS APPLIED:")
    print("1. MINIMAL model: 16 base filters, NO attention, NO deep supervision (~3-5M params)")
    print("2. Higher learning rate: 1e-3 with OneCycleLR scheduler")
    print("3. Aggressive RAM caching: 35GB train + 12GB val")
    print("4. Validation set fully cached in RAM")
    print("5. AdamW optimizer + OneCycleLR")
    print("6. Larger batch size: 6 (with gradient accumulation 2 = effective 12)")
    print("7. InstanceNorm (faster than GroupNorm)")
    print(f"{'='*80}\n")

    print(f"Starting training for {EPOCHS} epochs...")
    print(f"Expected: MUCH faster training (~15-20 min/epoch) with smaller, leaner model")

    for epoch in range(1, EPOCHS + 1):
        train_dataset.update_transforms(epoch)

        train_loss, loss_components = train_epoch(
            model, train_loader, optimizer, loss_fn,
            scheduler, DEVICE, scaler, epoch, GRADIENT_ACCUMULATION_STEPS
        )

        # Validate every epoch
        val_results = validate(model, val_loader, loss_fn, DEVICE, NUM_CLASSES)

        print(f"\nEpoch {epoch:3d}/{EPOCHS}")
        print(f"Train - Loss: {train_loss:.4f} | GDL: {loss_components['gdl']:.4f}")
        print(f"Val   - Loss: {val_results['loss']:.4f} | Dice: {val_results['dice']:.4f}")
        print(f"LR: {scheduler.get_last_lr()[0]:.6f}")

        if val_results['dice'] > best_dice:
            best_dice = val_results['dice']
            patience = 0

            checkpoint = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'epoch': epoch,
                'best_dice': best_dice,
                'num_classes': NUM_CLASSES,
                'class_weights': class_weights
            }
            torch.save(checkpoint, MODEL_SAVE_PATH)
            print(f"✓ New best! Dice: {best_dice:.4f}")

            if best_dice >= 0.75:
                print(f"Target 0.75 achieved!")
        else:
            patience += 1
            print(f"Patience: {patience}/{EARLY_STOPPING_PATIENCE}")

        if patience >= EARLY_STOPPING_PATIENCE:
            print(f"Early stopping at epoch {epoch}")
            break

        if epoch % 10 == 0:
            memory_cleanup()

    print(f"\n{'='*80}")
    print(f"Training completed! Best Dice: {best_dice:.4f}")
    print(f"Model saved to: {MODEL_SAVE_PATH}")
    print(f"{'='*80}")

if __name__ == '__main__':
    main()

PyTorch Version: 2.5.1+cu121
Device: cuda
CUDA Device: NVIDIA L4
CUDA Memory: 22.2GB
Found 260 files
Detected 5 classes from labels: [0, 1, 2, 3, 4]
Train: 208, Val: 52
Computing class weights from 50 samples...


Analyzing class distribution: 100%|██████████| 50/50 [00:33<00:00,  1.48it/s]


Class distribution: [4.11637563e+08 2.74323000e+05 1.93015500e+06 7.19530000e+04
 1.96406000e+05]
Class weights: [ 1. 10. 10. 10. 10.]

Class Weights: tensor([ 1., 10., 10., 10., 10.])
Building cache (max 35GB)...


Caching data:  78%|███████▊  | 162/208 [01:23<00:23,  1.94it/s]


Cached 162/208 samples (34.97GB)
Building cache (max 12GB)...


Caching data: 100%|██████████| 52/52 [00:35<00:00,  1.46it/s]


Cached 52/52 samples (11.54GB)
Training batches per epoch: 52
Effective batch size: 16
Model parameters: 2.94M

OPTIMIZATIONS APPLIED:
1. MINIMAL model: 16 base filters, NO attention, NO deep supervision (~3-5M params)
2. Higher learning rate: 1e-3 with OneCycleLR scheduler
3. Aggressive RAM caching: 35GB train + 12GB val
4. Validation set fully cached in RAM
5. AdamW optimizer + OneCycleLR
6. Larger batch size: 6 (with gradient accumulation 2 = effective 12)
7. InstanceNorm (faster than GroupNorm)

Starting training for 250 epochs...
Expected: MUCH faster training (~15-20 min/epoch) with smaller, leaner model


Epoch 1:   2%|▏         | 1/52 [00:44<37:55, 44.62s/it, loss=4.2333, gdl=0.981, lr=0.000040]

In [None]:
from google.colab import runtime
runtime.unassign()