# medssl_selfcontained_unet3d_v3

Self‑contained 3D medical image segmentation notebook (U‑Net style), designed to be robust and easy to run.

**Highlights**
- Clean, end‑to‑end pipeline (I/O → Dataset → Model → Loss/Metric → Train → Eval/Infer)
- **Quick Sanity Checks** to run *before* training (tiny overfit test)
- Optional AMP + EMA + uncomplicated checkpointing
- Hook to load **self‑supervised (SSL) encoder** weights if available

> Set your dataset paths in the *Config* cell, then run the *Quick Sanity Checks* section before full training.

## 1) Imports & Seeding

In [None]:
import os, sys, math, random, time, json, shutil, glob
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Simple 3D augmentations
import torchvision.transforms.functional as TF

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False  # better perf
    torch.backends.cudnn.benchmark = True       # autotune

seed_everything(1337)


## 2) Config

In [None]:
@dataclass
class Config:
    # Paths (EDIT THESE)
    data_root: str = "./data/"   # root containing volumes/labels
    train_list: str = "./splits/train.txt"  # each line: <vol_path> <label_path>
    val_list: str   = "./splits/val.txt"
    out_dir: str = "./runs/medssl_unet3d_v3"

    # I/O
    in_channels: int = 1
    num_classes: int = 1  # binary by default
    patch_size: Tuple[int,int,int] = (96, 96, 96)
    norm_min: float = 0.0
    norm_max: float = 99.5

    # Train
    epochs: int = 100
    batch_size: int = 2
    lr: float = 3e-4
    weight_decay: float = 1e-5
    grad_clip: Optional[float] = 1.0
    amp: bool = True
    use_ema: bool = True
    ema_decay: float = 0.999

    # Loss
    dice_smooth: float = 1.0
    bce_weight: float = 0.5  # total = BCE*w + Dice*(1-w)

    # Eval
    threshold: float = 0.5

CFG = Config()
os.makedirs(CFG.out_dir, exist_ok=True)
print(CFG)


## 3) Utilities: I/O, Normalization, Cropping

In [None]:
try:
    import nibabel as nib  # For NIfTI
except Exception:
    nib = None

def robust_min_max(x: np.ndarray, qmin=0.0, qmax=99.5):
    lo, hi = np.percentile(x, [qmin, qmax])
    if hi <= lo:
        hi = lo + 1e-5
    x = np.clip((x - lo) / (hi - lo), 0.0, 1.0)
    return x

def load_volume(path: str) -> np.ndarray:
    if path.lower().endswith(('.nii', '.nii.gz')):
        assert nib is not None, "Install nibabel to read NIfTI."
        vol = nib.load(path).get_fdata().astype(np.float32)
        return vol
    elif path.lower().endswith('.npy'):
        return np.load(path).astype(np.float32)
    else:
        raise ValueError(f"Unsupported volume ext: {path}")

def center_crop_or_pad(vol: np.ndarray, target: Tuple[int,int,int]) -> np.ndarray:
    z, y, x = vol.shape[-3:]
    tz, ty, tx = target
    # Pad
    pad_z = max(0, tz - z); pad_y = max(0, ty - y); pad_x = max(0, tx - x)
    if pad_z or pad_y or pad_x:
        vol = np.pad(vol, ((pad_z//2, pad_z - pad_z//2),
                           (pad_y//2, pad_y - pad_y//2),
                           (pad_x//2, pad_x - pad_x//2)), mode='edge')
        z, y, x = vol.shape
    # Center crop
    sz = (z - tz)//2; sy = (y - ty)//2; sx = (x - tx)//2
    return vol[sz:sz+tz, sy:sy+ty, sx:sx+tx]


## 4) Dataset & Dataloader

In [None]:
class VolPairDataset(Dataset):
    def __init__(self, list_file: str, cfg: Config, augment: bool = False):
        self.items = []
        with open(list_file, 'r') as f:
            for line in f:
                line = line.strip()
                if not line: continue
                vol_path, lab_path = line.split()
                self.items.append((vol_path, lab_path))
        self.cfg = cfg
        self.augment = augment

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        vpath, lpath = self.items[idx]
        vol = load_volume(os.path.join(self.cfg.data_root, vpath))
        lab = load_volume(os.path.join(self.cfg.data_root, lpath))

        vol = robust_min_max(vol, qmin=self.cfg.norm_min, qmax=self.cfg.norm_max)
        lab = (lab > 0).astype(np.float32)  # binary mask

        vol = center_crop_or_pad(vol, self.cfg.patch_size)
        lab = center_crop_or_pad(lab, self.cfg.patch_size)

        # (C, D, H, W)
        vol = vol[None, ...]
        lab = lab[None, ...]

        if self.augment:
            if random.random() < 0.5:
                vol = vol[:, ::-1, :, :]
                lab = lab[:, ::-1, :, :]
            if random.random() < 0.5:
                vol = vol[:, :, ::-1, :]
                lab = lab[:, :, ::-1, :]
            if random.random() < 0.5:
                vol = vol[:, :, :, ::-1]
                lab = lab[:, :, :, ::-1]

        vol = torch.from_numpy(vol.copy())
        lab = torch.from_numpy(lab.copy())
        return vol, lab

def make_loaders(cfg: Config):
    train_ds = VolPairDataset(cfg.train_list, cfg, augment=True)
    val_ds   = VolPairDataset(cfg.val_list,   cfg, augment=False)
    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=1,             shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader


## 5) Model: 3D U‑Net (lightweight)

In [None]:
class ConvBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1), nn.InstanceNorm3d(out_ch), nn.LeakyReLU(0.1, inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1), nn.InstanceNorm3d(out_ch), nn.LeakyReLU(0.1, inplace=True),
        )
    def forward(self, x):
        return self.conv(x)

class UNet3D(nn.Module):
    def __init__(self, in_ch=1, n_classes=1, base=16):
        super().__init__()
        self.enc1 = ConvBlock3D(in_ch, base)
        self.pool1 = nn.MaxPool3d(2)
        self.enc2 = ConvBlock3D(base, base*2)
        self.pool2 = nn.MaxPool3d(2)
        self.enc3 = ConvBlock3D(base*2, base*4)
        self.pool3 = nn.MaxPool3d(2)
        self.bottleneck = ConvBlock3D(base*4, base*8)
        self.up3 = nn.ConvTranspose3d(base*8, base*4, 2, stride=2)
        self.dec3 = ConvBlock3D(base*8, base*4)
        self.up2 = nn.ConvTranspose3d(base*4, base*2, 2, stride=2)
        self.dec2 = ConvBlock3D(base*4, base*2)
        self.up1 = nn.ConvTranspose3d(base*2, base, 2, stride=2)
        self.dec1 = ConvBlock3D(base*2, base)
        self.outc = nn.Conv3d(base, n_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b  = self.bottleneck(self.pool3(e3))
        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        logits = self.outc(d1)
        return logits


## 6) Losses & Metrics (Dice + BCE)

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        dims = (2,3,4)
        intersection = (probs * targets).sum(dim=dims)
        union = probs.sum(dim=dims) + targets.sum(dim=dims)
        dice = (2*intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

class BCEDiceLoss(nn.Module):
    def __init__(self, bce_weight=0.5, smooth=1.0):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss(smooth)
        self.w = bce_weight
    def forward(self, logits, targets):
        return self.w * self.bce(logits, targets) + (1 - self.w) * self.dice(logits, targets)

def dice_coefficient(logits, targets, thr=0.5):
    with torch.no_grad():
        probs = torch.sigmoid(logits)
        preds = (probs > thr).float()
        dims = (2,3,4)
        inter = (preds * targets).sum(dim=dims)
        union = preds.sum(dim=dims) + targets.sum(dim=dims)
        dice = (2*inter) / (union + 1e-8)
        return dice.mean().item()


## 7) (Optional) Self‑Supervised Pretraining Hook

In [None]:
def load_ssl_weights(model: nn.Module, ckpt_path: Optional[str] = None):
    if not ckpt_path or not os.path.exists(ckpt_path):
        print("[SSL] No checkpoint provided — skipping.")
        return
    print(f"[SSL] Loading encoder weights from {ckpt_path}")
    state = torch.load(ckpt_path, map_location='cpu')
    missing, unexpected = model.load_state_dict(state, strict=False)
    print("[SSL] Missing:", len(missing), "Unexpected:", len(unexpected))


## 8) Trainer with AMP + EMA + Checkpointing

In [None]:
class ModelEMA:
    def __init__(self, model, decay=0.999):
        self.model = model
        self.ema = UNet3D(CFG.in_channels, CFG.num_classes)
        self.ema.load_state_dict(model.state_dict())
        self.decay = decay
        for p in self.ema.parameters():
            p.requires_grad_(False)
    @torch.no_grad()
    def update(self):
        msd, esd = self.model.state_dict(), self.ema.state_dict()
        for k in msd.keys():
            esd[k].mul_(self.decay).add_(msd[k], alpha=1 - self.decay)

def save_ckpt(model, optimizer, epoch, best_dice, path):
    torch.save({
        'epoch': epoch,
        'model': model.state_dict(),
        'optim': optimizer.state_dict(),
        'best_dice': best_dice,
        'cfg': CFG.__dict__
    }, path)

def validate(model, loader, device):
    model.eval()
    dices = []
    with torch.no_grad():
        for vol, lab in loader:
            vol = vol.to(device, non_blocking=True)
            lab = lab.to(device, non_blocking=True)
            logits = model(vol)
            dices.append(dice_coefficient(logits, lab, CFG.threshold))
    return float(np.mean(dices)) if dices else 0.0

def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device:", device)
    train_loader, val_loader = make_loaders(CFG)

    model = UNet3D(CFG.in_channels, CFG.num_classes).to(device)
    load_ssl_weights(model, ckpt_path=None)  # set SSL path if available

    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    loss_fn = BCEDiceLoss(bce_weight=CFG.bce_weight, smooth=CFG.dice_smooth)
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.amp)

    ema = ModelEMA(model, decay=CFG.ema_decay) if CFG.use_ema else None

    best_dice = -1.0
    log_path = os.path.join(CFG.out_dir, 'train_log.txt')
    with open(log_path, 'w') as f:
        f.write('epoch,loss,valDice,best\n')

    for epoch in range(1, CFG.epochs+1):
        model.train()
        epoch_loss = 0.0
        n = 0
        t0 = time.time()
        for vol, lab in train_loader:
            vol = vol.to(device, non_blocking=True)
            lab = lab.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=CFG.amp):
                logits = model(vol)
                loss = loss_fn(logits, lab)
            scaler.scale(loss).backward()
            if CFG.grad_clip:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            if ema: ema.update()
            epoch_loss += loss.item() * vol.size(0)
            n += vol.size(0)
        avg_loss = epoch_loss / max(1, n)

        eval_model = ema.ema if ema else model
        val_dice = validate(eval_model, val_loader, device)
        best_dice = max(best_dice, val_dice)

        line = f"Epoch {epoch:03d} | loss {avg_loss:.4f} | valDice {val_dice:.4f} | best {best_dice:.4f}"
        print(line, f" | {time.time()-t0:.1f}s")
        with open(log_path, 'a') as f:
            f.write(f"{epoch},{avg_loss:.6f},{val_dice:.6f},{best_dice:.6f}\n")

        if (epoch % 5 == 0) or (val_dice >= best_dice):
            save_ckpt(eval_model, optimizer, epoch, best_dice, os.path.join(CFG.out_dir, f"ckpt_e{epoch:03d}.pth"))

    print("Training finished. Best Dice:", best_dice)


## 9) **Quick Sanity Checks** (run *before* training)

In [None]:
@torch.no_grad()
def peek_batch(cfg=CFG):
    tl, _ = make_loaders(cfg)
    vol, lab = next(iter(tl))
    print("vol shape:", vol.shape, "lab shape:", lab.shape,
          "min/max:", float(vol.min()), float(vol.max()))
    print("unique labels:", torch.unique(lab))

def quick_overfit_steps(steps=20):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tl, _ = make_loaders(CFG)
    vol, lab = next(iter(tl))
    vol, lab = vol.to(device), lab.to(device)

    model = UNet3D(CFG.in_channels, CFG.num_classes).to(device)
    loss_fn = BCEDiceLoss(bce_weight=CFG.bce_weight, smooth=CFG.dice_smooth)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.amp)

    for i in range(1, steps+1):
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=CFG.amp):
            logits = model(vol)
            loss = loss_fn(logits, lab)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        d = dice_coefficient(logits, lab, thr=CFG.threshold)
        print(f"step {i:02d} loss {loss.item():.4f} dice {d:.4f}")


## 10) Entry Point

In [None]:
if __name__ == "__main__":
    # 1) Run sanity checks first:
    # peek_batch()
    # quick_overfit_steps(steps=30)

    # 2) When checks look good, train:
    # train()
    pass


## 11) Evaluation & Inference Helpers

In [None]:
@torch.no_grad()
def predict_volume(model_path: str, vol_path: str, out_path: str = "./pred.npy"):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNet3D(CFG.in_channels, CFG.num_classes).to(device)
    state = torch.load(model_path, map_location=device)
    if 'model' in state:
        model.load_state_dict(state['model'])
    else:
        model.load_state_dict(state)
    model.eval()
    vol = load_volume(vol_path)
    vol = robust_min_max(vol, qmin=CFG.norm_min, qmax=CFG.norm_max)
    vol = center_crop_or_pad(vol, CFG.patch_size)
    t = torch.from_numpy(vol[None, None]).float().to(device)
    logits = model(t)
    probs = torch.sigmoid(logits).cpu().numpy()[0,0]
    np.save(out_path, probs)
    print("Saved:", out_path)


## 12) Troubleshooting Low Validation Dice
- **Data/Label alignment**: confirm voxel spaces & resampling. Misalignment caps Dice ~0.2–0.3.
- **Binarization**: ensure labels are 0/1; adjust threshold for class imbalance.
- **Normalization**: robust percentile scaling (0–99.5). Try per‑case z‑score if needed.
- **Patch size**: increase to (128,128,128) if memory allows.
- **Loss mix**: try `bce_weight=0.3` or Focal+Dice for heavy imbalance.
- **Augmentations**: start minimal; over‑aggressive aug can hurt.
- **Sanity overfit**: must reach ≳0.9 Dice on 1–2 cases.
- **Eval**: tune threshold on validation set.
- **SSL init**: load encoder weights if available for stability.
