# Self‑Contained 3D Segmentation — **v2** (SE‑Residual U‑Net++, TTA, LCC post‑proc, 3D‑MAE pretrainer)
This notebook is **self-contained** and includes:

- Reading **nnU‑Net v2** plans (to mirror spacing & patch size for fair comparison)

- **CT/MRI normalization**, **resampling to target spacing** (SimpleITK), light aug

- **SE‑Residual U‑Net++** (residual blocks + squeeze‑excitation + deep supervision)

- **Training loop** (Dice+CE), **AMP**, optional **EMA**, checkpoint by best val Dice

- **Sliding‑window prediction** with **mirror TTA** + **Gaussian blending**

- **Largest-Component Post‑Processing** (good for single-organ tasks)

- **Minimal 3D Masked Autoencoder (MAE) pretrainer** + **transfer** to segmentation encoder


> Place this notebook anywhere (e.g., repo root). Set paths in the config cell.


In [2]:
# === Build ./lists/* from nnU-Net fold-0 splits (Heart + Spleen) ===
import os, json, pathlib

def _find_case(root, cid):
    for ext in (".nii.gz", ".nii"):
        p = os.path.join(root, f"{cid}{ext}")
        if os.path.exists(p):
            return p
    raise FileNotFoundError(f"Case not found with .nii[.gz]: {root}/{cid}")

def make_lists(ds_id: int, name: str, task: str):
    split_p = f"/home/htetaung/data/nnunet_preprocessed/Dataset0{ds_id:02d}_{name}/splits_final.json"
    sp = json.load(open(split_p))[0]   # fold-0
    imgs = f"/home/htetaung/data/MSD/{task}/imagesTr"
    labs = f"/home/htetaung/data/MSD/{task}/labelsTr"
    outdir = pathlib.Path("./lists"); outdir.mkdir(exist_ok=True)

    for tag, ids in (("train_fold0", sp["train"]), ("val_fold0", sp["val"])):
        out = outdir / f"{task}_{tag}.txt"
        with open(out, "w") as f:
            for cid in ids:
                ip = _find_case(imgs, cid)
                lp = _find_case(labs, cid)
                f.write(f"{ip},{lp}\n")
        print(f"Wrote {out}  ({len(ids)} pairs)")

# Heart (Dataset002) + Spleen (Dataset009)
make_lists(2, "Heart",  "Task02_Heart")
make_lists(9, "Spleen", "Task09_Spleen")


Wrote lists/Task02_Heart_train_fold0.txt  (16 pairs)
Wrote lists/Task02_Heart_val_fold0.txt  (4 pairs)
Wrote lists/Task09_Spleen_train_fold0.txt  (32 pairs)
Wrote lists/Task09_Spleen_val_fold0.txt  (9 pairs)


In [3]:
import os, sys, json
LIST_TRAIN = "./lists/Task02_Heart_train_fold0.txt"
LIST_VAL   = "./lists/Task02_Heart_val_fold0.txt"
for p in (LIST_TRAIN, LIST_VAL):
    assert os.path.exists(p), f"Missing: {p}"
print("OK:", LIST_TRAIN, "and", LIST_VAL)

OK: ./lists/Task02_Heart_train_fold0.txt and ./lists/Task02_Heart_val_fold0.txt


In [4]:

# ==== Config ====
DATA_ROOT = "/home/htetaung/data"            # MSD + nnU-Net folders live here
TASK      = "Task02_Heart"                   # or "Task09_Spleen"
DS_ID     = 2 if TASK == "Task02_Heart" else 9
NAME      = "Heart" if DS_ID == 2 else "Spleen"
FOLD      = 0
MODALITY  = "MRI" if TASK == "Task02_Heart" else "CT"
IN_CHANNELS = 1
NUM_CLASSES = 2                              # bg + organ (Heart/Spleen)
PATCH      = (80, 192, 160)                    # you may set to nnU-Net patch printed below if VRAM allows
BATCH_SIZE = 2
MAX_EPOCHS = 50
NUM_WORKERS = 4
AMP = True
USE_EMA = True                               # Exponential Moving Average for stability
OVERLAP = 0.5                                # sliding window overlap
TTA_MIRROR = True                            # mirror test-time augmentation
POSTPROC_KEEP_LARGEST = True                 # largest-component post-proc for class 1

OUT_DIR = f"./runs_selfcontained_v2/{TASK}_seresunetpp_fold{FOLD}"
LIST_TRAIN = f"./lists/{TASK}_train_fold{FOLD}.txt"   # reuse nnU-Net fold split
LIST_VAL   = f"./lists/{TASK}_val_fold{FOLD}.txt"

# MAE pretraining knobs
MAE_PATCH  = (64, 64, 64)
MAE_MASK_RATIO = 0.8
MAE_EPOCHS = 10
MAE_OUT    = f"./runs_selfcontained_v2/{TASK}_mae_pretrain"


In [5]:

# ==== Imports & device ====
import os, json, time, math, random, csv, glob
from pathlib import Path
from typing import Tuple
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

print("Torch:", torch.__version__)
print("CUDA:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Torch: 2.5.1
CUDA: True
GPU: NVIDIA GeForce RTX 4070


## Read nnU‑Net plans (spacing & patch size)
Use these to align geometry with the supervised anchor. You may set `PATCH` to `nn_patch` if VRAM allows.


In [6]:

def read_nnunet_cfg(ds_id, name, data_root=DATA_ROOT):
    base = f"{data_root}/nnunet_preprocessed/Dataset0{ds_id:02d}_{name}"
    with open(f"{base}/nnUNetPlans.json") as f:
        plans = json.load(f)
    cfg = plans.get("configurations", {}).get("3d_fullres", {})
    spacing = (cfg.get("spacing") or cfg.get("resampling_target_spacing") or
               plans.get("target_spacing") or plans.get("spacing"))
    patch_size = tuple(cfg.get("patch_size") or ())
    batch_size = cfg.get("batch_size")
    return spacing, patch_size, batch_size

spacing, nn_patch, nn_bs = read_nnunet_cfg(DS_ID, NAME)
print(f"Dataset {DS_ID} {NAME} -> spacing={spacing}, patch_size={nn_patch}, batch_size={nn_bs}")
print("Notebook PATCH =", PATCH, "(set to nn_patch if VRAM allows)")


Dataset 2 Heart -> spacing=[1.3700000047683716, 1.25, 1.25], patch_size=(80, 192, 160), batch_size=2
Notebook PATCH = (80, 192, 160) (set to nn_patch if VRAM allows)


## I/O, resampling, normalization & augmentation
- Resample to nnU‑Net spacing via **SimpleITK** (BSpline for image, Nearest for label)
- **CT**: clip to [-125, 275] HU → scale to [0,1].  **MRI**: z‑score in foreground
- Light 3D aug: flips + gamma


In [7]:

import SimpleITK as sitk

def sitk_resample_to_spacing(image_sitk, target_spacing_xyz, is_label=False):
    # target_spacing_xyz expects (z, y, x)
    orig_spacing = image_sitk.GetSpacing()     # (x,y,z)
    target = (float(target_spacing_xyz[2]), float(target_spacing_xyz[1]), float(target_spacing_xyz[0]))
    orig_size = image_sitk.GetSize()           # (x,y,z)
    new_size = [int(round(osz*ospc/tspc)) for osz,ospc,tspc in zip(orig_size, orig_spacing, target)]
    res = sitk.Resample(
        image_sitk, new_size, sitk.Transform(),
        sitk.sitkNearestNeighbor if is_label else sitk.sitkBSpline,
        image_sitk.GetOrigin(), target, image_sitk.GetDirection(),
        0, image_sitk.GetPixelID()
    )
    return res

def normalize_ct_hu(x):
    x = np.clip(x, -125, 275)
    x = (x + 125)/400.0
    return x.astype(np.float32)

def normalize_mri_z(x, mask=None, eps=1e-6):
    if mask is None: mask = x != 0
    m = x[mask].mean() if mask.any() else x.mean()
    s = x[mask].std() if mask.any() else x.std()
    return ((x - m)/(s+eps)).astype(np.float32)

def random_aug(vol, seg):
    if random.random() < 0.5: vol = vol[::-1].copy(); seg = seg[::-1].copy()
    if random.random() < 0.5: vol = vol[:, ::-1].copy(); seg = seg[:, ::-1].copy()
    if random.random() < 0.5: vol = vol[:, :, ::-1].copy(); seg = seg[:, :, ::-1].copy()
    if random.random() < 0.3:
        g = 0.7 + 0.6*random.random()
        vmin, vmax = vol.min(), vol.max()
        vol = ((vol - vmin)/(vmax - vmin + 1e-6))**g
        vol = vol*(vmax - vmin) + vmin
    return vol, seg


In [8]:

class PairListDataset(Dataset):
    def __init__(self, list_file, patch, spacing, modality="CT", training=True, fg_ratio=0.5):
        self.items = [l.strip().split(",") for l in open(list_file) if l.strip()]
        self.patch = patch
        self.spacing = np.array(spacing, dtype=np.float32)
        self.modality = modality
        self.training = training
        self.fg_ratio = fg_ratio if training else 0.0

    def _rand_center(self, seg):
        z, y, x = np.where(seg > 0)
        if len(z)==0:
            return [np.random.randint(0, seg.shape[0]),
                    np.random.randint(0, seg.shape[1]),
                    np.random.randint(0, seg.shape[2])]
        i = np.random.randint(0, len(z))
        return [int(z[i]), int(y[i]), int(x[i])]

    def _extract_patch(self, v, s, center=None):
        Pz, Py, Px = self.patch
        Z, Y, X = v.shape
        if center is None:
            cz = np.random.randint(0, max(1, Z - Pz + 1))
            cy = np.random.randint(0, max(1, Y - Py + 1))
            cx = np.random.randint(0, max(1, X - Px + 1))
        else:
            cz = max(0, min(center[0] - Pz//2, Z - Pz))
            cy = max(0, min(center[1] - Py//2, Y - Py))
            cx = max(0, min(center[2] - Px//2, X - Px))
        patch_v = v[cz:cz+Pz, cy:cy+Py, cx:cx+Px]
        patch_s = s[cz:cz+Pz, cy:cy+Py, cx:cx+Px]
        # pad if at border
        padz = Pz - patch_v.shape[0]; pady = Py - patch_v.shape[1]; padx = Px - patch_v.shape[2]
        if padz>0 or pady>0 or padx>0:
            pad = [(0,padz),(0,pady),(0,padx)]
            patch_v = np.pad(patch_v, pad, mode="edge")
            patch_s = np.pad(patch_s, pad, mode="edge")
        return patch_v, patch_s

    def __len__(self): return len(self.items)

    def __getitem__(self, idx):
        ip, lp = self.items[idx]
        i_sitk = sitk.ReadImage(ip); l_sitk = sitk.ReadImage(lp)
        i_rs = sitk_resample_to_spacing(i_sitk, self.spacing, is_label=False)
        l_rs = sitk_resample_to_spacing(l_sitk, self.spacing, is_label=True)
        vol = sitk.GetArrayFromImage(i_rs).astype(np.float32)
        seg = sitk.GetArrayFromImage(l_rs).astype(np.int16)

        vol = normalize_ct_hu(vol) if self.modality=="CT" else normalize_mri_z(vol, vol!=0)
        if self.training:
            vol, seg = random_aug(vol, seg)
            if random.random() < self.fg_ratio:
                center = self._rand_center(seg)
                vol, seg = self._extract_patch(vol, seg, center=center)
            else:
                vol, seg = self._extract_patch(vol, seg, center=None)
        else:
            vol, seg = self._extract_patch(vol, seg, center=None)

        vol = vol[None, ...]
        onehot = np.zeros((NUM_CLASSES,)+seg.shape, np.float32)
        for c in range(NUM_CLASSES):
            onehot[c] = (seg==c)
        return torch.from_numpy(vol), torch.from_numpy(onehot), torch.from_numpy(seg)


## SE‑Residual U‑Net++ (deep supervision)


In [9]:

class SEBlock(nn.Module):
    def __init__(self, ch, r=8):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc1 = nn.Conv3d(ch, ch//r, 1)
        self.fc2 = nn.Conv3d(ch//r, ch, 1)
        self.act = nn.ReLU(inplace=True)
        self.gate = nn.Sigmoid()
    def forward(self, x):
        s = self.pool(x)
        s = self.fc2(self.act(self.fc1(s)))
        return x * self.gate(s)

class ResSEBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv3d(in_ch, out_ch, 3, padding=1, bias=False)
        self.in1   = nn.InstanceNorm3d(out_ch, affine=True)
        self.conv2 = nn.Conv3d(out_ch, out_ch, 3, padding=1, bias=False)
        self.in2   = nn.InstanceNorm3d(out_ch, affine=True)
        self.se    = SEBlock(out_ch)
        self.act   = nn.LeakyReLU(0.01, inplace=True)
        self.skip  = nn.Conv3d(in_ch, out_ch, 1, bias=False) if in_ch != out_ch else nn.Identity()
    def forward(self, x):
        s = self.skip(x)
        x = self.act(self.in1(self.conv1(x)))
        x = self.in2(self.conv2(x))
        x = self.se(x)
        return self.act(x + s)

class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose3d(in_ch, out_ch, 2, 2)
        self.conv = ResSEBlock(in_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x)
        dz, dy, dx = skip.shape[2]-x.shape[2], skip.shape[3]-x.shape[3], skip.shape[4]-x.shape[4]
        x = F.pad(x, (0, max(0,dx), 0, max(0,dy), 0, max(0,dz)))
        if dz<0 or dy<0 or dx<0:
            x = x[:, :, :skip.shape[2], :skip.shape[3], :skip.shape[4]]
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class SEResUNetPP(nn.Module):
    def __init__(self, in_ch=1, n_classes=2, base=32):
        super().__init__()
        chs = [base, base*2, base*4, base*8, base*16]
        self.e1 = ResSEBlock(in_ch, chs[0])
        self.e2 = ResSEBlock(chs[0], chs[1])
        self.e3 = ResSEBlock(chs[1], chs[2])
        self.e4 = ResSEBlock(chs[2], chs[3])
        self.bottleneck = ResSEBlock(chs[3], chs[4])
        self.pool = nn.MaxPool3d(2)
        self.u4 = UpBlock(chs[4], chs[3])
        self.u3 = UpBlock(chs[3], chs[2])
        self.u2 = UpBlock(chs[2], chs[1])
        self.u1 = UpBlock(chs[1], chs[0])
        self.out_main = nn.Conv3d(chs[0], n_classes, 1)
        self.out_ds2  = nn.Conv3d(chs[1], n_classes, 1)
        self.out_ds3  = nn.Conv3d(chs[2], n_classes, 1)
        self.out_ds4  = nn.Conv3d(chs[3], n_classes, 1)
    def forward(self, x):
        e1 = self.e1(x)
        e2 = self.e2(self.pool(e1))
        e3 = self.e3(self.pool(e2))
        e4 = self.e4(self.pool(e3))
        bn = self.bottleneck(self.pool(e4))
        d4 = self.u4(bn, e4)
        d3 = self.u3(d4, e3)
        d2 = self.u2(d3, e2)
        d1 = self.u1(d2, e1)
        o_main = self.out_main(d1)
        o2 = F.interpolate(self.out_ds2(d2), size=o_main.shape[2:], mode="trilinear", align_corners=False)
        o3 = F.interpolate(self.out_ds3(d3), size=o_main.shape[2:], mode="trilinear", align_corners=False)
        o4 = F.interpolate(self.out_ds4(d4), size=o_main.shape[2:], mode="trilinear", align_corners=False)
        return o_main, o2, o3, o4

# quick check
tmp = SEResUNetPP(IN_CHANNELS, NUM_CLASSES, base=32).to(device)
x = torch.randn(1, IN_CHANNELS, *PATCH).to(device)
with torch.no_grad():
    outs = tmp(x)
print("Output shapes:", [t.shape for t in outs])
print("Params (M):", sum(p.numel() for p in tmp.parameters())/1e6)
del tmp


Output shapes: [torch.Size([1, 2, 80, 192, 160]), torch.Size([1, 2, 80, 192, 160]), torch.Size([1, 2, 80, 192, 160]), torch.Size([1, 2, 80, 192, 160])]
Params (M): 23.038112


## Losses, metrics, EMA


In [10]:

def soft_dice_loss(logits, onehot, eps=1e-6):
    probs = torch.softmax(logits, dim=1)
    dims = (0,2,3,4)
    inter = torch.sum(probs*onehot, dims)
    denom = torch.sum(probs, dims) + torch.sum(onehot, dims)
    dice = (2*inter + eps)/(denom + eps)
    return 1 - dice.mean()

ce = nn.CrossEntropyLoss()

def combined_loss(logits_list, onehot, y):
    weights = [1.0, 0.5, 0.25, 0.125]
    loss = 0.0
    for w,lg in zip(weights, logits_list):
        loss += w*(soft_dice_loss(lg, onehot) + ce(lg, y.long()))
    return loss/sum(weights)

def dice_from_logits(logits, y, cls=1, eps=1e-6):
    pred = torch.argmax(logits, dim=1)
    inter = torch.sum((pred==cls) & (y==cls)).float()
    denom = torch.sum(pred==cls).float() + torch.sum(y==cls).float()
    return (2*inter + eps)/(denom + eps)

class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {k: v.detach().clone() for k,v in model.state_dict().items()}
    def update(self, model):
        with torch.no_grad():
            for k,v in model.state_dict().items():
                self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1-self.decay)
    def apply_to(self, model):
        model.load_state_dict(self.shadow, strict=False)


## DataLoaders


In [11]:

os.makedirs(OUT_DIR, exist_ok=True)
train_ds = PairListDataset(LIST_TRAIN, PATCH, spacing, modality=MODALITY, training=True, fg_ratio=0.6)
val_ds   = PairListDataset(LIST_VAL,   PATCH, spacing, modality=MODALITY, training=False)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=1,          shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
print("train/val:", len(train_ds), len(val_ds))


train/val: 16 4


## Train (AMP, optional EMA)


In [12]:

def train_epoch(model, loader, opt, scaler, ema=None):
    model.train()
    total=0.0
    for vol, onehot, lab in loader:
        vol = vol.to(device, non_blocking=True)
        onehot = onehot.to(device, non_blocking=True)
        lab = lab.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=AMP):
            outs = model(vol)
            loss = combined_loss(outs, onehot, lab)
        scaler.scale(loss).backward()
        scaler.step(opt); scaler.update()
        if ema: ema.update(model)
        total += loss.item()
    return total/len(loader)

@torch.no_grad()
def validate(model, loader, ema=None):
    if ema:
        bak = {k: v.detach().clone() for k,v in model.state_dict().items()}
        ema.apply_to(model)
    model.eval()
    ds = []
    for vol, _, lab in loader:
        vol = vol.to(device, non_blocking=True)
        lab = lab.to(device, non_blocking=True)
        lg = model(vol)[0]
        ds.append(dice_from_logits(lg, lab, 1).item())
    if ema:
        model.load_state_dict(bak, strict=False)
    return float(np.mean(ds)) if ds else 0.0

model = SEResUNetPP(IN_CHANNELS, NUM_CLASSES, base=32).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=AMP)
ema = EMA(model, decay=0.999) if USE_EMA else None

best=-1.0
import csv, time
with open(os.path.join(OUT_DIR,"log.csv"),"w",newline="") as f: csv.writer(f).writerow(["epoch","train_loss","val_dice"])

for ep in range(1, MAX_EPOCHS+1):
    t0=time.time()
    tl = train_epoch(model, train_loader, opt, scaler, ema)
    vd = validate(model, val_loader, ema)
    with open(os.path.join(OUT_DIR,"log.csv"),"a",newline="") as f: csv.writer(f).writerow([ep, tl, vd])
    if vd>best:
        best=vd
        torch.save({"epoch":ep, "state_dict":(ema.shadow if ema else model.state_dict()), "val_dice":vd},
                   os.path.join(OUT_DIR, "seresunetpp_best.pth"))
    print(f"Epoch {ep:03d} | loss {tl:.4f} | valDice {vd:.4f} | best {best:.4f}")


  scaler = torch.cuda.amp.GradScaler(enabled=AMP)
  with torch.cuda.amp.autocast(enabled=AMP):


Epoch 001 | loss 0.9945 | valDice 0.0580 | best 0.0580
Epoch 002 | loss 0.5611 | valDice 0.0672 | best 0.0672
Epoch 003 | loss 0.4353 | valDice 0.0613 | best 0.0672
Epoch 004 | loss 0.4136 | valDice 0.0914 | best 0.0914
Epoch 005 | loss 0.3699 | valDice 0.1043 | best 0.1043
Epoch 006 | loss 0.3411 | valDice 0.1146 | best 0.1146
Epoch 007 | loss 0.2595 | valDice 0.1381 | best 0.1381
Epoch 008 | loss 0.2851 | valDice 0.1490 | best 0.1490
Epoch 009 | loss 0.2626 | valDice 0.1589 | best 0.1589
Epoch 010 | loss 0.2087 | valDice 0.1640 | best 0.1640
Epoch 011 | loss 0.2159 | valDice 0.1045 | best 0.1640
Epoch 012 | loss 0.2499 | valDice 0.1624 | best 0.1640
Epoch 013 | loss 0.2186 | valDice 0.1649 | best 0.1649
Epoch 014 | loss 0.2040 | valDice 0.1759 | best 0.1759
Epoch 015 | loss 0.1866 | valDice 0.1875 | best 0.1875
Epoch 016 | loss 0.1945 | valDice 0.1750 | best 0.1875
Epoch 017 | loss 0.1648 | valDice 0.2002 | best 0.2002
Epoch 018 | loss 0.2116 | valDice 0.2024 | best 0.2024
Epoch 019 

## Inference: sliding window + Gaussian blending + mirror TTA + Largest Component Post‑Processing


In [None]:

def gaussian_weight(patch):
    z,y,x = patch
    zz = np.linspace(-1,1,patch[0])[:,None,None]
    yy = np.linspace(-1,1,patch[1])[None,:,None]
    xx = np.linspace(-1,1,patch[2])[None,None,:]
    g = np.exp(-0.5*(zz**2 + yy**2 + xx**2))
    return g.astype(np.float32)

GAUSS = gaussian_weight(PATCH)

def predict_volume(model, img_path, out_path):
    img0 = sitk.ReadImage(img_path)
    img = sitk_resample_to_spacing(img0, spacing, is_label=False)
    vol = sitk.GetArrayFromImage(img).astype(np.float32)
    vol = normalize_ct_hu(vol) if MODALITY=="CT" else normalize_mri_z(vol, vol!=0)

    sz, sy, sx = vol.shape
    pz, py, px = PATCH
    stz = max(1, int(pz*(1-OVERLAP)))
    sty = max(1, int(py*(1-OVERLAP)))
    stx = max(1, int(px*(1-OVERLAP)))

    prob = np.zeros((NUM_CLASSES, sz, sy, sx), np.float32)
    weight = np.zeros((sz, sy, sx), np.float32)

    def run(inp):
        t = torch.from_numpy(inp[None,None]).to(device)
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=AMP):
            lg = model(t)[0]
            p = torch.softmax(lg, dim=1)[0].cpu().numpy()
        return p

    for z in range(0, max(1, sz - pz + 1), stz):
        for y in range(0, max(1, sy - py + 1), sty):
            for x in range(0, max(1, sx - px + 1), stx):
                patch = vol[z:z+pz, y:y+py, x:x+px]
                pad = [(0, max(0, pz - patch.shape[0])),
                       (0, max(0, py - patch.shape[1])),
                       (0, max(0, px - patch.shape[2]))]
                patch = np.pad(patch, pad, mode="edge")
                p = run(patch)
                if TTA_MIRROR:
                    p = (p
                         + np.flip(run(np.flip(patch, 0)), 1)
                         + np.flip(run(np.flip(patch, 1)), 2)
                         + np.flip(run(np.flip(patch, 2)), 3)) / 4.0
                gw = GAUSS[:, :p.shape[1], :p.shape[2]]
                p = p * gw[None]
                z2 = min(pz, sz - z); y2 = min(py, sy - y); x2 = min(px, sx - x)
                prob[:, z:z+z2, y:y+y2, x:x+x2] += p[:, :z2, :y2, :x2]
                weight[z:z+z2, y:y+y2, x:x+x2] += gw[:z2, :y2, :x2]

    prob /= np.maximum(weight[None], 1e-6)
    seg = np.argmax(prob, axis=0).astype(np.uint8)

    if POSTPROC_KEEP_LARGEST and NUM_CLASSES>1:
        s_img = sitk.GetImageFromArray((seg==1).astype(np.uint8)); s_img.CopyInformation(img)
        cc = sitk.ConnectedComponent(s_img)
        rel = sitk.RelabelComponent(cc, sortByObjectSize=True)
        largest = sitk.BinaryThreshold(rel, 1, 1, 1, 0)
        seg = (sitk.GetArrayFromImage(largest)>0).astype(np.uint8)
        out = np.zeros_like(sitk.GetArrayFromImage(img), dtype=np.uint8)
        out[seg==1] = 1
        seg = out

    seg_sitk = sitk.GetImageFromArray(seg.astype(np.uint8)); seg_sitk.CopyInformation(img)
    back = sitk.Resample(seg_sitk, img0, sitk.Transform(), sitk.sitkNearestNeighbor, 0, sitk.sitkUInt8)
    sitk.WriteImage(back, out_path)
    print("Saved:", out_path)


## Example prediction on one test case (edit filename)


In [None]:

test_img = f"{DATA_ROOT}/MSD/{TASK}/imagesTs/la_001.nii.gz"
ckpt = os.path.join(OUT_DIR, "seresunetpp_best.pth")
out_pred = os.path.join(OUT_DIR, "la_001_pred_v2.nii.gz")

model = SEResUNetPP(IN_CHANNELS, NUM_CLASSES, base=32).to(device)
if os.path.exists(ckpt):
    state = torch.load(ckpt, map_location="cpu")
    model.load_state_dict(state["state_dict"], strict=False)
else:
    print("Checkpoint not found at", ckpt, "- train the model first.")

if os.path.exists(test_img):
    predict_volume(model, test_img, out_pred)
else:
    print("Edit 'test_img' with a real imagesTs filename.")


## Minimal 3D MAE pretrainer + transfer to seg encoder


In [None]:

class MAE3D(nn.Module):
    def __init__(self, in_ch=1, emb=96):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv3d(in_ch, 32, 3, padding=1), nn.InstanceNorm3d(32, affine=True), nn.LeakyReLU(0.01, True),
            nn.Conv3d(32, 64, 3, stride=2, padding=1), nn.InstanceNorm3d(64, affine=True), nn.LeakyReLU(0.01, True),
            nn.Conv3d(64, emb, 3, stride=2, padding=1), nn.InstanceNorm3d(emb, affine=True), nn.LeakyReLU(0.01, True),
        )
        self.dec = nn.Sequential(
            nn.ConvTranspose3d(emb, 64, 2, 2), nn.InstanceNorm3d(64, affine=True), nn.LeakyReLU(0.01, True),
            nn.ConvTranspose3d(64, 32, 2, 2), nn.InstanceNorm3d(32, affine=True), nn.LeakyReLU(0.01, True),
            nn.Conv3d(32, in_ch, 1)
        )
    def forward(self, x, mask):
        z = self.enc(x)
        r = self.dec(z)
        r = F.interpolate(r, size=x.shape[2:], mode="trilinear", align_corners=False)
        return r*mask, x*mask

def mae_make_mask(shape, ratio):
    B,_,Z,Y,X = shape
    m = torch.zeros((B,1,Z,Y,X), device=device)
    num = int(Z*Y*X*ratio)
    for b in range(B):
        idx = torch.randperm(Z*Y*X, device=device)[:num]
        m.view(B,1,-1)[b,0,idx] = 1.0
    return m

class UnlabeledImagesDataset(Dataset):
    def __init__(self, images, patch, spacing, modality="CT"):
        self.paths = images
        self.patch = patch
        self.spacing = np.array(spacing, dtype=np.float32)
        self.modality = modality
    def __len__(self): return len(self.paths)
    def _extract(self, v):
        Pz,Py,Px = self.patch
        Z,Y,X = v.shape
        z = np.random.randint(0, max(1, Z-Pz+1))
        y = np.random.randint(0, max(1, Y-Py+1))
        x = np.random.randint(0, max(1, X-Px+1))
        vv = v[z:z+Pz, y:y+Py, x:x+Px]
        pad = [(0,max(0,Pz-vv.shape[0])), (0,max(0,Py-vv.shape[1])), (0,max(0,Px-vv.shape[2]))]
        return np.pad(vv, pad, mode="edge")
    def __getitem__(self, idx):
        ip = self.paths[idx]
        i_sitk = sitk.ReadImage(ip)
        i_rs = sitk_resample_to_spacing(i_sitk, spacing, is_label=False)
        vol = sitk.GetArrayFromImage(i_rs).astype(np.float32)
        vol = normalize_ct_hu(vol) if self.modality=="CT" else normalize_mri_z(vol, vol!=0)
        vol = self._extract(vol)
        return torch.from_numpy(vol[None, ...])

def mae_pretrain(imagesTr_dir, epochs=MAE_EPOCHS, mask_ratio=MAE_MASK_RATIO, save_dir=MAE_OUT):
    os.makedirs(save_dir, exist_ok=True)
    imgs = sorted(glob.glob(os.path.join(imagesTr_dir, "*.nii.gz")))
    ds = UnlabeledImagesDataset(imgs, MAE_PATCH, spacing, MODALITY)
    dl = DataLoader(ds, batch_size=2, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    mae = MAE3D(in_ch=1, emb=96).to(device)
    opt = torch.optim.AdamW(mae.parameters(), lr=1e-3, weight_decay=1e-4)
    scaler = torch.cuda.amp.GradScaler(enabled=AMP)
    for ep in range(1, epochs+1):
        total=0.0
        for v in dl:
            v = v.to(device, non_blocking=True)
            m = mae_make_mask(v.shape, mask_ratio)
            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=AMP):
                pr, gt = mae(v, m)
                loss = F.l1_loss(pr, gt)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
            total += loss.item()
        print(f"[MAE] epoch {ep:03d} | loss {total/len(dl):.4f}")
    torch.save({"encoder": mae.enc.state_dict()}, os.path.join(save_dir, "mae_encoder.pth"))
    print("Saved MAE encoder to", os.path.join(save_dir, "mae_encoder.pth"))
    return os.path.join(save_dir, "mae_encoder.pth")

def load_mae_encoder_into_seg(model, mae_ckpt_path):
    ck = torch.load(mae_ckpt_path, map_location="cpu")["encoder"]
    seg_sd = model.state_dict()
    mapped=0
    for k,v in ck.items():
        for tgt in ["e1", "e2", "e3"]:
            for sub in ["conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias"]:
                tk = f"{tgt}.{sub}"
                if tk in seg_sd and seg_sd[tk].shape == v.shape:
                    seg_sd[tk] = v; mapped += 1; break
    model.load_state_dict(seg_sd, strict=False)
    print(f"Loaded {mapped} MAE params into segmentation encoder.")
