In [None]:
!pip install -q -U monai nibabel

In [None]:
# %% Cell 1 —  setup

import os, random, glob, tarfile, csv, warnings
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm

from monai.data import Dataset, CacheDataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd,
    Spacingd, ScaleIntensityd, RandFlipd, RandAffined,
    RandCropByPosNegLabeld, ConcatItemsd, SelectItemsd, Lambdad
)

warnings.filterwarnings("ignore")

# ── Paths ────────────────────────────────────────────────────────────────────
# Your Kaggle dataset contains the BraTS2021 tar:
BRA2021_TAR = "/kaggle/input/brats-2021-task1/BraTS2021_Training_Data.tar"
EXTRACT_DIR = Path("/kaggle/working/brats2021")  # will extract here
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)

# Where to save logs / checkpoints
WORK_DIR = Path("/kaggle/working")
CSV_LOG  = WORK_DIR / "metrics_phase6.csv"
CKPT_BEST_EMA = WORK_DIR / "best_phase6_ema.pth"

# ── Runtime ──────────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
torch.backends.cudnn.benchmark = True

# Train config
epochs_phase1  = 15
batch_size     = 2               # P100 safe
patch_size     = (112,112,112)
pos_neg_ratio  = 1.0             # more positives
samples_per_vol= 2               # how many patches per volume per epoch
base_lr        = 1e-4            # OneCycle max LR
weight_decay   = 5e-5

print(f"✅ Device: {device}")
print(f"🗂️  Extract to: {EXTRACT_DIR}")


In [None]:
# %% Cell 2 — Extract & index BraTS2021 cases (keys: t1ce, t2, flair, seg)
import tarfile, random
from pathlib import Path

# 1) Extract once
if not any(EXTRACT_DIR.iterdir()):
    print("📦 Extracting BraTS2021_Training_Data.tar ...")
    with tarfile.open(BRA2021_TAR, "r") as tf:
        tf.extractall(path=EXTRACT_DIR)
    print("✅ Extracted.")
else:
    print("ℹ️ Already extracted.")

# 2) Match each seg with t1ce, t2, flair
def find_cases_by_seg(root_dir: Path):
    cases = []
    seg_files = sorted(root_dir.rglob("*_seg.nii.gz"))
    print(f"🔎 Found {len(seg_files)} seg files. Indexing…")
    missing = 0
    for seg_path in seg_files:
        base = seg_path.name[:-len("_seg.nii.gz")]  # e.g., BraTS2021_01079
        d = seg_path.parent
        t1ce = d / f"{base}_t1ce.nii.gz"
        t2   = d / f"{base}_t2.nii.gz"
        flair= d / f"{base}_flair.nii.gz"
        if t1ce.exists() and t2.exists() and flair.exists():
            cases.append({
                "t1ce": str(t1ce),
                "t2":   str(t2),
                "flair":str(flair),
                "seg":  str(seg_path),
            })
        else:
            missing += 1
    if missing:
        print(f"⚠️ {missing} segs missing one or more modalities.")
    return cases

cases = find_cases_by_seg(EXTRACT_DIR)[:500]
print(f"✅ Usable cases: {len(cases)}")

# 3) Train/Val split (shuffle then 90/10)
random.seed(42)
random.shuffle(cases)
val_frac = 0.20
n_val = max(1, int(len(cases) * val_frac))
train_dicts = cases[:-n_val]
val_dicts   = cases[-n_val:]
print(f"📊 Train: {len(train_dicts)}  |  Val: {len(val_dicts)}")

# quick sanity peek
if len(cases):
    print("🧪 sample case dict:", train_dicts[0])


In [None]:
# %% Cell 3 — Transforms (orientation, spacing, label remap, patch sampling)

from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd,
    Spacingd, ScaleIntensityd, RandFlipd, RandAffined,
    RandCropByPosNegLabeld, ConcatItemsd, SelectItemsd, Lambdad
)
from monai.data import CacheDataset
from monai.data import CacheDataset, list_data_collate
from torch.utils.data import DataLoader
import numpy as np

# BraTS2021 labels: 0=bg, 1=NT, 2=ED, 4=ET
# Map to {0:bg, 1:TC, 2:ED, 3:ET}, where TC = {1,4}
def brats2021_label_map(seg: np.ndarray):
    seg = seg.astype(np.uint8)
    out = np.zeros_like(seg, dtype=np.uint8)
    out[seg == 2] = 2   # ED
    out[seg == 1] = 1   # TC (necrotic/non-enhancing)
    out[seg == 4] = 3   # ET
    return out

# ----------------- TRAIN TRANSFORMS -----------------
train_transforms = Compose([
    LoadImaged(keys=["t1ce", "t2", "flair", "seg"]),
    EnsureChannelFirstd(keys=["t1ce", "t2", "flair", "seg"]),
    Orientationd(keys=["t1ce", "t2", "flair", "seg"], axcodes="RAS"),
    # Spacingd(keys=["t1ce", "t2", "flair", "seg"], pixdim=(1.0,1.0,1.0),
    #          mode=("bilinear","bilinear","bilinear","nearest")),
    ScaleIntensityd(keys=["t1ce", "t2", "flair"]),
    Lambdad(keys="seg", func=brats2021_label_map),

    ConcatItemsd(keys=["t1ce", "t2", "flair"], name="image"),
    SelectItemsd(keys=["image", "seg"]),

    # Paired spatial augments
    RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=2),
    RandAffined(
        keys=["image", "seg"], prob=1.0,
        rotate_range=(0.10, 0.10, 0.10),
        translate_range=(10, 10, 10),
        scale_range=(0.10, 0.10, 0.10),
        mode=("bilinear", "nearest"),
        padding_mode="zeros"
    ),

    # Foreground-focused random crops
    RandCropByPosNegLabeld(
        keys=["image", "seg"],
        label_key="seg",
        spatial_size=patch_size, 
        pos=pos_neg_ratio, neg=1.0,
        num_samples=samples_per_vol,
        image_key=None, image_threshold=0,
    ),

    EnsureTyped(keys=["image", "seg"]),
])

# ----------------- VAL TRANSFORMS -----------------
val_transforms = Compose([
    LoadImaged(keys=["t1ce", "t2", "flair", "seg"]),
    EnsureChannelFirstd(keys=["t1ce", "t2", "flair", "seg"]),
    Orientationd(keys=["t1ce", "t2", "flair", "seg"], axcodes="RAS"),
    # Spacingd(keys=["t1ce", "t2", "flair", "seg"], pixdim=(1.0,1.0,1.0),
    #          mode=("bilinear","bilinear","bilinear","nearest")),
    ScaleIntensityd(keys=["t1ce", "t2", "flair"]),
    Lambdad(keys="seg", func=brats2021_label_map),
    ConcatItemsd(keys=["t1ce", "t2", "flair"], name="image"),
    SelectItemsd(keys=["image", "seg"]),
    EnsureTyped(keys=["image", "seg"]),
])

# ----------------- DATASETS & LOADERS -----------------
cache_rate = 0.0  # raise if you’ve got disk/time
# ...

train_ds = CacheDataset(train_dicts, transform=train_transforms, cache_rate=cache_rate, num_workers=0)
val_ds   = CacheDataset(val_dicts,   transform=val_transforms,   cache_rate=1.0,       num_workers=0)

# IMPORTANT: use list_data_collate for training when using RandCropByPosNegLabeld(num_samples>1)
train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    collate_fn=list_data_collate,
)

# Val has no multi-sample cropping; default collate is fine
val_loader   = DataLoader(
    val_ds,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)

# sanity check
b = next(iter(train_loader))
print("🧪 batch image:", tuple(b["image"].shape), "(expect [B,3,Z,Y,X])")
print("🧪 batch seg  :", tuple(b["seg"].shape),   "(expect [B,1,Z,Y,X])")


In [None]:
# %% Cell 4 — Model (same UNet family as before, for 3 modalities)
from torch.utils.checkpoint import checkpoint

def Norm3d(ch, use_gn=False):
    return nn.GroupNorm(8,ch) if use_gn else nn.InstanceNorm3d(ch)

class ResidualConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, use_gn=False):
        super().__init__()
        self.conv1 = nn.Conv3d(in_ch, out_ch, 3, padding=1, bias=False)
        self.norm1 = Norm3d(out_ch, use_gn)
        self.conv2 = nn.Conv3d(out_ch, out_ch, 3, padding=1, bias=False)
        self.norm2 = Norm3d(out_ch, use_gn)
        self.relu  = nn.ReLU(inplace=True)
        self.res   = nn.Conv3d(in_ch, out_ch, 1, bias=False) if in_ch!=out_ch else nn.Identity()
    def forward(self, x):
        def _f(z):
            y = self.relu(self.norm1(self.conv1(z)))
            y = self.norm2(self.conv2(y))
            return self.relu(y + self.res(z))
        return checkpoint(_f, x)

class Encoder(nn.Module):
    def __init__(self, in_ch, feats=(24,48,96), use_gn=False):
        super().__init__()
        layers=[]; ch=in_ch
        for f in feats:
            layers += [ResidualConvBlock(ch,f,use_gn), nn.MaxPool3d(2)]
            ch=f
        self.layers = nn.ModuleList(layers)
    def forward(self, x):
        skips=[]
        for i in range(0,len(self.layers),2):
            x = self.layers[i](x); skips.append(x)
            x = self.layers[i+1](x)
        return skips

class CBAM3D(nn.Module):
    def __init__(self, ch, red=16):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool3d(1); self.max = nn.AdaptiveMaxPool3d(1)
        self.mlp = nn.Sequential(nn.Conv3d(ch, ch//red, 1, bias=False), nn.ReLU(),
                                 nn.Conv3d(ch//red, ch, 1, bias=False))
        self.sig = nn.Sigmoid()
        self.spat = nn.Sequential(nn.Conv3d(1,1,7,padding=3,bias=False), nn.Sigmoid())
    def forward(self, x):
        c = self.sig(self.mlp(self.avg(x)) + self.mlp(self.max(x)))
        x = x * c
        avg_map = x.mean(1, keepdim=True); max_map = x.amax(1, keepdim=True)
        return x * self.spat(0.5*(avg_map+max_map))

class SkipFusionBlock(nn.Module):
    def __init__(self, ch, use_gn=False):
        super().__init__()
        def gate():
            return nn.Sequential(nn.AdaptiveAvgPool3d(1),
                                 nn.Conv3d(ch, ch//16, 1, bias=False), nn.ReLU(),
                                 nn.Conv3d(ch//16, ch, 1, bias=False), nn.Sigmoid())
        self.g1, self.g2, self.g3 = gate(), gate(), gate()
        self.conv = nn.Conv3d(3*ch, ch, 1, bias=False); self.norm = Norm3d(ch, use_gn)
    def forward(self, f1,f2,f3):
        f1=f1*self.g1(f1); f2=f2*self.g2(f2); f3=f3*self.g3(f3)
        return self.norm(self.conv(torch.cat([f1,f2,f3],1)))

class Decoder(nn.Module):
    def __init__(self, feats=(96,48,24), out_ch=4, bottleneck_ch=256, ds_at=(0,1,2), use_gn=False):
        super().__init__()
        self.up, self.dec, self.aux = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        self.ds_at=set(ds_at)
        in_feats=[bottleneck_ch]+list(feats[:-1])
        for cin,c in zip(in_feats, feats):
            self.up.append(nn.ConvTranspose3d(cin,c,2,2))
            self.dec.append(ResidualConvBlock(c*2,c,use_gn))
            self.aux.append(nn.Conv3d(c,out_ch,1))
        self.final = nn.Conv3d(feats[-1], out_ch, 1)
    def forward(self,x,skips):
        aux_outs=[]
        for i,(up,dec) in enumerate(zip(self.up,self.dec)):
            x = up(x); s=skips[i]
            if x.shape[2:] != s.shape[2:]:
                x = F.interpolate(x, size=s.shape[2:], mode='trilinear', align_corners=False)
            x = dec(torch.cat([x,s],1))
            aux_outs.append(self.aux[i](x) if i in self.ds_at else None)
        return self.final(x), aux_outs

class ModalitySelectiveUNet(nn.Module):
    def __init__(self, feats=(24,48,96), bottleneck_ch=256, use_gn=False, ds_at=(0,1,2)):
        super().__init__()
        self.encoders = nn.ModuleList([Encoder(1,feats,use_gn) for _ in range(3)])
        self.bottleneck = nn.Sequential(
            nn.Conv3d(feats[-1], bottleneck_ch, 1, bias=False),
            CBAM3D(bottleneck_ch),
            ResidualConvBlock(bottleneck_ch, bottleneck_ch, use_gn)
        )
        self.skip_fusers = nn.ModuleList([SkipFusionBlock(ch,use_gn) for ch in feats[::-1]])
        self.decoder = Decoder(feats[::-1], out_ch=4, bottleneck_ch=bottleneck_ch, ds_at=ds_at, use_gn=use_gn)
    def forward(self, x):
        m1,m2,m3 = x[:,0:1], x[:,1:2], x[:,2:3]
        skips_all = [enc(m) for enc,m in zip(self.encoders,(m1,m2,m3))]
        deep = torch.stack([s[-1] for s in skips_all], dim=1)       # [B,3,C,D,H,W]
        attn = deep.mean(dim=[2,3,4,5])                             # [B,3]
        alpha = torch.softmax(attn, dim=1).view(-1,3,1,1,1,1)
        fused = alpha[:,0]*skips_all[0][-1] + alpha[:,1]*skips_all[1][-1] + alpha[:,2]*skips_all[2][-1]
        x = self.bottleneck(fused)
        fused_skips = [fuser(*feats_at_scale) for feats_at_scale, fuser in zip(zip(*[reversed(s) for s in skips_all]), self.skip_fusers)]
        return self.decoder(x, fused_skips)
SAVED_CK_PT = "/kaggle/input/best_phase_5/pytorch/default/1/best_phase5_ema.pth"
model = ModalitySelectiveUNet().to(device)
model.load_state_dict(torch.load(SAVED_CK_PT, map_location=device))
print("🧠 Trainable params:", sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
# %% Cell 5 — Losses, deep supervision, metrics, TTA

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3, gamma=2.0, weights=None, include_bg=False, eps=1e-6):
        super().__init__()
        self.a, self.b, self.g, self.eps = alpha, beta, gamma, eps
        self.include_bg = include_bg
        self.register_buffer("w", None if weights is None else torch.as_tensor(weights, dtype=torch.float32))
    def forward(self, logits, target):
        prob = F.softmax(logits, 1)
        if target.shape[1] == 1:
            target = torch.zeros_like(prob).scatter_(1, target.long(), 1)
        if not self.include_bg:
            prob, target = prob[:,1:], target[:,1:]
            w = self.w[1:] if self.w is not None else None
        else:
            w = self.w
        dims = tuple(range(2, logits.ndim))
        tp = (prob*target).sum(dims); fp = (prob*(1-target)).sum(dims); fn = ((1-prob)*target).sum(dims)
        tv = (tp+self.eps)/(tp + self.a*fp + self.b*fn + self.eps)
        loss = (1.0 - tv) ** self.g
        if w is not None: loss = loss * w
        return loss.mean()

# bg ignored since include_bg=False
class_weights   = torch.tensor([0.10, 0.30, 0.25, 0.35], device=device)
criterion_main  = FocalTverskyLoss(alpha=0.7, beta=0.3, gamma=2.0, weights=class_weights, include_bg=False)
bce             = nn.BCEWithLogitsLoss()
wt_aux_weight   = 0.5

def wt_logits(logits):    return torch.logsumexp(logits[:,1:4], dim=1, keepdim=True)
def wt_targets(target):   return (target > 0).long()

def _resize_target_like(logits, target):
    if target.shape[2:] != logits.shape[2:]:
        return F.interpolate(target.float(), size=logits.shape[2:], mode="nearest").long()
    return target

aux_weights = [0.25, 0.15, 0.10]
def loss_with_deep_supervision(preds, aux_list, target):
    tgt_main = _resize_target_like(preds, target)
    loss = criterion_main(preds, tgt_main)
    loss = loss + wt_aux_weight * bce(wt_logits(preds), wt_targets(tgt_main).float())
    for w, aux in zip(aux_weights, aux_list):
        if aux is not None:
            tgt_aux = _resize_target_like(aux, target)
            loss = loss + w * criterion_main(aux, tgt_aux)
    return loss

@torch.no_grad()
def wt_soft_dice_from_logits(logits, target, eps=1e-6):
    logits = logits.detach().float()
    prob = torch.softmax(logits, dim=1).clamp(1e-7, 1.0)
    if target.shape[1] == 1:
        oh = torch.zeros_like(prob); oh.scatter_(1, target.long(), 1)
    else:
        oh = target.float()
    p_wt = prob[:,1:4].amax(1); t_wt = oh[:,1:4].amax(1)
    inter = (p_wt*t_wt).sum([1,2,3]); p2=(p_wt*p_wt).sum([1,2,3]); t2=(t_wt*t_wt).sum([1,2,3])
    denom = p2 + t2
    dice = (2*inter + eps)/(denom + eps)
    dice = torch.where(denom < eps, torch.ones_like(dice), dice)
    return torch.nan_to_num(dice, nan=0.0).mean().item()

def predict_tta(model, x):
    outs=[]
    for dims in [None, (3,), (4,), (3,4,)]:   # flip Y, X on ZYX tensors
        xx = torch.flip(x, dims) if dims else x
        y, _ = model(xx)
        y = torch.flip(y, dims) if dims else y
        outs.append(y)
    return torch.stack(outs,0).mean(0)


In [None]:
# %% Cell 6 — EMA helper, optimizer, OneCycleLR, CSV
class ModelEMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {k: v.detach().clone() for k,v in model.state_dict().items() if v.dtype.is_floating_point}
        self.collected = {}
    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items():
            if k in self.shadow and v.dtype.is_floating_point:
                self.shadow[k].mul_(self.decay).add_(v, alpha=1.0 - self.decay)
    def apply_shadow(self, model):
        self.collected = {}
        msd = model.state_dict()
        for k in self.shadow:
            self.collected[k] = msd[k].detach().clone()
            msd[k].copy_(self.shadow[k])
    def restore(self, model):
        msd = model.state_dict()
        for k,v in self.collected.items():
            msd[k].copy_(v)
        self.collected = {}

ema = ModelEMA(model, decay=0.999)

optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay, betas=(0.9, 0.999))
scaler = GradScaler()

steps_per_epoch = max(1, len(train_loader))
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=base_lr,
    epochs=epochs_phase1,
    steps_per_epoch=steps_per_epoch,
    pct_start=0.1,
    anneal_strategy="cos",
    div_factor=25.0,
    final_div_factor=1e4,
)

if not CSV_LOG.exists():
    with open(CSV_LOG, "w", newline="") as f:
        csv.writer(f).writerow(["epoch","train_loss","val_loss","train_wt","val_wt","lr"])
best_wt = 0.0


In [None]:
# %% Cell 7 — Train/Val loop (100 epochs, AMP+EMA+TTA val)
from torch.nn.utils import clip_grad_norm_

for epoch in range(1, epochs_phase1 + 1):
    model.train()
    tr_loss = tr_wt = 0.0
    pbar = tqdm(train_loader, desc=f"🟢 Train {epoch:03d}/{epochs_phase1}", leave=False)

    for batch in pbar:
        imgs = batch["image"].to(device)  # [B,3,Z,Y,X]
        segs = batch["seg"].to(device)    # [B,1,Z,Y,X]
        optimizer.zero_grad(set_to_none=True)

        with autocast():
            preds, aux = model(imgs)
            loss = loss_with_deep_supervision(preds, aux, segs)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        ema.update(model)

        tr_loss += loss.item()
        tr_wt   += wt_soft_dice_from_logits(preds, segs)

        n_done = pbar.n + 1
        pbar.set_postfix({"loss": f"{tr_loss/n_done:.4f}",
                          "WT": f"{tr_wt/n_done:.4f}",
                          "lr": f"{optimizer.param_groups[0]['lr']:.2e}"})

    # Validation with EMA+TTA
    model.eval(); ema.apply_shadow(model)
    val_loss = val_wt = 0.0
    with torch.no_grad():
        pbar_v = tqdm(val_loader, desc=f"🔵 Val   {epoch:03d}/{epochs_phase1}", leave=False)
        for batch in pbar_v:
            imgs = batch["image"].to(device)
            segs = batch["seg"].to(device)
            with autocast():
                logits = predict_tta(model, imgs)
                val_loss += loss_with_deep_supervision(logits, [None,None,None], segs).item()
            val_wt   += wt_soft_dice_from_logits(logits, segs)
            n_done = pbar_v.n + 1
            pbar_v.set_postfix({"loss": f"{val_loss/n_done:.4f}", "WT": f"{val_wt/n_done:.4f}"})
    ema.restore(model)

    n_tr = len(train_loader); n_val = len(val_loader)
    tr_loss/=n_tr; tr_wt/=n_tr; val_loss/=n_val; val_wt/=n_val

    # Save best EMA
    if val_wt > best_wt:
        best_wt = val_wt
        ema.apply_shadow(model)
        torch.save(model.state_dict(), CKPT_BEST_EMA)
        ema.restore(model)
        badge = " 🏆"
    else:
        badge = ""

    print(f"📊 Ep {epoch:03d}/{epochs_phase1} | TL {tr_loss:.4f}  TW {tr_wt:.4f}  ||  VL {val_loss:.4f}  VW {val_wt:.4f}{badge}")

    with open(CSV_LOG, "a", newline="") as f:
        csv.writer(f).writerow([epoch, tr_loss, val_loss, tr_wt, val_wt, optimizer.param_groups[0]['lr']])

print(f"\n✅ Phase 1 complete. Best EMA WT = {best_wt:.4f}")
print(f"💾 Saved to: {CKPT_BEST_EMA}")
