# Self‑Contained 3D Segmentation Notebook (U‑Net‑style, Deep Supervision)
This notebook is **self-contained**: it defines the model, dataloaders, losses, training loop, sliding‑window inference, and simple evaluation **inside** the notebook.

### What it covers
- Reads **nnU‑Net v2** plans to mirror *spacing* and *patch size* (for fair comparison).
- Minimal **CT/MRI normalization** and **resampling to target spacing** via SimpleITK.
- A **Residual U‑Net 3D** with **deep supervision**.
- **Dice + Cross‑Entropy** loss, AMP, checkpointing by **val Dice**.
- A simple **sliding‑window** predictor (with optional test‑time mirroring).

> **Where to put it:** save/run in your WSL repo folder (`~/projects/medssl_from_scratch`) or anywhere you like. Just ensure the paths below are correct.


In [None]:

# --- Config: edit these for Heart vs Spleen, paths, and training knobs ---
DATA_ROOT = "/home/htetaung/data"            # where MSD + nnU-Net folders live
TASK      = "Task02_Heart"                   # or "Task09_Spleen"
DS_ID     = 2 if TASK == "Task02_Heart" else 9
NAME      = "Heart" if DS_ID == 2 else "Spleen"
FOLD      = 0
IN_CHANNELS = 1
NUM_CLASSES = 2                              # bg + structure
MAX_EPOCHS = 50
BATCH_SIZE = 2
PATCH      = (96, 96, 96)                    # will print nnU-Net patch below; you can copy it if VRAM allows
NUM_WORKERS = 4
AMP = True
OUT_DIR = f"./runs_selfcontained/{TASK}_unet_resdeep_fold{FOLD}"
LIST_TRAIN = f"./lists/{TASK}_train_fold{FOLD}.txt"   # reuse nnU-Net split -> created earlier
LIST_VAL   = f"./lists/{TASK}_val_fold{FOLD}.txt"

# modality for normalization: "CT" or "MRI"
MODALITY = "MRI" if TASK == "Task02_Heart" else "CT"

# test-time augmentation (mirror flips over axes)
TTA_MIRROR = True


In [None]:

# --- Imports & sanity ---
import os, json, math, time, random, shutil, csv
from pathlib import Path
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

print("Torch:", torch.__version__)
print("CUDA:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Read nnU‑Net plans (spacing & patch size)
Use these values for fair comparisons. If you have enough VRAM, set `PATCH` to the printed patch size.


In [None]:

def read_nnunet_cfg(ds_id, name, data_root=DATA_ROOT):
    base = f"{data_root}/nnunet_preprocessed/Dataset0{ds_id:02d}_{name}"
    with open(f"{base}/nnUNetPlans.json") as f:
        plans = json.load(f)
    cfg = plans.get("configurations", {}).get("3d_fullres", {})
    spacing = (cfg.get("spacing") or cfg.get("resampling_target_spacing") or
               plans.get("target_spacing") or plans.get("spacing"))
    patch_size = tuple(cfg.get("patch_size") or ())
    batch_size = cfg.get("batch_size")
    return spacing, patch_size, batch_size

spacing, nn_patch, nn_bs = read_nnunet_cfg(DS_ID, NAME)
print(f"Dataset {DS_ID} {NAME} -> spacing={spacing}, patch_size={nn_patch}, batch_size={nn_bs}")
print(f"Using PATCH (for this notebook) =", PATCH)


## Dataset & preprocessing
- We read list files of `image_path,label_path`.
- **Resample** each image/label to the **nnU‑Net spacing** using SimpleITK (linear for images, nearest for labels).
- Normalize:
  - **CT:** clip to \[-125, 275\] HU then scale to 0–1 (spleen‑friendly window).
  - **MRI:** z‑score inside the foreground.
- Random augmentations: flips, small rotations/scales, gamma (light).

> You can expand this with MONAI/TorchIO later; we keep it minimal here.


In [None]:

import nibabel as nib
import SimpleITK as sitk

def load_nii(path):
    img = sitk.ReadImage(path)
    arr = sitk.GetArrayFromImage(img)  # z,y,x
    spacing = img.GetSpacing()[::-1]   # to z,y,x
    return arr.astype(np.float32), np.array(spacing, dtype=np.float32), img

def sitk_resample_to_spacing(image_sitk, target_spacing, is_label=False):
    orig_spacing = image_sitk.GetSpacing()
    orig_size = image_sitk.GetSize()
    target_spacing = tuple(target_spacing[::-1])  # sitk uses x,y,z
    new_size = [int(round(osz*ospc/tspc)) for osz,ospc,tspc in zip(orig_size, orig_spacing, target_spacing)]
    res = sitk.Resample(
        image_sitk,
        new_size,
        sitk.Transform(),
        sitk.sitkNearestNeighbor if is_label else sitk.sitkBSpline,
        image_sitk.GetOrigin(),
        target_spacing,
        image_sitk.GetDirection(),
        0,
        image_sitk.GetPixelID()
    )
    return res

def normalize_ct_hu(x):
    # Window for abdominal organs ~[-125, 275], then scale to [0,1]
    x = np.clip(x, -125, 275)
    x = (x + 125) / (275 + 125)
    return x.astype(np.float32)

def normalize_mri_zscore(x, mask=None, eps=1e-6):
    if mask is None:
        mask = x != 0
    m = x[mask].mean() if mask.any() else x.mean()
    s = x[mask].std() if mask.any() else x.std()
    return ((x - m) / (s + eps)).astype(np.float32)

def random_augment(vol, seg):
    # minimal 3D augmentations: random flips & gamma
    if random.random() < 0.5:
        vol = vol[::-1].copy(); seg = seg[::-1].copy()
    if random.random() < 0.5:
        vol = vol[:, ::-1].copy(); seg = seg[:, ::-1].copy()
    if random.random() < 0.5:
        vol = vol[:, :, ::-1].copy(); seg = seg[:, :, ::-1].copy()
    # random gamma
    if random.random() < 0.3:
        g = 0.7 + 0.6*random.random()
        vmin, vmax = vol.min(), vol.max()
        vol = ((vol - vmin) / (vmax - vmin + 1e-6)) ** g
        vol = vol * (vmax - vmin) + vmin
    return vol, seg


In [None]:

class NiftiPairDataset(Dataset):
    def __init__(self, list_file, patch, spacing_target, modality="CT", training=True):
        self.items = [l.strip().split(",") for l in open(list_file) if l.strip()]
        self.patch = patch
        self.spacing_target = np.array(spacing_target, dtype=np.float32)
        self.modality = modality
        self.training = training

    def _crop_or_pad(self, v, s):
        # center crop or pad to patch size
        Pz, Py, Px = self.patch
        z, y, x = v.shape
        outv = np.zeros(self.patch, np.float32)
        outs = np.zeros(self.patch, np.int16)
        sz = max(0, (Pz - z)//2); ez = sz + z
        sy = max(0, (Py - y)//2); ey = sy + y
        sx = max(0, (Px - x)//2); ex = sx + x

        cz = max(0, (z - Pz)//2); cz2 = cz + min(Pz, z)
        cy = max(0, (y - Py)//2); cy2 = cy + min(Py, y)
        cx = max(0, (x - Px)//2); cx2 = cx + min(Px, x)

        outv[sz:ez, sy:ey, sx:ex] = v[cz:cz2, cy:cy2, cx:cx2]
        outs[sz:ez, sy:ey, sx:ex] = s[cz:cz2, cy:cy2, cx:cx2]
        return outv, outs

    def __len__(self): return len(self.items)

    def __getitem__(self, idx):
        ipath, lpath = self.items[idx]
        img_sitk = sitk.ReadImage(ipath); lab_sitk = sitk.ReadImage(lpath)
        img_rs = sitk_resample_to_spacing(img_sitk, self.spacing_target, is_label=False)
        lab_rs = sitk_resample_to_spacing(lab_sitk, self.spacing_target, is_label=True)
        vol = sitk.GetArrayFromImage(img_rs).astype(np.float32)  # z,y,x
        seg = sitk.GetArrayFromImage(lab_rs).astype(np.int16)

        if self.modality == "CT":
            vol = normalize_ct_hu(vol)
        else:
            # MRI: foreground mask where not 0
            vol = normalize_mri_zscore(vol, vol != 0)

        if self.training:
            vol, seg = random_augment(vol, seg)

        vol, seg = self._crop_or_pad(vol, seg)

        vol = vol[None, ...]               # (1, z, y, x)
        onehot = np.zeros((NUM_CLASSES,)+seg.shape, np.float32)
        for c in range(NUM_CLASSES):
            onehot[c] = (seg == c)
        return torch.from_numpy(vol), torch.from_numpy(onehot), torch.from_numpy(seg)


## Model: Residual U‑Net 3D with deep supervision
- Residual conv blocks (3D) with InstanceNorm.
- 4 encoder stages + 4 decoder stages.
- **Deep supervision**: auxiliary logits at 3 decoder stages (down‑weighted in the loss).


In [None]:

class ConvBlock3d(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv3d(in_ch, out_ch, 3, padding=1, bias=False)
        self.in1   = nn.InstanceNorm3d(out_ch, affine=True)
        self.conv2 = nn.Conv3d(out_ch, out_ch, 3, padding=1, bias=False)
        self.in2   = nn.InstanceNorm3d(out_ch, affine=True)
        self.act   = nn.LeakyReLU(0.01, inplace=True)
        self.skip  = nn.Conv3d(in_ch, out_ch, 1, bias=False) if in_ch != out_ch else nn.Identity()

    def forward(self, x):
        s = self.skip(x)
        x = self.conv1(x); x = self.in1(x); x = self.act(x)
        x = self.conv2(x); x = self.in2(x)
        return self.act(x + s)

class UpBlock3d(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose3d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = ConvBlock3d(in_ch, out_ch)  # after concat

    def forward(self, x, skip):
        x = self.up(x)
        # pad/crop to match if needed
        dz, dy, dx = skip.shape[2]-x.shape[2], skip.shape[3]-x.shape[3], skip.shape[4]-x.shape[4]
        x = F.pad(x, (0, max(0,dx), 0, max(0,dy), 0, max(0,dz)))
        if dz<0 or dy<0 or dx<0:
            x = x[:, :, :skip.shape[2], :skip.shape[3], :skip.shape[4]]
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class ResidualUNet3D_DS(nn.Module):
    def __init__(self, in_ch=1, n_classes=2, base=32):
        super().__init__()
        chs = [base, base*2, base*4, base*8, base*16]

        # encoder
        self.e1 = ConvBlock3d(in_ch, chs[0])
        self.e2 = ConvBlock3d(chs[0], chs[1])
        self.e3 = ConvBlock3d(chs[1], chs[2])
        self.e4 = ConvBlock3d(chs[2], chs[3])
        self.bottleneck = ConvBlock3d(chs[3], chs[4])

        self.pool = nn.MaxPool3d(2)

        # decoder
        self.u4 = UpBlock3d(chs[4], chs[3])
        self.u3 = UpBlock3d(chs[3], chs[2])
        self.u2 = UpBlock3d(chs[2], chs[1])
        self.u1 = UpBlock3d(chs[1], chs[0])

        # logits (deep supervision at three decoder scales)
        self.out_main = nn.Conv3d(chs[0], n_classes, 1)
        self.out_ds2  = nn.Conv3d(chs[1], n_classes, 1)
        self.out_ds3  = nn.Conv3d(chs[2], n_classes, 1)
        self.out_ds4  = nn.Conv3d(chs[3], n_classes, 1)

    def forward(self, x):
        e1 = self.e1(x)
        e2 = self.e2(self.pool(e1))
        e3 = self.e3(self.pool(e2))
        e4 = self.e4(self.pool(e3))
        bn = self.bottleneck(self.pool(e4))

        d4 = self.u4(bn, e4)
        d3 = self.u3(d4, e3)
        d2 = self.u2(d3, e2)
        d1 = self.u1(d2, e1)

        out_main = self.out_main(d1)
        # upsample ds outputs to main resolution
        ds2 = F.interpolate(self.out_ds2(d2), size=out_main.shape[2:], mode="trilinear", align_corners=False)
        ds3 = F.interpolate(self.out_ds3(d3), size=out_main.shape[2:], mode="trilinear", align_corners=False)
        ds4 = F.interpolate(self.out_ds4(d4), size=out_main.shape[2:], mode="trilinear", align_corners=False)
        return out_main, ds2, ds3, ds4

# quick shape/param check
m = ResidualUNet3D_DS(in_ch=IN_CHANNELS, n_classes=NUM_CLASSES, base=32).to(device)
x = torch.randn(1, IN_CHANNELS, *PATCH).to(device)
with torch.no_grad():
    outs = m(x)
print("Output shapes:", [o.shape for o in outs])
print("Params (M):", sum(p.numel() for p in m.parameters())/1e6)


## Losses & metrics
- **Soft Dice + Cross‑Entropy** combined.
- Deep supervision with weights `[1.0, 0.5, 0.25, 0.125]`.


In [None]:

def soft_dice_loss(logits, targets, eps=1e-6):
    # logits: (B,C,Z,Y,X), targets: one-hot (B,C,Z,Y,X)
    probs = torch.softmax(logits, dim=1)
    dims = (0,2,3,4)
    intersect = torch.sum(probs * targets, dims)
    denom = torch.sum(probs, dims) + torch.sum(targets, dims)
    dice = (2*intersect + eps) / (denom + eps)
    return 1 - dice.mean()

ce = nn.CrossEntropyLoss()

def combined_loss(logits_list, targets_onehot, targets_label):
    weights = [1.0, 0.5, 0.25, 0.125]
    tot = 0.0
    for w,logits in zip(weights, logits_list):
        tot += w*(soft_dice_loss(logits, targets_onehot) + ce(logits, targets_label.long()))
    return tot / sum(weights)

def dice_score_from_logits(logits, targets_label, c=1, eps=1e-6):
    # computes single-class Dice (e.g., organ)
    pred = torch.argmax(logits, dim=1)
    inter = torch.sum((pred==c) & (targets_label==c)).float()
    denom = torch.sum(pred==c).float() + torch.sum(targets_label==c).float()
    return (2*inter + eps)/(denom + eps)


## DataLoaders
Reads the nnU‑Net fold‑0 splits we created earlier. If you haven’t created them yet, run the helper in the other notebook to export `splits_final.json` into list files.


In [None]:

# build loaders
os.makedirs(OUT_DIR, exist_ok=True)

train_ds = NiftiPairDataset(LIST_TRAIN, PATCH, np.array(spacing), modality=MODALITY, training=True)
val_ds   = NiftiPairDataset(LIST_VAL,   PATCH, np.array(spacing), modality=MODALITY, training=False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=1,          shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

len(train_ds), len(val_ds)


## Training loop (AMP, checkpoint by best val Dice)


In [None]:

def train_one_epoch(model, loader, opt, scaler):
    model.train()
    total = 0.0
    for vol, onehot, lab in loader:
        vol = vol.to(device, non_blocking=True)
        onehot = onehot.to(device, non_blocking=True)
        lab = lab.to(device, non_blocking=True)

        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=AMP):
            outs = model(vol)
            loss = combined_loss(outs, onehot, lab)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        total += loss.item()
    return total/len(loader)

@torch.no_grad()
def validate(model, loader):
    model.eval()
    dices = []
    for vol, onehot, lab in loader:
        vol = vol.to(device, non_blocking=True)
        lab = lab.to(device, non_blocking=True)
        logits = model(vol)[0]
        d = dice_score_from_logits(logits, lab, c=1).item()
        dices.append(d)
    return float(np.mean(dices)) if dices else 0.0

def save_ckpt(state, path):
    torch.save(state, path)

model = ResidualUNet3D_DS(in_ch=IN_CHANNELS, n_classes=NUM_CLASSES, base=32).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=AMP)

best = -1
log_csv = os.path.join(OUT_DIR, "log.csv")
with open(log_csv, "w", newline="") as f:
    w=csv.writer(f); w.writerow(["epoch","train_loss","val_dice"])

for epoch in range(1, MAX_EPOCHS+1):
    t0=time.time()
    tl = train_one_epoch(model, train_loader, opt, scaler)
    vd = validate(model, val_loader)
    with open(log_csv, "a", newline="") as f:
        csv.writer(f).writerow([epoch, tl, vd])
    if vd > best:
        best = vd
        save_ckpt({"epoch":epoch, "state_dict":model.state_dict(), "val_dice":vd},
                  os.path.join(OUT_DIR, "unet3d_best.pth"))
    print(f"Epoch {epoch:03d} | loss {tl:.4f} | val Dice {vd:.4f} | best {best:.4f} | {time.time()-t0:.1f}s")


## Sliding‑window prediction (with optional mirror TTA)
Outputs a NIfTI file aligned to the input image geometry.


In [None]:

def sliding_window_predict(vol_np, model, patch, overlap=0.5, tta_mirror=False):
    model.eval()
    sz, sy, sx = vol_np.shape
    pz, py, px = patch
    stz = max(1, int(pz*(1-overlap)))
    sty = max(1, int(py*(1-overlap)))
    stx = max(1, int(px*(1-overlap)))

    prob = np.zeros((NUM_CLASSES, sz, sy, sx), np.float32)
    weight = np.zeros((sz, sy, sx), np.float32)

    def run_patch(inp):
        ten = torch.from_numpy(inp[None, None]).to(device)
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=AMP):
            logits = model(ten)[0]
            p = torch.softmax(logits, dim=1)[0].cpu().numpy()
        return p

    def put_prob(p, z0,y0,x0):
        prob[:, z0:z0+p.shape[1], y0:y0+p.shape[2], x0:x0+p.shape[3]] += p
        weight[z0:z0+p.shape[1], y0:y0+p.shape[2], x0:x0+p.shape[3]] += 1

    for z in range(0, max(1, sz - pz + 1), stz):
        for y in range(0, max(1, sy - py + 1), sty):
            for x in range(0, max(1, sx - px + 1), stx):
                patch_np = vol_np[z:z+pz, y:y+py, x:x+px]
                # pad if needed
                pad = [(0, max(0, pz - patch_np.shape[0])),
                       (0, max(0, py - patch_np.shape[1])),
                       (0, max(0, px - patch_np.shape[2]))]
                patch_np = np.pad(patch_np, pad, mode="edge")
                p = run_patch(patch_np)
                if tta_mirror:
                    p = (p
                         + np.flip(run_patch(np.flip(patch_np, 0)), 1)
                         + np.flip(run_patch(np.flip(patch_np, 1)), 2)
                         + np.flip(run_patch(np.flip(patch_np, 2)), 3)) / 4.0
                put_prob(p[:, :min(pz, sz - z),
                            :min(py, sy - y),
                            :min(px, sx - x)], z, y, x)
    prob /= np.maximum(weight[None], 1e-6)
    seg = np.argmax(prob, axis=0).astype(np.uint8)
    return seg, prob

def predict_single(image_path, out_path):
    # read original image and resample to spacing
    img_sitk = sitk.ReadImage(image_path)
    img_rs = sitk_resample_to_spacing(img_sitk, np.array(spacing), is_label=False)
    vol = sitk.GetArrayFromImage(img_rs).astype(np.float32)
    vol = normalize_mri_zscore(vol) if MODALITY=="MRI" else normalize_ct_hu(vol)
    seg, prob = sliding_window_predict(vol, model, PATCH, overlap=0.5, tta_mirror=TTA_MIRROR)
    # bring seg back to original spacing/size
    seg_sitk = sitk.GetImageFromArray(seg.astype(np.uint8))
    seg_sitk.CopyInformation(img_rs)
    seg_back = sitk.Resample(seg_sitk, img_sitk, sitk.Transform(),
                             sitk.sitkNearestNeighbor, 0, sitk.sitkUInt8)
    sitk.WriteImage(seg_back, out_path)
    print("Saved:", out_path)


## Example prediction on one test case (edit the filename)


In [None]:

# Example (Heart): change la_001 to a real file present in your imagesTs
test_img = f"{DATA_ROOT}/MSD/{TASK}/imagesTs/la_001.nii.gz"
out_pred = os.path.join(OUT_DIR, "la_001_pred.nii.gz")
if os.path.exists(test_img):
    predict_single(test_img, out_pred)
else:
    print("Edit 'test_img' with a real imagesTs filename.")


---
### What next (roadmap inside this notebook)
1) If VRAM allows, set `PATCH = nn_patch` printed earlier.
2) Try **base=48** in the model for more capacity.
3) Add **largest-component post‑processing** for Spleen/Heart.
4) Enable **mirror TTA** at inference (`TTA_MIRROR=True`).
5) Swap loss to **Dice + Focal** for heavy imbalance tasks.
6) (Stretch) Replace backbone with a **UNETR‑lite** and add a **3D MAE** pretrainer.
