<a href="https://colab.research.google.com/github/aalvarez359/3dunet-vs-medam-brain-tumor/blob/main/training%20/unet_msd_training/3d_unet_msd_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Run Scripts 1,2 for 4 modality msd unet

Run Scripts 1,2 and 5 for flair-only msd unet

Run Scripts 1,2 and 6 for t2f or t2w brats2024

Make sure your directories are pointed to the right place


#1

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#2

In [None]:
!pip install nibabel segmentation-models-pytorch-3d torch torchvision torchaudio tqdm

Collecting segmentation-models-pytorch-3d
  Downloading segmentation_models_pytorch_3d-1.0.2-py3-none-any.whl.metadata (724 bytes)
Collecting pretrainedmodels==0.7.4 (from segmentation-models-pytorch-3d)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting efficientnet-pytorch==0.7.1 (from segmentation-models-pytorch-3d)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting timm==0.9.7 (from segmentation-models-pytorch-3d)
  Downloading timm-0.9.7-py3-none-any.whl.metadata (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting timm-3d==1.0.1 (from segmentation-models-pytorch-3d)
  Downloading timm_3d-1.0.1-py3-none-any.whl.metadata (588 bytes)
Col

#3

In [None]:
import os, json
from pathlib import Path
import numpy as np

# ✅ Your actual root with all the BraTS-GLI-xxxxx-yyy folders
BRATS2024_ROOT = "/content/drive/MyDrive/BraTS2024_TrainingData_extracted/BraTS2024-BraTS-GLI-TrainingData/training_data1_v2"

# Where to save the dataset JSON
JSON_PATH_BRATS2024 = os.path.join(BRATS2024_ROOT, "dataset_brats2024_cases.json")

cases = []

for root, dirs, files in os.walk(BRATS2024_ROOT):
    for f in files:
        if f.endswith("-seg.nii.gz"):
            seg_path = os.path.join(root, f)
            base = f[:-len("-seg.nii.gz")]  # e.g. BraTS-GLI-02077-100

            # Expected modality paths for this case
            t2f_path = os.path.join(root, base + "-t2f.nii.gz")  # FLAIR-like
            t2w_path = os.path.join(root, base + "-t2w.nii.gz")
            t1c_path = os.path.join(root, base + "-t1c.nii.gz")
            t1n_path = os.path.join(root, base + "-t1n.nii.gz")

            # Require at least FLAIR + seg
            if not (os.path.exists(t2f_path) and os.path.exists(seg_path)):
                print("Skipping (missing FLAIR or seg):", base)
                continue

            case = {
                "case_id": base,
                "seg": seg_path,
                "t2f": t2f_path,
                "t2w": t2w_path if os.path.exists(t2w_path) else None,
                "t1c": t1c_path if os.path.exists(t1c_path) else None,
                "t1n": t1n_path if os.path.exists(t1n_path) else None,
            }
            cases.append(case)

print(f"Found {len(cases)} cases with FLAIR + seg.")

dataset = {"training": cases}

# Make sure directory for JSON exists
os.makedirs(os.path.dirname(JSON_PATH_BRATS2024), exist_ok=True)

with open(JSON_PATH_BRATS2024, "w") as f:
    json.dump(dataset, f, indent=2)

print("Saved dataset JSON to:", JSON_PATH_BRATS2024)


Found 1350 cases with FLAIR + seg.
Saved dataset JSON to: /content/drive/MyDrive/BraTS2024_TrainingData_extracted/BraTS2024-BraTS-GLI-TrainingData/training_data1_v2/dataset_brats2024_cases.json


#4

In [None]:
#4 modality 3D U-Net
import os, json, time, math, random
from pathlib import Path
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from tqdm import tqdm
#!pip install nibabel segmentation-models-pytorch-3d torch torchvision torchaudio tqdm
# pip install segmentation-models-pytorch-3d nibabel tqdm
import segmentation_models_pytorch_3d as smp3d

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

JSON_PATH = "/content/drive/MyDrive/Task01_BrainTumour_extracted/Task01_BrainTumour/dataset.json"   # path to your JSON
# SAVE_DIR  = "checkpoints_smp3d_b7" # Original save directory
SAVE_DIR  = "/content/drive/MyDrive/BrainTumor_Checkpoints" # Save to Google Drive
os.makedirs(SAVE_DIR, exist_ok=True)

SEED = 42
#TRAIN_VAL_SPLIT = 0.90

# Compute & memory knobs (safe for Colab 16GB VRAM)
PATCH_SIZE = (128, 224, 224)       # training crop + SWI window
#PATCH_SIZE = (128, 192, 192)
OVERLAP    = 0.75               # sliding-window overlap
#OVERLAP    = 0.5
#BATCH_SIZE = 2                # L4 GPU
BATCH_SIZE = 6                # A100 GPU
ACCUM_STEPS = 1                # A100 GPU
#ACCUM_STEPS = 4                # L4 GPU
#NUM_WORKERS = 2                 #L4 GPU
NUM_WORKERS = 4                 #A100

# Training schedule
EPOCHS = 75
LR = 3e-4
WEIGHT_DECAY = 1e-4

# Data/spec
IN_CHANNELS = 4                 # (FLAIR, T1w, T1gd, T2w)
N_CLASSES   = 4                 # (bg, edema, non-enh, enh)

# -------------------------
# Reproducibility
# -------------------------
def set_seed(s=SEED):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed()

torch.backends.cudnn.benchmark = True      # faster on fixed-size inputs
torch.backends.cudnn.deterministic = False # allow fastest algorithms


# -------------------------
# NIfTI I/O + preprocessing
# -------------------------
def percentile_clip(arr, lo=0.5, hi=99.5):
    a = arr.astype(np.float32)
    l, h = np.percentile(a, [lo, hi]); a = np.clip(a, l, h)
    return a

def zscore_per_channel(x, eps=1e-8):
    # x: (C, D, H, W)
    x = x.astype(np.float32)
    for c in range(x.shape[0]):
        v = x[c]
        mask = (v != 0)
        m = v[mask].mean() if mask.any() else v.mean()
        s = v[mask].std()  if mask.any() else v.std()
        s = max(float(s), eps)
        x[c] = (v - m) / s
    return x

def load_nifti_image_chwd(path_img):
    # BRATS-like: (H, W, D, 4) -> (C=4, D, H, W)
    img = nib.load(str(path_img))
    arr = img.get_fdata(dtype=np.float32)
    assert arr.ndim == 4 and arr.shape[3] == IN_CHANNELS, f"{path_img} shape {arr.shape} not HWD4"
    chans = [percentile_clip(arr[..., c], 0.5, 99.5) for c in range(arr.shape[3])]
    vol = np.stack(chans, axis=3)                 # H W D C
    vol = np.transpose(vol, (3, 2, 0, 1))         # C D H W
    vol = zscore_per_channel(vol)
    return vol

def load_nifti_label_dhw(path_lab):
    lab = nib.load(str(path_lab)).get_fdata(dtype=np.float32)  # H W D
    lab = np.rint(lab).astype(np.int64)
    lab = np.transpose(lab, (2, 0, 1))  # D H W
    return lab

# -------------------------
# Cropping & SW Inference
# -------------------------
def random_crop_3d(image_cdhw, label_dhw, crop, tries=8, fg_bias=0.5):
    C, D, H, W = image_cdhw.shape
    cd, ch, cw = crop
    assert D >= cd and H >= ch and W >= cw, f"Patch {crop} > vol {(D,H,W)}"
    for _ in range(tries):
        z0 = np.random.randint(0, D - cd + 1)
        y0 = np.random.randint(0, H - ch + 1)
        x0 = np.random.randint(0, W - cw + 1)
        patch_lab = label_dhw[z0:z0+cd, y0:y0+ch, x0:x0+cw]
        if (np.random.rand() > fg_bias) or np.any(patch_lab > 0):
            return image_cdhw[:, z0:z0+cd, y0:y0+ch, x0:x0+cw], patch_lab
    # fallback random
    z0 = np.random.randint(0, D - cd + 1)
    y0 = np.random.randint(0, H - ch + 1)
    x0 = np.random.randint(0, W - cw + 1)
    return image_cdhw[:, z0:z0+cd, y0:y0+ch, x0:x0+cw], label_dhw[z0:z0+cd, y0:y0+ch, x0:x0+cw]

@torch.no_grad()
def sliding_window_inference(volume_cdhw, model, window, overlap, device="cuda"):
    model.eval()
    C, D, H, W = volume_cdhw.shape
    wd, wh, ww = window
    sd = max(1, int(wd * (1 - overlap)))
    sh = max(1, int(wh * (1 - overlap)))
    sw = max(1, int(ww * (1 - overlap)))

    out_prob = torch.zeros((N_CLASSES, D, H, W), dtype=torch.float32, device=device)
    out_norm = torch.zeros((1, D, H, W), dtype=torch.float32, device=device)

    for z in range(0, max(D - wd + 1, 1), sd):
        z0 = min(z, D - wd)
        for y in range(0, max(H - wh + 1, 1), sh):
            y0 = min(y, H - wh)
            for x in range(0, max(W - ww + 1, 1), sw):
                x0 = min(x, W - ww)
                patch = volume_cdhw[:, z0:z0+wd, y0:y0+wh, x0:x0+ww]
                pt = torch.from_numpy(patch).unsqueeze(0).to(device)  # 1,C,D,H,W
                logits = model(pt)                                    # 1,C,d,h,w
                probs = F.softmax(logits, dim=1)[0]
                out_prob[:, z0:z0+wd, y0:y0+wh, x0:x0+ww] += probs
                out_norm[:, z0:z0+wd, y0:y0+wh, x0:x0+ww] += 1.0

    out_prob /= (out_norm + 1e-8)
    pred = torch.argmax(out_prob, dim=0)  # D H W
    return pred, out_prob

# -------------------------
# Dataset
# -------------------------
class BratsTrainPatches(Dataset):
    def __init__(self, items):
        self.items = items

    def __len__(self): return len(self.items)

    def __getitem__(self, i):
        it  = self.items[i]
        img = load_nifti_image_chwd(it["image"])  # C D H W
        lab = load_nifti_label_dhw(it["label"])   # D H W

        # Foreground-biased random crop
        img, lab = random_crop_3d(img, lab, PATCH_SIZE, fg_bias=0.85)

        # --- Geometric augmentations ---
        # Rotate ONLY in-plane (H, W) → keeps shape (C, D, H, W) consistent
        if np.random.rand() < 0.5:
            k = np.random.randint(0, 4)
            img = np.rot90(img, k=k, axes=(2, 3)).copy()  # rotate over H,W
            lab = np.rot90(lab, k=k, axes=(1, 2)).copy()  # rotate over H,W

        # Flips along depth, height, width (safe for shapes)
        if np.random.rand() < 0.33:
            img = np.flip(img, axis=1).copy()  # flip D
            lab = np.flip(lab, axis=0).copy()
        if np.random.rand() < 0.33:
            img = np.flip(img, axis=2).copy()  # flip H
            lab = np.flip(lab, axis=1).copy()
        if np.random.rand() < 0.33:
            img = np.flip(img, axis=3).copy()  # flip W
            lab = np.flip(lab, axis=2).copy()

        # --- Light intensity augmentations ---
        if np.random.rand() < 0.15:
            img = img * (0.9 + 0.2 * np.random.rand())
        if np.random.rand() < 0.15:
            img = img + np.random.normal(0, 0.05, img.shape).astype(np.float32)

        # Return torch tensors
        return torch.from_numpy(img), torch.from_numpy(lab).long()


# -------------------------
# Losses & Metrics
# -------------------------
class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1.0, include_bg=False):
        super().__init__(); self.smooth = smooth; self.include_bg = include_bg
    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        N, C, D, H, W = probs.shape
        onehot = torch.zeros_like(probs)
        onehot.scatter_(1, targets.unsqueeze(1), 1)
        s = 0 if self.include_bg else 1
        p = probs[:, s:, ...]; t = onehot[:, s:, ...]
        dims = (0,2,3,4)
        inter = torch.sum(p*t, dims); denom = torch.sum(p+t, dims)
        dice = (2.0*inter + self.smooth) / (denom + self.smooth)
        return 1.0 - dice.mean()

class CEDice(nn.Module):
    def __init__(self, w_ce=1.0, w_dice=1.0, include_bg_dice=False):
        super().__init__(); self.ce = nn.CrossEntropyLoss()
        self.dice = SoftDiceLoss(include_bg=include_bg_dice)
        self.w_ce = w_ce; self.w_dice = w_dice
    def forward(self, logits, y):
        return self.w_ce*self.ce(logits, y) + self.w_dice*self.dice(logits, y)

@torch.no_grad()
def dice_per_class(pred, target, n_classes=N_CLASSES):
    if pred.ndim == 3:
        pred = pred.unsqueeze(0); target = target.unsqueeze(0)
    out = []
    for c in range(n_classes):
        p = (pred == c); t = (target == c)
        inter = (p & t).sum().item(); denom = p.sum().item()+t.sum().item()
        out.append((2*inter/denom) if denom>0 else 1.0)
    return out

# -------------------------
# Data split from JSON
# -------------------------
def load_train_val_from_json(json_path, train_frac=0.75, val_frac=0.05):
    """
    Split MSD training set into train / val / test.

    - train_frac: fraction of total for training
    - val_frac  : fraction of total for validation
    - test_frac : remainder (1 - train_frac - val_frac)

    We do this ONLY once; afterward we always reload the same split
    from dataset_split.json in SAVE_DIR so that U-Net and MedSAM2
    share the exact same test set.
    """
    split_path = os.path.join(SAVE_DIR, "dataset_split.json")

    # ---- 1) If split already exists, just reload it (LOCKED SPLIT) ----
    if os.path.exists(split_path):
        with open(split_path, "r") as f:
            split = json.load(f)
        train_items = split["train"]
        val_items   = split["val"]
        test_items  = split["test"]
        print(f"Loaded existing dataset split from {split_path}")
        print(f"Train={len(train_items)}  Val={len(val_items)}  Test={len(test_items)}")
        return train_items, val_items, test_items

    # ---- 2) Otherwise, create a new split and save it ----
    with open(json_path, "r") as f:
        js = json.load(f)

    items = js["training"]  # list of {"image":..., "label":...}
    root = Path(json_path).parent

    # Make image/label paths absolute so they remain valid when reloaded
    for it in items:
        it["image"] = str((root / it["image"]).resolve())
        it["label"] = str((root / it["label"]).resolve())

    n = len(items)
    idx = np.arange(n)
    rng = np.random.RandomState(SEED)
    rng.shuffle(idx)

    n_train = int(n * train_frac)
    n_val   = int(n * val_frac)
    n_test  = n - n_train - n_val

    train_idx = idx[:n_train]
    val_idx   = idx[n_train:n_train + n_val]
    test_idx  = idx[n_train + n_val:]

    train_items = [items[i] for i in train_idx]
    val_items   = [items[i] for i in val_idx]
    test_items  = [items[i] for i in test_idx]

    split = {
        "train": train_items,
        "val":   val_items,
        "test":  test_items,
    }
    with open(split_path, "w") as f:
        json.dump(split, f, indent=2)
    print(f"Saved dataset split listing to {split_path}")
    print(f"Train={len(train_items)}  Val={len(val_items)}  Test={len(test_items)}")

    return train_items, val_items, test_items

# -------------------------
# Model
# -------------------------
def build_model(device):
    model = smp3d.Unet(
        encoder_name="efficientnet-b7",
        encoder_weights=None,
        in_channels=IN_CHANNELS,
        classes=N_CLASSES,
        # slimmer decoder to save VRAM
        decoder_channels=(192, 128, 64, 32, 16),
    )
    return model.to(device)

# -------------------------
# Validation
# -------------------------
@torch.no_grad()
def validate_full_volume(model, val_items, device):
    model.eval()

    mean_dice_all = []          # Existing behavior: mean of classes 1-3
    per_class_dice = {1: [], 2: [], 3: []}  # ✅ New: collect per-tumor-class Dice scores

    for it in tqdm(val_items, desc="Validating (full-volume)", leave=False):
        vol = load_nifti_image_chwd(it["image"])      # C D H W
        lab = load_nifti_label_dhw(it["label"])       # D H W

        # ---- Run sliding-window prediction ----
        pred, _ = sliding_window_inference(vol, model, PATCH_SIZE, OVERLAP, device)

        # ---- Compute Dice per class ----
        dices = dice_per_class(pred.cpu(), torch.from_numpy(lab))
        # dices is assumed to be a tensor/list of 4 values: [bg, class1, class2, class3]

        # ✅ Append mean over non-background classes (1–3), same as before
        mean_dice_all.append(np.mean(dices[1:]))

        # ✅ Store per-class Dice for logging later
        for cls in [1, 2, 3]:
            per_class_dice[cls].append(float(dices[cls]))

    # ✅ Compute final averages
    final_mean_dice = float(np.mean(mean_dice_all))
    final_per_class_dice = {
        cls: float(np.mean(per_class_dice[cls])) for cls in [1, 2, 3]
    }

    # ✅ You can return both
    return final_mean_dice, final_per_class_dice

def save_checkpoint(path, epoch, model, opt, sched, scaler, best_val):
    torch.save(
        {
            "epoch": epoch,
            "state_dict": model.state_dict(),
            "optimizer": opt.state_dict(),
            "scheduler": sched.state_dict(),
            "scaler": scaler.state_dict(),
            "best_val": float(best_val),
            "config": {
                "encoder": "efficientnet-b7",
                "in_channels": IN_CHANNELS,
                "classes": N_CLASSES,
                "patch_size": PATCH_SIZE,
            },
        },
        path,
    )

def resume_if_possible(model, opt, sched, scaler, load_scheduler=False):
    """Load last.pt if it exists and return (start_epoch, best_val)."""
    last_ckpt = os.path.join(SAVE_DIR, "last.pt")
    if os.path.exists(last_ckpt):
        ckpt = torch.load(last_ckpt, map_location="cpu")

        # weights
        model.load_state_dict(ckpt["state_dict"])

        # optimizer
        if "optimizer" in ckpt:
            opt.load_state_dict(ckpt["optimizer"])

        # scheduler (optional, but we keep it False when switching to OneCycle)
        if load_scheduler and "scheduler" in ckpt and sched is not None:
            try:
                sched.load_state_dict(ckpt["scheduler"])
            except Exception as e:
                print("⚠️ Skipping scheduler state load:", e)

        # scaler
        if "scaler" in ckpt:
            scaler.load_state_dict(ckpt["scaler"])

        start_epoch = int(ckpt.get("epoch", 0)) + 1
        best_val = float(ckpt.get("best_val", -1.0))
        print(f"↩ Resumed from {last_ckpt}: start_epoch={start_epoch}, best_val={best_val:.4f}")
        return start_epoch, best_val
    return 1, -1.0


# -------------------------
# Train loop
# -------------------------
def train():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Device:", device)

    # 3-way split: train / val / test
    train_items, val_items, test_items = load_train_val_from_json(JSON_PATH, train_frac=0.75, val_frac=0.05)
    print(f"Train={len(train_items)}  Val={len(val_items)}  Test={len(test_items)}")

    train_ds = BratsTrainPatches(train_items)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=(device=="cuda"), drop_last=True)

    model = build_model(device)
    loss_fn = CEDice(w_ce=1.0, w_dice=1.0, include_bg_dice=False)
    cls_w = torch.tensor([1.0, 2.0, 2.3, 1.2], device=device)
    loss_fn.ce = nn.CrossEntropyLoss(weight=cls_w, label_smoothing=0.03)
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    #sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
    sched = None
    scaler = GradScaler(device="cuda")

    start_epoch, best_val = resume_if_possible(model, opt, sched, scaler, load_scheduler=False)

    # ---- Build OneCycleLR for the remaining epochs ----
    # IMPORTANT: OneCycle expects *optimizer step count* per epoch, not dataloader iters,
    # so we divide by ACCUM_STEPS (ceil).
    steps_per_epoch = len(train_loader) // ACCUM_STEPS
    remaining_epochs = max(1, EPOCHS - (start_epoch - 1))  # e.g., 160 - 103 + 1 if resuming at 103

    # Optional: make sure LR is at your intended max before OneCycle constructs its curve
    for pg in opt.param_groups:
        pg['lr'] = LR

    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt,
        max_lr=LR,
        steps_per_epoch=steps_per_epoch,   # number of optimizer.step() calls per epoch
        epochs=remaining_epochs,
        pct_start=0.1,
        div_factor=10,
        final_div_factor=100,
    )

    # ---- init or load training history (for nice plots and resume) ----
    history_path = os.path.join(SAVE_DIR, "training_history.json")
    if os.path.exists(history_path):
        with open(history_path, "r") as f:
            history = json.load(f)
    else:
        history = {
            "epoch": [],
            "train_loss": [],
            "val_mean_dice": [],
            "dice_edema": [],
            "dice_non_enh": [],
            "dice_enh": [],
            "lr": []
        }

    step = 0
    for epoch in range(start_epoch, EPOCHS+1):
        model.train()
        t0 = time.time()
        running = 0.0
        opt.zero_grad(set_to_none=True)
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
        for i, (x, y) in enumerate(pbar):
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            with autocast(device_type="cuda"):
                logits = model(x)
                loss = loss_fn(logits, y) / ACCUM_STEPS

            scaler.scale(loss).backward()
            if ((i+1) % ACCUM_STEPS) == 0:
                scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True)

                sched.step()

            running += loss.item() * ACCUM_STEPS
            pbar.set_postfix(loss=f"{(running/(i+1)):.4f}")

            step += 1

        train_loss = running / max(1, len(train_loader))
        # --- Validation ---
        val_mean_dice, per_class = validate_full_volume(model, val_items, device)
        dt = time.time() - t0

        print(f"[{epoch:03d}] "
            f"train_loss={train_loss:.4f}  "
            f"val_meanDice(no-bg)={val_mean_dice:.4f}  "
            f"Dice(edema)={per_class[1]:.4f}  "
            f"Dice(non-enh)={per_class[2]:.4f}  "
            f"Dice(enh)={per_class[3]:.4f}  "
            f"lr={sched.get_last_lr()[0]:.2e}  ({dt:.1f}s)")

        # ---- Save training history for plotting ----
        history["epoch"].append(epoch)
        history["train_loss"].append(train_loss)
        history["val_mean_dice"].append(val_mean_dice)
        history["dice_edema"].append(per_class[1])
        history["dice_non_enh"].append(per_class[2])
        history["dice_enh"].append(per_class[3])
        history["lr"].append(sched.get_last_lr()[0])

        with open(os.path.join(SAVE_DIR, "training_history.json"), "w") as f:
            json.dump(history, f, indent=2)


        # save last
        save_checkpoint(
            os.path.join(SAVE_DIR, "last.pt"),
            epoch, model, opt, sched, scaler, best_val
)

        # save best
        if val_mean_dice > best_val:
            best_val = val_mean_dice
            best_path = os.path.join(SAVE_DIR, f"best_epoch{epoch}_dice{best_val:.4f}.pt")
            save_checkpoint(best_path, epoch, model, opt, sched, scaler, best_val)
            print("✔ Saved", best_path)



    print(f"Best val mean Dice (no-bg): {best_val:.4f}")

if __name__ == "__main__":
    train()


Device: cuda
Loaded existing dataset split from /content/drive/MyDrive/BrainTumor_Checkpoints/dataset_split.json
Train=363  Val=24  Test=97
Train=363  Val=24  Test=97


KeyboardInterrupt: 

#5

In [None]:
# FLAIR-only MSD 3D U-Net

import os, json, time, math, random
from pathlib import Path
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from tqdm import tqdm
# !pip install nibabel segmentation-models-pytorch-3d torch torchvision torchaudio tqdm
import segmentation_models_pytorch_3d as smp3d

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# --- Paths ---
JSON_PATH = "/content/drive/MyDrive/Task01_BrainTumour_extracted/Task01_BrainTumour/dataset.json"   # original MSD/BRATS json

# ✅ Reuse the SAME split as your 4-modality experiment:
SPLIT_JSON_PATH = "/content/drive/MyDrive/BrainTumor_Checkpoints/dataset_split.json"

# ✅ Save FLAIR-only checkpoints in a different folder
SAVE_DIR  = "/content/drive/MyDrive/BrainTumor_Checkpoints_FLAIR"
os.makedirs(SAVE_DIR, exist_ok=True)

SEED = 42

# Compute & memory knobs (copied from your 4-mod script)
PATCH_SIZE = (128, 224, 224)       # training crop + SWI window
OVERLAP    = 0.75                  # sliding-window overlap
BATCH_SIZE = 6                     # A100 GPU (worked for your 4-mod run)
ACCUM_STEPS = 1
NUM_WORKERS = 4

# Training schedule
EPOCHS = 95
LR = 3e-4
WEIGHT_DECAY = 1e-4

# Data/spec
IN_CHANNELS = 1                 # ✅ FLAIR only now
N_CLASSES   = 4                 # (bg, edema, non-enh, enh)

# -------------------------
# Reproducibility
# -------------------------
def set_seed(s=SEED):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed()

torch.backends.cudnn.benchmark = True      # faster on fixed-size inputs
torch.backends.cudnn.deterministic = False # allow fastest algorithms


# -------------------------
# NIfTI I/O + preprocessing
# -------------------------
def percentile_clip(arr, lo=0.5, hi=99.5):
    a = arr.astype(np.float32)
    l, h = np.percentile(a, [lo, hi]); a = np.clip(a, l, h)
    return a

def zscore_per_channel(x, eps=1e-8):
    # x: (C, D, H, W)
    x = x.astype(np.float32)
    for c in range(x.shape[0]):
        v = x[c]
        mask = (v != 0)
        m = v[mask].mean() if mask.any() else v.mean()
        s = v[mask].std()  if mask.any() else v.std()
        s = max(float(s), eps)
        x[c] = (v - m) / s
    return x

def load_nifti_image_chwd(path_img):
    """
    FLAIR-only loader.

    Original BRATS: (H, W, D, 4)
      - we now take ONLY channel 0 (FLAIR)
      - -> (H, W, D)
      - -> (C=1, D, H, W)
    """
    img = nib.load(str(path_img))
    arr = img.get_fdata(dtype=np.float32)   # (H, W, D, 4) in BRATS/Task01
    assert arr.ndim == 4 and arr.shape[3] >= 1, f"{path_img} shape {arr.shape} not HWD4"

    flair = percentile_clip(arr[..., 0], 0.5, 99.5)  # channel 0 = FLAIR, shape (H, W, D)

    # add channel axis → (H, W, D, 1)
    vol = flair[..., None]                      # (H, W, D, 1)
    vol = np.transpose(vol, (3, 2, 0, 1))       # (C=1, D, H, W)
    vol = zscore_per_channel(vol)
    return vol

def load_nifti_label_dhw(path_lab):
    lab = nib.load(str(path_lab)).get_fdata(dtype=np.float32)  # H W D
    lab = np.rint(lab).astype(np.int64)
    lab = np.transpose(lab, (2, 0, 1))  # D H W
    return lab

# -------------------------
# Cropping & SW Inference
# -------------------------
def random_crop_3d(image_cdhw, label_dhw, crop, tries=8, fg_bias=0.5):
    C, D, H, W = image_cdhw.shape
    cd, ch, cw = crop
    assert D >= cd and H >= ch and W >= cw, f"Patch {crop} > vol {(D,H,W)}"
    for _ in range(tries):
        z0 = np.random.randint(0, D - cd + 1)
        y0 = np.random.randint(0, H - ch + 1)
        x0 = np.random.randint(0, W - cw + 1)
        patch_lab = label_dhw[z0:z0+cd, y0:y0+ch, x0:x0+cw]
        if (np.random.rand() > fg_bias) or np.any(patch_lab > 0):
            return image_cdhw[:, z0:z0+cd, y0:y0+ch, x0:x0+cw], patch_lab
    # fallback random
    z0 = np.random.randint(0, D - cd + 1)
    y0 = np.random.randint(0, H - ch + 1)
    x0 = np.random.randint(0, W - cw + 1)
    return image_cdhw[:, z0:z0+cd, y0:y0+ch, x0:x0+cw], label_dhw[z0:z0+cd, y0:y0+ch, x0:x0+cw]

@torch.no_grad()
def sliding_window_inference(volume_cdhw, model, window, overlap, device="cuda"):
    model.eval()
    C, D, H, W = volume_cdhw.shape
    wd, wh, ww = window
    sd = max(1, int(wd * (1 - overlap)))
    sh = max(1, int(wh * (1 - overlap)))
    sw = max(1, int(ww * (1 - overlap)))

    out_prob = torch.zeros((N_CLASSES, D, H, W), dtype=torch.float32, device=device)
    out_norm = torch.zeros((1, D, H, W), dtype=torch.float32, device=device)

    for z in range(0, max(D - wd + 1, 1), sd):
        z0 = min(z, D - wd)
        for y in range(0, max(H - wh + 1, 1), sh):
            y0 = min(y, H - wh)
            for x in range(0, max(W - ww + 1, 1), sw):
                x0 = min(x, W - ww)
                patch = volume_cdhw[:, z0:z0+wd, y0:y0+wh, x0:x0+ww]
                pt = torch.from_numpy(patch).unsqueeze(0).to(device)  # 1,C,D,H,W
                logits = model(pt)                                    # 1,C,d,h,w
                probs = F.softmax(logits, dim=1)[0]
                out_prob[:, z0:z0+wd, y0:y0+wh, x0:x0+ww] += probs
                out_norm[:, z0:z0+wd, y0:y0+wh, x0:x0+ww] += 1.0

    out_prob /= (out_norm + 1e-8)
    pred = torch.argmax(out_prob, dim=0)  # D H W
    return pred, out_prob

# -------------------------
# Dataset
# -------------------------
class BratsTrainPatches(Dataset):
    def __init__(self, items):
        self.items = items

    def __len__(self): return len(self.items)

    def __getitem__(self, i):
        it  = self.items[i]
        img = load_nifti_image_chwd(it["image"])  # C D H W (now C=1)
        lab = load_nifti_label_dhw(it["label"])   # D H W

        # Foreground-biased random crop
        img, lab = random_crop_3d(img, lab, PATCH_SIZE, fg_bias=0.85)

        # --- Geometric augmentations ---
        # Rotate ONLY in-plane (H, W) → keeps shape (C, D, H, W) consistent
        if np.random.rand() < 0.5:
            k = np.random.randint(0, 4)
            img = np.rot90(img, k=k, axes=(2, 3)).copy()  # rotate over H,W
            lab = np.rot90(lab, k=k, axes=(1, 2)).copy()  # rotate over H,W

        # Flips along depth, height, width (safe for shapes)
        if np.random.rand() < 0.33:
            img = np.flip(img, axis=1).copy()  # flip D
            lab = np.flip(lab, axis=0).copy()
        if np.random.rand() < 0.33:
            img = np.flip(img, axis=2).copy()  # flip H
            lab = np.flip(lab, axis=1).copy()
        if np.random.rand() < 0.33:
            img = np.flip(img, axis=3).copy()  # flip W
            lab = np.flip(lab, axis=2).copy()

        # --- Light intensity augmentations ---
        if np.random.rand() < 0.15:
            img = img * (0.9 + 0.2 * np.random.rand())
        if np.random.rand() < 0.15:
            img = img + np.random.normal(0, 0.05, img.shape).astype(np.float32)

        # Return torch tensors
        return torch.from_numpy(img), torch.from_numpy(lab).long()


# -------------------------
# Losses & Metrics
# -------------------------
class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1.0, include_bg=False):
        super().__init__(); self.smooth = smooth; self.include_bg = include_bg
    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        N, C, D, H, W = probs.shape
        onehot = torch.zeros_like(probs)
        onehot.scatter_(1, targets.unsqueeze(1), 1)
        s = 0 if self.include_bg else 1
        p = probs[:, s:, ...]; t = onehot[:, s:, ...]
        dims = (0,2,3,4)
        inter = torch.sum(p*t, dims); denom = torch.sum(p+t, dims)
        dice = (2.0*inter + self.smooth) / (denom + self.smooth)
        return 1.0 - dice.mean()

class CEDice(nn.Module):
    def __init__(self, w_ce=1.0, w_dice=1.0, include_bg_dice=False):
        super().__init__(); self.ce = nn.CrossEntropyLoss()
        self.dice = SoftDiceLoss(include_bg=include_bg_dice)
        self.w_ce = w_ce; self.w_dice = w_dice
    def forward(self, logits, y):
        return self.w_ce*self.ce(logits, y) + self.w_dice*self.dice(logits, y)

@torch.no_grad()
def dice_per_class(pred, target, n_classes=N_CLASSES):
    if pred.ndim == 3:
        pred = pred.unsqueeze(0); target = target.unsqueeze(0)
    out = []
    for c in range(n_classes):
        p = (pred == c); t = (target == c)
        inter = (p & t).sum().item(); denom = p.sum().item()+t.sum().item()
        out.append((2*inter/denom) if denom>0 else 1.0)
    return out

# -------------------------
# Data split from JSON
# -------------------------
def load_train_val_from_json(json_path, train_frac=0.75, val_frac=0.05):
    """
    FLAIR-only version:
    ✅ Always reuse the EXISTING split from your 4-modality run.
    """
    assert os.path.exists(SPLIT_JSON_PATH), f"Split file not found: {SPLIT_JSON_PATH}"
    with open(SPLIT_JSON_PATH, "r") as f:
        split = json.load(f)
    train_items = split["train"]
    val_items   = split["val"]
    test_items  = split["test"]
    print(f"Loaded existing dataset split from {SPLIT_JSON_PATH}")
    print(f"Train={len(train_items)}  Val={len(val_items)}  Test={len(test_items)}")
    return train_items, val_items, test_items

# -------------------------
# Model (hard-coded B7)
# -------------------------
def build_model(device):
    model = smp3d.Unet(
        encoder_name="efficientnet-b7",
        encoder_weights=None,
        in_channels=IN_CHANNELS,      # ✅ 1 channel now
        classes=N_CLASSES,
        # slimmer decoder to save VRAM
        decoder_channels=(192, 128, 64, 32, 16),
    )
    return model.to(device)

# -------------------------
# Validation (full volume)
# -------------------------
@torch.no_grad()
def validate_full_volume(model, val_items, device):
    model.eval()

    mean_dice_all = []          # mean of classes 1-3
    per_class_dice = {1: [], 2: [], 3: []}

    for it in tqdm(val_items, desc="Validating (full-volume)", leave=False):
        vol = load_nifti_image_chwd(it["image"])      # C D H W (C=1 now)
        lab = load_nifti_label_dhw(it["label"])       # D H W

        pred, _ = sliding_window_inference(vol, model, PATCH_SIZE, OVERLAP, device)

        dices = dice_per_class(pred.cpu(), torch.from_numpy(lab))
        mean_dice_all.append(np.mean(dices[1:]))

        for cls in [1, 2, 3]:
            per_class_dice[cls].append(float(dices[cls]))

    final_mean_dice = float(np.mean(mean_dice_all))
    final_per_class_dice = {
        cls: float(np.mean(per_class_dice[cls])) for cls in [1, 2, 3]
    }

    return final_mean_dice, final_per_class_dice

def save_checkpoint(path, epoch, model, opt, sched, scaler, best_val):
    torch.save(
        {
            "epoch": epoch,
            "state_dict": model.state_dict(),
            "optimizer": opt.state_dict(),
            "scheduler": sched.state_dict() if sched is not None else {},
            "scaler": scaler.state_dict(),
            "best_val": float(best_val),
            "config": {
                "encoder": "efficientnet-b7",
                "in_channels": IN_CHANNELS,
                "classes": N_CLASSES,
                "patch_size": PATCH_SIZE,
            },
        },
        path,
    )

def resume_if_possible(model, opt, sched, scaler, load_scheduler=False):
    """Load last.pt if it exists and return (start_epoch, best_val)."""
    last_ckpt = os.path.join(SAVE_DIR, "last.pt")
    if os.path.exists(last_ckpt):
        ckpt = torch.load(last_ckpt, map_location="cpu")

        # weights
        model.load_state_dict(ckpt["state_dict"])

        # optimizer
        if "optimizer" in ckpt:
            opt.load_state_dict(ckpt["optimizer"])

        # scheduler (optional)
        if load_scheduler and "scheduler" in ckpt and sched is not None:
            try:
                sched.load_state_dict(ckpt["scheduler"])
            except Exception as e:
                print("⚠️ Skipping scheduler state load:", e)

        # scaler
        if "scaler" in ckpt:
            scaler.load_state_dict(ckpt["scaler"])

        start_epoch = int(ckpt.get("epoch", 0)) + 1
        best_val = float(ckpt.get("best_val", -1.0))
        print(f"↩ Resumed from {last_ckpt}: start_epoch={start_epoch}, best_val={best_val:.4f}")
        return start_epoch, best_val
    return 1, -1.0


# -------------------------
# Train loop
# -------------------------
def train():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Device:", device)

    # 3-way split: train / val / test (reused from 4-mod split)
    train_items, val_items, test_items = load_train_val_from_json(JSON_PATH, train_frac=0.75, val_frac=0.05)
    print(f"Train={len(train_items)}  Val={len(val_items)}  Test={len(test_items)}")

    train_ds = BratsTrainPatches(train_items)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=(device=="cuda"), drop_last=True)

    model = build_model(device)
    loss_fn = CEDice(w_ce=1.0, w_dice=1.0, include_bg_dice=False)
    cls_w = torch.tensor([1.0, 2.0, 2.3, 1.2], device=device)
    loss_fn.ce = nn.CrossEntropyLoss(weight=cls_w, label_smoothing=0.03)
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    sched = None
    scaler = GradScaler(device="cuda")

    start_epoch, best_val = resume_if_possible(model, opt, sched, scaler, load_scheduler=False)

    # ---- Build OneCycleLR for the remaining epochs (same as your 4-mod script) ----
    steps_per_epoch = len(train_loader) // ACCUM_STEPS
    remaining_epochs = max(1, EPOCHS - (start_epoch - 1))

    for pg in opt.param_groups:
        pg['lr'] = LR

    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt,
        max_lr=LR,
        steps_per_epoch=steps_per_epoch,
        epochs=remaining_epochs,
        pct_start=0.1,
        div_factor=10,
        final_div_factor=100,
    )

    # ---- init or load training history ----
    history_path = os.path.join(SAVE_DIR, "training_history.json")
    if os.path.exists(history_path):
        with open(history_path, "r") as f:
            history = json.load(f)
    else:
        history = {
            "epoch": [],
            "train_loss": [],
            "val_mean_dice": [],
            "dice_edema": [],
            "dice_non_enh": [],
            "dice_enh": [],
            "lr": []
        }

    step = 0
    for epoch in range(start_epoch, EPOCHS+1):
        model.train()
        t0 = time.time()
        running = 0.0
        opt.zero_grad(set_to_none=True)
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
        for i, (x, y) in enumerate(pbar):
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            with autocast(device_type="cuda"):
                logits = model(x)
                loss = loss_fn(logits, y) / ACCUM_STEPS

            scaler.scale(loss).backward()
            if ((i+1) % ACCUM_STEPS) == 0:
                scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True)
                sched.step()

            running += loss.item() * ACCUM_STEPS
            pbar.set_postfix(loss=f"{(running/(i+1)):.4f}")

            step += 1

        train_loss = running / max(1, len(train_loader))

        # --- Validation ---
        val_mean_dice, per_class = validate_full_volume(model, val_items, device)
        dt = time.time() - t0

        print(f"[{epoch:03d}] "
              f"train_loss={train_loss:.4f}  "
              f"val_meanDice(no-bg)={val_mean_dice:.4f}  "
              f"Dice(edema)={per_class[1]:.4f}  "
              f"Dice(non-enh)={per_class[2]:.4f}  "
              f"Dice(enh)={per_class[3]:.4f}  "
              f"lr={sched.get_last_lr()[0]:.2e}  ({dt:.1f}s)")

        # ---- Save training history ----
        history["epoch"].append(epoch)
        history["train_loss"].append(train_loss)
        history["val_mean_dice"].append(val_mean_dice)
        history["dice_edema"].append(per_class[1])
        history["dice_non_enh"].append(per_class[2])
        history["dice_enh"].append(per_class[3])
        history["lr"].append(sched.get_last_lr()[0])

        with open(history_path, "w") as f:
            json.dump(history, f, indent=2)

        # save last (full training state)
        save_checkpoint(
            os.path.join(SAVE_DIR, "last.pt"),
            epoch, model, opt, sched, scaler, best_val
        )

        # save best (also full state so you can deploy or resume from best)
        if val_mean_dice > best_val:
            best_val = val_mean_dice
            best_path = os.path.join(SAVE_DIR, f"best_epoch{epoch}_dice{best_val:.4f}.pt")
            save_checkpoint(best_path, epoch, model, opt, sched, scaler, best_val)
            print("✔ Saved", best_path)

    print(f"Best val mean Dice (no-bg): {best_val:.4f}")

if __name__ == "__main__":
    train()


Device: cuda
Loaded existing dataset split from /content/drive/MyDrive/BrainTumor_Checkpoints/dataset_split.json
Train=363  Val=24  Test=97
Train=363  Val=24  Test=97
↩ Resumed from /content/drive/MyDrive/BrainTumor_Checkpoints_FLAIR/last.pt: start_epoch=86, best_val=0.5087


Epoch 86/95: 100%|██████████| 60/60 [02:17<00:00,  2.29s/it, loss=0.7694]


[086] train_loss=0.7694  val_meanDice(no-bg)=0.4643  Dice(edema)=0.6825  Dice(non-enh)=0.3105  Dice(enh)=0.3999  lr=3.00e-04  (175.8s)


Epoch 87/95: 100%|██████████| 60/60 [02:34<00:00,  2.57s/it, loss=0.7752]


[087] train_loss=0.7752  val_meanDice(no-bg)=0.4590  Dice(edema)=0.7113  Dice(non-enh)=0.3273  Dice(enh)=0.3384  lr=2.91e-04  (192.2s)


Epoch 88/95: 100%|██████████| 60/60 [02:38<00:00,  2.64s/it, loss=0.7823]


[088] train_loss=0.7823  val_meanDice(no-bg)=0.4657  Dice(edema)=0.6458  Dice(non-enh)=0.3351  Dice(enh)=0.4163  lr=2.64e-04  (195.8s)


Epoch 89/95: 100%|██████████| 60/60 [02:33<00:00,  2.56s/it, loss=0.7854]


[089] train_loss=0.7854  val_meanDice(no-bg)=0.4507  Dice(edema)=0.7129  Dice(non-enh)=0.3504  Dice(enh)=0.2889  lr=2.24e-04  (194.0s)


Epoch 90/95: 100%|██████████| 60/60 [02:38<00:00,  2.64s/it, loss=0.7867]


[090] train_loss=0.7867  val_meanDice(no-bg)=0.4707  Dice(edema)=0.6926  Dice(non-enh)=0.3793  Dice(enh)=0.3403  lr=1.75e-04  (197.1s)


Epoch 91/95: 100%|██████████| 60/60 [02:34<00:00,  2.57s/it, loss=0.7773]


[091] train_loss=0.7773  val_meanDice(no-bg)=0.4796  Dice(edema)=0.6917  Dice(non-enh)=0.3785  Dice(enh)=0.3687  lr=1.23e-04  (192.6s)


Epoch 92/95: 100%|██████████| 60/60 [02:38<00:00,  2.64s/it, loss=0.7666]


[092] train_loss=0.7666  val_meanDice(no-bg)=0.4902  Dice(edema)=0.6998  Dice(non-enh)=0.3667  Dice(enh)=0.4043  lr=7.45e-05  (196.2s)


Epoch 93/95: 100%|██████████| 60/60 [02:35<00:00,  2.59s/it, loss=0.7621]


[093] train_loss=0.7621  val_meanDice(no-bg)=0.4867  Dice(edema)=0.7099  Dice(non-enh)=0.3619  Dice(enh)=0.3884  lr=3.48e-05  (193.5s)


Epoch 94/95: 100%|██████████| 60/60 [02:36<00:00,  2.60s/it, loss=0.7550]


[094] train_loss=0.7550  val_meanDice(no-bg)=0.4816  Dice(edema)=0.7178  Dice(non-enh)=0.3639  Dice(enh)=0.3631  lr=9.04e-06  (194.3s)


Epoch 95/95: 100%|██████████| 60/60 [02:34<00:00,  2.57s/it, loss=0.7623]


[095] train_loss=0.7623  val_meanDice(no-bg)=0.4819  Dice(edema)=0.7164  Dice(non-enh)=0.3729  Dice(enh)=0.3563  lr=3.03e-07  (193.8s)
Best val mean Dice (no-bg): 0.5087


#6

In [None]:
# BraTS2024 FLAIR-only (or T2w-only) 3D U-Net

import os, json, time, math, random
from pathlib import Path

import numpy as np
import nibabel as nib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from tqdm import tqdm

import segmentation_models_pytorch_3d as smp3d

# -----------------------------------------------------------------------------
# Global config
# -----------------------------------------------------------------------------
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

SEED = 42

# --------- MODALITY SWITCH ----------
# Use "t2f" for FLAIR-only, later you can change to "t2w"
MODALITY_NAME = "t2w"   # choices: "t2f" (FLAIR), "t2w" (T2-weighted), etc.
# -----------------------------------

# Paths (adjust BRATS2024_ROOT to your actual path)
BRATS2024_ROOT = "/content/drive/MyDrive/BraTS2024_TrainingData_extracted/BraTS2024-BraTS-GLI-TrainingData/training_data1_v2"
JSON_PATH = os.path.join(BRATS2024_ROOT, "dataset_brats2024_cases.json")


# Where to store split file + checkpoints for THIS experiment
SAVE_DIR = f"/content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_{MODALITY_NAME.upper()}"
os.makedirs(SAVE_DIR, exist_ok=True)

#SPLIT_JSON_PATH = os.path.join(SAVE_DIR, "dataset_split_brats2024.json")
SPLIT_JSON_PATH = "/content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2F/dataset_split_brats2024.json"

# Compute & memory knobs (similar to before)
PATCH_SIZE  = (128, 160, 160)  # or even (128, 144, 144) if you want extra safety
OVERLAP     = 0.5              # fine for val/inference
BATCH_SIZE  = 2
ACCUM_STEPS = 3
NUM_WORKERS = 2

# Training schedule
EPOCHS = 75
LR = 3e-4
WEIGHT_DECAY = 1e-4

# Data/spec
IN_CHANNELS = 1      # single modality (FLAIR or T2w)
N_CLASSES   = 4      # 0=bg, 1=ET, 2=NETC, 3=SNFH

# -----------------------------------------------------------------------------
# Reproducibility
# -----------------------------------------------------------------------------
def set_seed(s=SEED):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)

set_seed()
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# -----------------------------------------------------------------------------
# NIfTI I/O + preprocessing
# -----------------------------------------------------------------------------
def percentile_clip(arr, lo=0.5, hi=99.5):
    a = arr.astype(np.float32)
    l, h = np.percentile(a, [lo, hi])
    a = np.clip(a, l, h)
    return a

def zscore_per_channel(x, eps=1e-8):
    # x: (C, D, H, W)
    x = x.astype(np.float32)
    for c in range(x.shape[0]):
        v = x[c]
        mask = (v != 0)
        if mask.any():
            m = v[mask].mean()
            s = v[mask].std()
        else:
            m = v.mean()
            s = v.std()
        s = max(float(s), eps)
        x[c] = (v - m) / s
    return x

def load_single_modality_chdw(path_img):
    """
    BraTS2024 modality loader.

    Each file is (H, W, D) with a single modality.
    We:
      - percentile clip
      - normalize
      - convert to (C=1, D, H, W)
    """
    img = nib.load(str(path_img))
    arr = img.get_fdata(dtype=np.float32)   # (H, W, D)
    assert arr.ndim == 3, f"{path_img} shape {arr.shape} not HWD"

    arr = percentile_clip(arr, 0.5, 99.5)     # (H, W, D)
    vol = np.transpose(arr, (2, 0, 1))        # (D, H, W)
    vol = vol[None, ...]                      # (C=1, D, H, W)
    vol = zscore_per_channel(vol)
    return vol

def load_nifti_label_dhw(path_lab):
    """
    Label loader. Ensures labels are in {0,1,2,3}.
    If any label > 3 appears, we map it to 0 (background).
    """
    lab = nib.load(str(path_lab)).get_fdata(dtype=np.float32)  # H, W, D
    lab = np.rint(lab).astype(np.int64)
    # force into 0..3
    lab[lab < 0] = 0
    lab[lab > 3] = 0
    lab = np.transpose(lab, (2, 0, 1))  # (D, H, W)
    return lab

# -----------------------------------------------------------------------------
# Cropping & Sliding-window inference
# -----------------------------------------------------------------------------
def random_crop_3d(image_cdhw, label_dhw, crop, tries=8, fg_bias=0.5):
    C, D, H, W = image_cdhw.shape
    cd, ch, cw = crop
    assert D >= cd and H >= ch and W >= cw, f"Patch {crop} > vol {(D,H,W)}"
    for _ in range(tries):
        z0 = np.random.randint(0, D - cd + 1)
        y0 = np.random.randint(0, H - ch + 1)
        x0 = np.random.randint(0, W - cw + 1)
        patch_lab = label_dhw[z0:z0+cd, y0:y0+ch, x0:x0+cw]
        if (np.random.rand() > fg_bias) or np.any(patch_lab > 0):
            return image_cdhw[:, z0:z0+cd, y0:y0+ch, x0:x0+cw], patch_lab
    # fallback random
    z0 = np.random.randint(0, D - cd + 1)
    y0 = np.random.randint(0, H - ch + 1)
    x0 = np.random.randint(0, W - cw + 1)
    return image_cdhw[:, z0:z0+cd, y0:y0+ch, x0:x0+cw], label_dhw[z0:z0+cd, y0:y0+ch, x0:x0+cw]

@torch.no_grad()
def sliding_window_inference(volume_cdhw, model, window, overlap, device="cuda"):
    model.eval()
    C, D, H, W = volume_cdhw.shape
    wd, wh, ww = window
    sd = max(1, int(wd * (1 - overlap)))
    sh = max(1, int(wh * (1 - overlap)))
    sw = max(1, int(ww * (1 - overlap)))

    out_prob = torch.zeros((N_CLASSES, D, H, W), dtype=torch.float32, device=device)
    out_norm = torch.zeros((1, D, H, W), dtype=torch.float32, device=device)

    for z in range(0, max(D - wd + 1, 1), sd):
        z0 = min(z, D - wd)
        for y in range(0, max(H - wh + 1, 1), sh):
            y0 = min(y, H - wh)
            for x in range(0, max(W - ww + 1, 1), sw):
                x0 = min(x, W - ww)
                patch = volume_cdhw[:, z0:z0+wd, y0:y0+wh, x0:x0+ww]
                pt = torch.from_numpy(patch).unsqueeze(0).to(device)  # 1,C,D,H,W
                logits = model(pt)                                    # 1,C,d,h,w
                probs = F.softmax(logits, dim=1)[0]
                out_prob[:, z0:z0+wd, y0:y0+wh, x0:x0+ww] += probs
                out_norm[:, z0:z0+wd, y0:y0+wh, x0:x0+ww] += 1.0

    out_prob /= (out_norm + 1e-8)
    pred = torch.argmax(out_prob, dim=0)  # D,H,W
    return pred, out_prob

# -----------------------------------------------------------------------------
# Dataset
# -----------------------------------------------------------------------------
class Brats2024TrainPatches(Dataset):
    def __init__(self, items, modality_name=MODALITY_NAME):
        self.items = items
        self.modality_name = modality_name

    def __len__(self):
        return len(self.items)

    def __getitem__(self, i):
        it = self.items[i]

        img_path = it[self.modality_name]
        lab_path = it["seg"]

        assert img_path is not None, f"{self.modality_name} missing for case {it['case_id']}"

        img = load_single_modality_chdw(img_path)  # (1, D, H, W)
        lab = load_nifti_label_dhw(lab_path)       # (D, H, W)

        img, lab = random_crop_3d(img, lab, PATCH_SIZE, fg_bias=0.85)

        # Rotations in H,W
        if np.random.rand() < 0.5:
            k = np.random.randint(0, 4)
            img = np.rot90(img, k=k, axes=(2, 3)).copy()
            lab = np.rot90(lab, k=k, axes=(1, 2)).copy()

        # Flips in D,H,W
        if np.random.rand() < 0.33:
            img = np.flip(img, axis=1).copy()
            lab = np.flip(lab, axis=0).copy()
        if np.random.rand() < 0.33:
            img = np.flip(img, axis=2).copy()
            lab = np.flip(lab, axis=1).copy()
        if np.random.rand() < 0.33:
            img = np.flip(img, axis=3).copy()
            lab = np.flip(lab, axis=2).copy()

        # Light intensity aug
        if np.random.rand() < 0.15:
            img = img * (0.9 + 0.2 * np.random.rand())
        if np.random.rand() < 0.15:
            img = img + np.random.normal(0, 0.05, img.shape).astype(np.float32)

        return torch.from_numpy(img), torch.from_numpy(lab).long()

# -----------------------------------------------------------------------------
# Losses & metrics
# -----------------------------------------------------------------------------
class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1.0, include_bg=False):
        super().__init__()
        self.smooth = smooth
        self.include_bg = include_bg

    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        N, C, D, H, W = probs.shape
        onehot = torch.zeros_like(probs)
        onehot.scatter_(1, targets.unsqueeze(1), 1)
        s = 0 if self.include_bg else 1
        p = probs[:, s:, ...]
        t = onehot[:, s:, ...]
        dims = (0, 2, 3, 4)
        inter = torch.sum(p * t, dims)
        denom = torch.sum(p + t, dims)
        dice = (2.0 * inter + self.smooth) / (denom + self.smooth)
        return 1.0 - dice.mean()

class CEDice(nn.Module):
    def __init__(self, w_ce=1.0, w_dice=1.0, include_bg_dice=False):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()
        self.dice = SoftDiceLoss(include_bg=include_bg_dice)
        self.w_ce = w_ce
        self.w_dice = w_dice

    def forward(self, logits, y):
        return self.w_ce * self.ce(logits, y) + self.w_dice * self.dice(logits, y)

@torch.no_grad()
def dice_per_class(pred, target, n_classes=N_CLASSES):
    if pred.ndim == 3:
        pred = pred.unsqueeze(0)
        target = target.unsqueeze(0)
    out = []
    for c in range(n_classes):
        p = (pred == c)
        t = (target == c)
        inter = (p & t).sum().item()
        denom = p.sum().item() + t.sum().item()
        out.append((2 * inter / denom) if denom > 0 else 1.0)
    return out

@torch.no_grad()
def dice_whole_tumor(pred, target, eps=1e-8):
    """
    pred, target: (D,H,W) int
    Dice for "whole tumor" = label > 0
    """
    p = (pred > 0)
    t = (target > 0)
    inter = (p & t).sum().item()
    denom = p.sum().item() + t.sum().item()
    if denom == 0:
        return 1.0
    return 2.0 * inter / (denom + eps)

# -----------------------------------------------------------------------------
# Data split (75% train, 5% val, 20% test)
# -----------------------------------------------------------------------------
def load_train_val_from_json(json_path, train_frac=0.75, val_frac=0.05):
    """
    If split JSON exists, reload it.
    Otherwise:
      - read dataset_brats2024_cases.json
      - shuffle
      - create 75% train / 5% val / 20% test split
      - save split JSON in SAVE_DIR
    """
    if os.path.exists(SPLIT_JSON_PATH):
        with open(SPLIT_JSON_PATH, "r") as f:
            split = json.load(f)
        train_items = split["train"]
        val_items   = split["val"]
        test_items  = split["test"]
        print(f"Loaded existing split from {SPLIT_JSON_PATH}")
        print(f"Train={len(train_items)}  Val={len(val_items)}  Test={len(test_items)}")
        return train_items, val_items, test_items

    with open(json_path, "r") as f:
        js = json.load(f)

    items = js["training"]
    n = len(items)
    idx = np.arange(n)
    rng = np.random.RandomState(SEED)
    rng.shuffle(idx)

    n_train = int(n * train_frac)
    n_val   = int(n * val_frac)
    n_test  = n - n_train - n_val

    train_idx = idx[:n_train]
    val_idx   = idx[n_train:n_train+n_val]
    test_idx  = idx[n_train+n_val:]

    train_items = [items[int(i)] for i in train_idx]
    val_items   = [items[int(i)] for i in val_idx]
    test_items  = [items[int(i)] for i in test_idx]

    split = {
        "train": train_items,
        "val":   val_items,
        "test":  test_items,
    }
    with open(SPLIT_JSON_PATH, "w") as f:
        json.dump(split, f, indent=2)

    print(f"Created new split at {SPLIT_JSON_PATH}")
    print(f"Train={len(train_items)}  Val={len(val_items)}  Test={len(test_items)}")
    return train_items, val_items, test_items

# -----------------------------------------------------------------------------
# Model
# -----------------------------------------------------------------------------
def build_model(device):
    model = smp3d.Unet(
        encoder_name="efficientnet-b7",
        encoder_weights=None,
        in_channels=IN_CHANNELS,
        classes=N_CLASSES,
        decoder_channels=(192, 128, 64, 32, 16),
    )
    return model.to(device)

# -----------------------------------------------------------------------------
# Validation (full volume, per-class + whole tumor)
# -----------------------------------------------------------------------------
@torch.no_grad()
def validate_full_volume(model, val_items, device):
    model.eval()

    mean_dice_all = []         # mean over classes 1-3
    per_class_dice = {1: [], 2: [], 3: []}
    whole_dice_list = []

    for it in tqdm(val_items, desc="Validating (full-volume)", leave=False):
        img_path = it[MODALITY_NAME]
        lab_path = it["seg"]

        if img_path is None:
            continue

        vol = load_single_modality_chdw(img_path)  # (1,D,H,W)
        lab = load_nifti_label_dhw(lab_path)       # (D,H,W)

        pred, _ = sliding_window_inference(vol, model, PATCH_SIZE, OVERLAP, device)

        dices = dice_per_class(pred.cpu(), torch.from_numpy(lab))
        mean_dice_all.append(np.mean(dices[1:]))

        for cls in [1, 2, 3]:
            per_class_dice[cls].append(float(dices[cls]))

        whole_d = dice_whole_tumor(pred.cpu(), torch.from_numpy(lab))
        whole_dice_list.append(float(whole_d))

    final_mean_dice = float(np.mean(mean_dice_all)) if mean_dice_all else 0.0
    final_per_class_dice = {
        cls: float(np.mean(per_class_dice[cls])) if per_class_dice[cls] else 0.0
        for cls in [1, 2, 3]
    }
    final_whole_dice = float(np.mean(whole_dice_list)) if whole_dice_list else 0.0

    return final_mean_dice, final_per_class_dice, final_whole_dice

# -----------------------------------------------------------------------------
# Checkpoint helpers
# -----------------------------------------------------------------------------
def save_checkpoint(path, epoch, model, opt, sched, scaler, best_val_whole):
    torch.save(
        {
            "epoch": epoch,
            "state_dict": model.state_dict(),
            "optimizer": opt.state_dict(),
            "scheduler": sched.state_dict() if sched is not None else {},
            "scaler": scaler.state_dict(),
            "best_val_whole": float(best_val_whole),
            "config": {
                "encoder": "efficientnet-b7",
                "in_channels": IN_CHANNELS,
                "classes": N_CLASSES,
                "patch_size": PATCH_SIZE,
                "modality": MODALITY_NAME,
            },
        },
        path,
    )

def resume_if_possible(model, opt, sched, scaler, load_scheduler=False):
    """
    Allow resuming THIS BraTS2024 experiment if last.pt exists in SAVE_DIR.
    (Don't reuse old MSD checkpoints.)
    """
    last_ckpt = os.path.join(SAVE_DIR, "last.pt")
    if os.path.exists(last_ckpt):
        ckpt = torch.load(last_ckpt, map_location="cpu")

        model.load_state_dict(ckpt["state_dict"])

        if "optimizer" in ckpt:
            opt.load_state_dict(ckpt["optimizer"])

        if load_scheduler and "scheduler" in ckpt and sched is not None:
            try:
                sched.load_state_dict(ckpt["scheduler"])
            except Exception as e:
                print("⚠️ Skipping scheduler state load:", e)

        if "scaler" in ckpt:
            scaler.load_state_dict(ckpt["scaler"])

        start_epoch = int(ckpt.get("epoch", 0)) + 1
        best_val_whole = float(ckpt.get("best_val_whole", -1.0))
        print(f"↩ Resumed from {last_ckpt}: start_epoch={start_epoch}, best_wholeDice={best_val_whole:.4f}")
        return start_epoch, best_val_whole

    return 1, -1.0

# -----------------------------------------------------------------------------
# Train loop
# -----------------------------------------------------------------------------
def train():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Device:", device)
    print("Modality:", MODALITY_NAME)

    train_items, val_items, test_items = load_train_val_from_json(JSON_PATH, train_frac=0.75, val_frac=0.05)
    print(f"Train={len(train_items)}  Val={len(val_items)}  Test={len(test_items)}")

    train_ds = Brats2024TrainPatches(train_items, modality_name=MODALITY_NAME)
    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=(device == "cuda"),
        drop_last=True,
    )

    model = build_model(device)
    loss_fn = CEDice(w_ce=1.0, w_dice=1.0, include_bg_dice=False)
    cls_w = torch.tensor([1.0, 2.0, 2.3, 1.2], device=device)
    loss_fn.ce = nn.CrossEntropyLoss(weight=cls_w, label_smoothing=0.03)

    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    sched = None
    scaler = GradScaler(device="cuda" if device == "cuda" else "cpu")

    start_epoch, best_val_whole = resume_if_possible(model, opt, sched, scaler, load_scheduler=False)

    # OneCycleLR, same style as before
    steps_per_epoch = max(1, len(train_loader) // ACCUM_STEPS)
    remaining_epochs = max(1, EPOCHS - (start_epoch - 1))

    for pg in opt.param_groups:
        pg["lr"] = LR

    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt,
        max_lr=LR,
        steps_per_epoch=steps_per_epoch,
        epochs=remaining_epochs,
        pct_start=0.1,
        div_factor=10,
        final_div_factor=100,
    )

    # training history
    history_path = os.path.join(SAVE_DIR, "training_history.json")
    if os.path.exists(history_path):
        with open(history_path, "r") as f:
            history = json.load(f)
    else:
        history = {
            "epoch": [],
            "train_loss": [],
            "val_mean_dice": [],
            "val_whole_dice": [],
            "dice_ET": [],
            "dice_NETC": [],
            "dice_SNFH": [],
            "lr": [],
        }

    step = 0
    for epoch in range(start_epoch, EPOCHS + 1):
        model.train()
        t0 = time.time()
        running = 0.0
        opt.zero_grad(set_to_none=True)

        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
        for i, (x, y) in enumerate(pbar):
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            with autocast(device_type="cuda" if device == "cuda" else "cpu"):
                logits = model(x)
                loss = loss_fn(logits, y) / ACCUM_STEPS

            scaler.scale(loss).backward()
            if ((i + 1) % ACCUM_STEPS) == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sched.step()

            running += loss.item() * ACCUM_STEPS
            pbar.set_postfix(loss=f"{(running / (i + 1)):.4f}")
            step += 1

        train_loss = running / max(1, len(train_loader))

        # Validation
        val_mean_dice, per_class, val_whole_dice = validate_full_volume(model, val_items, device)
        dt = time.time() - t0

        print(
            f"[{epoch:03d}] "
            f"train_loss={train_loss:.4f}  "
            f"val_meanDice(no-bg)={val_mean_dice:.4f}  "
            f"Dice(ET)={per_class[1]:.4f}  "
            f"Dice(NETC)={per_class[2]:.4f}  "
            f"Dice(SNFH)={per_class[3]:.4f}  "
            f"wholeDice(label>0)={val_whole_dice:.4f}  "
            f"lr={sched.get_last_lr()[0]:.2e}  ({dt:.1f}s)"
        )

        # history
        history["epoch"].append(epoch)
        history["train_loss"].append(train_loss)
        history["val_mean_dice"].append(val_mean_dice)
        history["val_whole_dice"].append(val_whole_dice)
        history["dice_ET"].append(per_class[1])
        history["dice_NETC"].append(per_class[2])
        history["dice_SNFH"].append(per_class[3])
        history["lr"].append(sched.get_last_lr()[0])

        with open(history_path, "w") as f:
            json.dump(history, f, indent=2)

        # save last
        save_checkpoint(
            os.path.join(SAVE_DIR, "last.pt"),
            epoch,
            model,
            opt,
            sched,
            scaler,
            best_val_whole,
        )

        # save best based on whole-tumor dice
        if val_whole_dice > best_val_whole:
            best_val_whole = val_whole_dice
            best_path = os.path.join(SAVE_DIR, f"best_epoch{epoch}_wholeDice{best_val_whole:.4f}.pt")
            save_checkpoint(best_path, epoch, model, opt, sched, scaler, best_val_whole)
            print("✔ Saved", best_path)

    print(f"Best val whole-tumor Dice: {best_val_whole:.4f}")

if __name__ == "__main__":
    train()


  self.setter(val)


Device: cuda
Modality: t2w
Loaded existing split from /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2F/dataset_split_brats2024.json
Train=1012  Val=67  Test=271
Train=1012  Val=67  Test=271


Epoch 1/75: 100%|██████████| 506/506 [21:22<00:00,  2.54s/it, loss=1.8527]


[001] train_loss=1.8527  val_meanDice(no-bg)=0.3758  Dice(ET)=0.4478  Dice(NETC)=0.3364  Dice(SNFH)=0.3433  wholeDice(label>0)=0.3683  lr=4.17e-05  (1535.9s)
✔ Saved /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2W/best_epoch1_wholeDice0.3683.pt


Epoch 2/75: 100%|██████████| 506/506 [16:15<00:00,  1.93s/it, loss=1.2946]


[002] train_loss=1.2946  val_meanDice(no-bg)=0.4347  Dice(ET)=0.6418  Dice(NETC)=0.3190  Dice(SNFH)=0.3433  wholeDice(label>0)=0.3449  lr=7.47e-05  (1059.5s)


Epoch 3/75: 100%|██████████| 506/506 [16:23<00:00,  1.94s/it, loss=1.1461]


[003] train_loss=1.1461  val_meanDice(no-bg)=0.4949  Dice(ET)=0.6418  Dice(NETC)=0.4997  Dice(SNFH)=0.3433  wholeDice(label>0)=0.5267  lr=1.23e-04  (1066.4s)
✔ Saved /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2W/best_epoch3_wholeDice0.5267.pt


Epoch 4/75: 100%|██████████| 506/506 [16:25<00:00,  1.95s/it, loss=1.1082]


[004] train_loss=1.1082  val_meanDice(no-bg)=0.4966  Dice(ET)=0.6418  Dice(NETC)=0.5049  Dice(SNFH)=0.3433  wholeDice(label>0)=0.5281  lr=1.79e-04  (1071.1s)
✔ Saved /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2W/best_epoch4_wholeDice0.5281.pt


Epoch 5/75: 100%|██████████| 506/506 [16:23<00:00,  1.94s/it, loss=1.0932]


[005] train_loss=1.0932  val_meanDice(no-bg)=0.4680  Dice(ET)=0.5137  Dice(NETC)=0.5471  Dice(SNFH)=0.3433  wholeDice(label>0)=0.5796  lr=2.33e-04  (1068.3s)
✔ Saved /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2W/best_epoch5_wholeDice0.5796.pt


Epoch 6/75: 100%|██████████| 506/506 [16:23<00:00,  1.94s/it, loss=1.0866]


[006] train_loss=1.0866  val_meanDice(no-bg)=0.4802  Dice(ET)=0.6119  Dice(NETC)=0.5003  Dice(SNFH)=0.3284  wholeDice(label>0)=0.5334  lr=2.74e-04  (1070.2s)


Epoch 7/75: 100%|██████████| 506/506 [16:19<00:00,  1.94s/it, loss=1.0714]


[007] train_loss=1.0714  val_meanDice(no-bg)=0.4401  Dice(ET)=0.5970  Dice(NETC)=0.5702  Dice(SNFH)=0.1532  wholeDice(label>0)=0.5966  lr=2.97e-04  (1063.5s)
✔ Saved /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2W/best_epoch7_wholeDice0.5966.pt


Epoch 8/75: 100%|██████████| 506/506 [16:17<00:00,  1.93s/it, loss=1.0621]


[008] train_loss=1.0621  val_meanDice(no-bg)=0.4921  Dice(ET)=0.6418  Dice(NETC)=0.5683  Dice(SNFH)=0.2662  wholeDice(label>0)=0.6035  lr=3.00e-04  (1062.9s)
✔ Saved /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2W/best_epoch8_wholeDice0.6035.pt


Epoch 9/75: 100%|██████████| 506/506 [16:16<00:00,  1.93s/it, loss=1.0567]


[009] train_loss=1.0567  val_meanDice(no-bg)=0.4933  Dice(ET)=0.6418  Dice(NETC)=0.5867  Dice(SNFH)=0.2515  wholeDice(label>0)=0.6178  lr=3.00e-04  (1063.0s)
✔ Saved /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2W/best_epoch9_wholeDice0.6178.pt


Epoch 10/75: 100%|██████████| 506/506 [16:16<00:00,  1.93s/it, loss=1.0419]


[010] train_loss=1.0419  val_meanDice(no-bg)=0.4588  Dice(ET)=0.6418  Dice(NETC)=0.5738  Dice(SNFH)=0.1607  wholeDice(label>0)=0.6165  lr=2.99e-04  (1062.1s)


Epoch 11/75: 100%|██████████| 506/506 [16:16<00:00,  1.93s/it, loss=1.0360]


[011] train_loss=1.0360  val_meanDice(no-bg)=0.5077  Dice(ET)=0.6418  Dice(NETC)=0.5945  Dice(SNFH)=0.2869  wholeDice(label>0)=0.6275  lr=2.98e-04  (1061.3s)
✔ Saved /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2W/best_epoch11_wholeDice0.6275.pt


Epoch 12/75: 100%|██████████| 506/506 [16:17<00:00,  1.93s/it, loss=1.0340]


[012] train_loss=1.0340  val_meanDice(no-bg)=0.5013  Dice(ET)=0.6269  Dice(NETC)=0.5929  Dice(SNFH)=0.2842  wholeDice(label>0)=0.6337  lr=2.97e-04  (1063.8s)
✔ Saved /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2W/best_epoch12_wholeDice0.6337.pt


Epoch 13/75: 100%|██████████| 506/506 [16:17<00:00,  1.93s/it, loss=1.0233]


[013] train_loss=1.0233  val_meanDice(no-bg)=0.4510  Dice(ET)=0.4926  Dice(NETC)=0.5864  Dice(SNFH)=0.2740  wholeDice(label>0)=0.6223  lr=2.95e-04  (1064.0s)


Epoch 14/75: 100%|██████████| 506/506 [16:19<00:00,  1.94s/it, loss=1.0204]


[014] train_loss=1.0204  val_meanDice(no-bg)=0.5087  Dice(ET)=0.6119  Dice(NETC)=0.5773  Dice(SNFH)=0.3368  wholeDice(label>0)=0.6102  lr=2.93e-04  (1065.3s)


Epoch 15/75: 100%|██████████| 506/506 [16:19<00:00,  1.94s/it, loss=1.0124]


[015] train_loss=1.0124  val_meanDice(no-bg)=0.5184  Dice(ET)=0.6058  Dice(NETC)=0.5924  Dice(SNFH)=0.3571  wholeDice(label>0)=0.6194  lr=2.91e-04  (1065.1s)


Epoch 16/75: 100%|██████████| 506/506 [16:17<00:00,  1.93s/it, loss=1.0095]


[016] train_loss=1.0095  val_meanDice(no-bg)=0.4747  Dice(ET)=0.4867  Dice(NETC)=0.6002  Dice(SNFH)=0.3372  wholeDice(label>0)=0.6301  lr=2.88e-04  (1063.1s)


Epoch 17/75: 100%|██████████| 506/506 [16:18<00:00,  1.93s/it, loss=1.0099]


[017] train_loss=1.0099  val_meanDice(no-bg)=0.4985  Dice(ET)=0.5533  Dice(NETC)=0.6024  Dice(SNFH)=0.3397  wholeDice(label>0)=0.6308  lr=2.86e-04  (1066.9s)


Epoch 18/75: 100%|██████████| 506/506 [16:17<00:00,  1.93s/it, loss=0.9996]


[018] train_loss=0.9996  val_meanDice(no-bg)=0.4385  Dice(ET)=0.4541  Dice(NETC)=0.6014  Dice(SNFH)=0.2599  wholeDice(label>0)=0.6251  lr=2.82e-04  (1065.7s)


Epoch 19/75: 100%|██████████| 506/506 [16:17<00:00,  1.93s/it, loss=1.0028]


[019] train_loss=1.0028  val_meanDice(no-bg)=0.4465  Dice(ET)=0.3875  Dice(NETC)=0.6157  Dice(SNFH)=0.3363  wholeDice(label>0)=0.6525  lr=2.79e-04  (1064.5s)
✔ Saved /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2W/best_epoch19_wholeDice0.6525.pt


Epoch 20/75: 100%|██████████| 506/506 [16:22<00:00,  1.94s/it, loss=0.9970]


[020] train_loss=0.9970  val_meanDice(no-bg)=0.4673  Dice(ET)=0.4636  Dice(NETC)=0.5944  Dice(SNFH)=0.3441  wholeDice(label>0)=0.6258  lr=2.75e-04  (1067.4s)


Epoch 21/75: 100%|██████████| 506/506 [16:20<00:00,  1.94s/it, loss=0.9937]


[021] train_loss=0.9937  val_meanDice(no-bg)=0.4314  Dice(ET)=0.3656  Dice(NETC)=0.5882  Dice(SNFH)=0.3406  wholeDice(label>0)=0.6386  lr=2.71e-04  (1067.4s)


Epoch 22/75: 100%|██████████| 506/506 [16:23<00:00,  1.94s/it, loss=0.9967]


[022] train_loss=0.9967  val_meanDice(no-bg)=0.5123  Dice(ET)=0.5412  Dice(NETC)=0.6266  Dice(SNFH)=0.3691  wholeDice(label>0)=0.6535  lr=2.67e-04  (1069.6s)
✔ Saved /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2W/best_epoch22_wholeDice0.6535.pt


Epoch 23/75: 100%|██████████| 506/506 [16:24<00:00,  1.95s/it, loss=0.9883]


[023] train_loss=0.9883  val_meanDice(no-bg)=0.3885  Dice(ET)=0.2912  Dice(NETC)=0.6102  Dice(SNFH)=0.2640  wholeDice(label>0)=0.6471  lr=2.63e-04  (1071.2s)


Epoch 24/75: 100%|██████████| 506/506 [16:24<00:00,  1.95s/it, loss=0.9894]


[024] train_loss=0.9894  val_meanDice(no-bg)=0.4313  Dice(ET)=0.3429  Dice(NETC)=0.6106  Dice(SNFH)=0.3404  wholeDice(label>0)=0.6505  lr=2.58e-04  (1070.1s)


Epoch 25/75: 100%|██████████| 506/506 [16:24<00:00,  1.95s/it, loss=0.9847]


[025] train_loss=0.9847  val_meanDice(no-bg)=0.4288  Dice(ET)=0.3561  Dice(NETC)=0.6129  Dice(SNFH)=0.3174  wholeDice(label>0)=0.6452  lr=2.53e-04  (1070.4s)


Epoch 26/75: 100%|██████████| 506/506 [16:24<00:00,  1.95s/it, loss=0.9836]


[026] train_loss=0.9836  val_meanDice(no-bg)=0.4446  Dice(ET)=0.4004  Dice(NETC)=0.5575  Dice(SNFH)=0.3760  wholeDice(label>0)=0.6109  lr=2.48e-04  (1072.5s)


Epoch 27/75: 100%|██████████| 506/506 [16:24<00:00,  1.95s/it, loss=0.9843]


[027] train_loss=0.9843  val_meanDice(no-bg)=0.4652  Dice(ET)=0.4312  Dice(NETC)=0.6230  Dice(SNFH)=0.3413  wholeDice(label>0)=0.6619  lr=2.42e-04  (1073.6s)
✔ Saved /content/drive/MyDrive/BrainTumor_Checkpoints_BraTS2024_T2W/best_epoch27_wholeDice0.6619.pt


Epoch 28/75: 100%|██████████| 506/506 [16:23<00:00,  1.94s/it, loss=0.9799]


[028] train_loss=0.9799  val_meanDice(no-bg)=0.4504  Dice(ET)=0.3871  Dice(NETC)=0.6127  Dice(SNFH)=0.3513  wholeDice(label>0)=0.6513  lr=2.37e-04  (1069.3s)


Epoch 29/75: 100%|██████████| 506/506 [16:23<00:00,  1.94s/it, loss=0.9763]


[029] train_loss=0.9763  val_meanDice(no-bg)=0.5018  Dice(ET)=0.5403  Dice(NETC)=0.6349  Dice(SNFH)=0.3304  wholeDice(label>0)=0.6615  lr=2.31e-04  (1071.2s)


Epoch 30/75: 100%|██████████| 506/506 [16:24<00:00,  1.94s/it, loss=0.9756]


[030] train_loss=0.9756  val_meanDice(no-bg)=0.5060  Dice(ET)=0.5183  Dice(NETC)=0.6185  Dice(SNFH)=0.3811  wholeDice(label>0)=0.6533  lr=2.25e-04  (1073.8s)


Epoch 31/75: 100%|██████████| 506/506 [16:23<00:00,  1.94s/it, loss=0.9712]


[031] train_loss=0.9712  val_meanDice(no-bg)=0.4890  Dice(ET)=0.4972  Dice(NETC)=0.6166  Dice(SNFH)=0.3531  wholeDice(label>0)=0.6514  lr=2.19e-04  (1071.9s)


Epoch 32/75: 100%|██████████| 506/506 [16:24<00:00,  1.94s/it, loss=0.9703]


[032] train_loss=0.9703  val_meanDice(no-bg)=0.4802  Dice(ET)=0.4615  Dice(NETC)=0.6292  Dice(SNFH)=0.3497  wholeDice(label>0)=0.6590  lr=2.13e-04  (1073.7s)


Epoch 33/75: 100%|██████████| 506/506 [16:24<00:00,  1.94s/it, loss=0.9670]


[033] train_loss=0.9670  val_meanDice(no-bg)=0.4811  Dice(ET)=0.4686  Dice(NETC)=0.6326  Dice(SNFH)=0.3422  wholeDice(label>0)=0.6615  lr=2.06e-04  (1069.2s)


Epoch 34/75: 100%|██████████| 506/506 [16:20<00:00,  1.94s/it, loss=0.9684]


[034] train_loss=0.9684  val_meanDice(no-bg)=0.4184  Dice(ET)=0.3828  Dice(NETC)=0.5973  Dice(SNFH)=0.2750  wholeDice(label>0)=0.6471  lr=2.00e-04  (1065.8s)


Epoch 35/75:   3%|▎         | 13/506 [00:41<26:09,  3.18s/it, loss=0.9875]


KeyboardInterrupt: 