# medssl_selfcontained_unet3d_v3 ‚Äî **SE‚ÄëResidual U‚ÄëNet++ (Deep Supervision)**

Single-file notebook with:
- SE‚ÄëResidual U‚ÄëNet++ (3D) **with deep supervision** (default) and a lightweight 3D U‚ÄëNet baseline
- **Normalization toggle** (z-score / robust percentiles)
- **Cosine LR with warmup**, **gradient accumulation**, **AMP**, optional **EMA**
- **Sanity checks** and **validation threshold sweep**

Edit the **Config** cell, run sanity checks, then train. üçÄ

## 1) Imports & Seeding

In [1]:
import os, sys, math, random, time, json, shutil, glob
from dataclasses import dataclass
from typing import Optional, Tuple, List, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

try:
    import nibabel as nib
except Exception:
    nib = None

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(1337)


## 2) Config

In [None]:
@dataclass
class Config:
    # Paths (EDIT THESE)
    data_root: str = "./data/"
    train_list: str = "./splits/train.txt"
    val_list: str   = "./splits/val.txt"
    out_dir: str = "./runs/medssl_unet3d_v3_sepp"

    # Model & I/O
    in_channels: int = 1
    num_classes: int = 1
    patch_size: Tuple[int,int,int] = (96, 96, 96)
    base_channels: int = 24
    model_name: str = "seresunetpp_ds"  # 'seresunetpp_ds' | 'unet3d_light'
    se_reduction: int = 8
    dropout: float = 0.0

    # Training
    epochs: int = 150
    batch_size: int = 1
    lr: float = 3e-4
    weight_decay: float = 1e-5
    grad_clip: Optional[float] = 1.0
    amp: bool = True
    use_ema: bool = False
    ema_decay: float = 0.999

    # Loss / Deep supervision
    dice_smooth: float = 1.0
    bce_weight: float = 0.4
    ds_weights: Tuple[float, ...] = (0.2, 0.3, 0.5)  # x01, x02, x03

    # Eval
    threshold: float = 0.5

CFG = Config()
os.makedirs(CFG.out_dir, exist_ok=True)
print(CFG)


In [None]:
# === Task02_Heart (fold-0) overrides ===
CFG.data_root  = "/path/to/Task02_Heart"
CFG.train_list = "/path/to/splits/heart_fold0_train.txt"
CFG.val_list   = "/path/to/splits/heart_fold0_val.txt"

# Model choice & capacity
CFG.model_name    = "seresunetpp_ds"   # or "unet3d_light" for the baseline
CFG.base_channels = 24                 # try 32 if VRAM allows

# Match nnU-Net plan
CFG.patch_size = (80, 192, 160)
CFG.batch_size = 1                     # keep memory safe; emulate batch=2 via accumulation

# Loss / Deep supervision
CFG.bce_weight = 0.4                   # try 0.3 later
CFG.ds_weights = (0.2, 0.3, 0.5)

# Training schedule (kept from defaults, repeated here for clarity)
CFG.epochs        = 150
CFG.lr            = 3e-4
CFG.weight_decay  = 1e-5
CFG.grad_clip     = 1.0
CFG.amp           = True
CFG.use_ema       = False

# Eval
CFG.threshold     = 0.5
print("Overrides applied:", CFG)


## 3) Utilities & Normalization Toggle

In [None]:
def robust_min_max(x: np.ndarray, qmin=0.0, qmax=99.5):
    lo, hi = np.percentile(x, [qmin, qmax])
    if hi <= lo:
        hi = lo + 1e-5
    x = np.clip((x - lo) / (hi - lo), 0.0, 1.0)
    return x

def load_volume(path: str) -> np.ndarray:
    if path.lower().endswith(('.nii', '.nii.gz')):
        assert nib is not None, "Install nibabel to read NIfTI."
        vol = nib.load(path).get_fdata().astype(np.float32)
        return vol
    elif path.lower().endswith('.npy'):
        return np.load(path).astype(np.float32)
    else:
        raise ValueError(f"Unsupported volume ext: {path}")

def center_crop_or_pad(vol: np.ndarray, target: Tuple[int,int,int]) -> np.ndarray:
    z, y, x = vol.shape[-3:]
    tz, ty, tx = target
    pad_z = max(0, tz - z); pad_y = max(0, ty - y); pad_x = max(0, tx - x)
    if pad_z or pad_y or pad_x:
        vol = np.pad(vol, ((pad_z//2, pad_z - pad_z//2),
                           (pad_y//2, pad_y - pad_y//2),
                           (pad_x//2, pad_x - pad_x//2)), mode='edge')
        z, y, x = vol.shape
    sz = (z - tz)//2; sy = (y - ty)//2; sx = (x - tx)//2
    return vol[sz:sz+tz, sy:sy+ty, sx:sx+tx]

# === Normalization options ===
setattr(CFG, 'norm_mode', 'zscore')           # 'zscore' | 'robust'
setattr(CFG, 'zscore_use_nonzero', True)
setattr(CFG, 'zscore_clip', 3.0)
setattr(CFG, 'norm_min', 0.0)
setattr(CFG, 'norm_max', 99.5)

def normalize_volume(vol: np.ndarray, cfg=CFG) -> np.ndarray:
    mode = getattr(cfg, 'norm_mode', 'zscore')
    if mode == 'robust':
        return robust_min_max(vol, qmin=getattr(cfg,'norm_min',0.0), qmax=getattr(cfg,'norm_max',99.5)).astype(np.float32)
    # z-score (nnU-Net-like)
    nz = vol != 0 if getattr(cfg, 'zscore_use_nonzero', True) else np.ones_like(vol, dtype=bool)
    if nz.any():
        m = vol[nz].mean(); s = vol[nz].std()
    else:
        m = vol.mean(); s = vol.std()
    s = s + 1e-8
    vol = (vol - m) / s
    clipv = getattr(cfg, 'zscore_clip', 0.0) or 0.0
    if clipv and clipv > 0:
        vol = np.clip(vol, -clipv, clipv)
    return vol.astype(np.float32)

print('Normalization mode:', CFG.norm_mode)


## 4) Dataset & Dataloader

In [None]:
class VolPairDataset(Dataset):
    def __init__(self, list_file: str, cfg: Config, augment: bool = False):
        self.items = []
        with open(list_file, 'r') as f:
            for line in f:
                line = line.strip()
                if not line: continue
                vol_path, lab_path = line.split()
                self.items.append((vol_path, lab_path))
        self.cfg = cfg
        self.augment = augment

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        vpath, lpath = self.items[idx]
        vol = load_volume(os.path.join(self.cfg.data_root, vpath))
        lab = load_volume(os.path.join(self.cfg.data_root, lpath))

        vol = normalize_volume(vol, self.cfg)
        lab = (lab > 0).astype(np.float32)

        vol = center_crop_or_pad(vol, self.cfg.patch_size)
        lab = center_crop_or_pad(lab, self.cfg.patch_size)

        vol = vol[None, ...]
        lab = lab[None, ...]

        if self.augment:
            if random.random() < 0.5:
                vol = vol[:, ::-1, :, :]; lab = lab[:, ::-1, :, :]
            if random.random() < 0.5:
                vol = vol[:, :, ::-1, :]; lab = lab[:, :, ::-1, :]
            if random.random() < 0.5:
                vol = vol[:, :, :, ::-1]; lab = lab[:, :, :, ::-1]

        vol = torch.from_numpy(vol.copy())
        lab = torch.from_numpy(lab.copy())
        return vol, lab

def make_loaders(cfg: Config):
    train_ds = VolPairDataset(cfg.train_list, cfg, augment=True)
    val_ds   = VolPairDataset(cfg.val_list,   cfg, augment=False)
    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=1,             shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader


## 5) Model Blocks: Residual + SE (3D)

In [None]:
class SE3D(nn.Module):
    def __init__(self, channels: int, reduction: int = 8):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc1 = nn.Conv3d(channels, max(1, channels // reduction), kernel_size=1)
        self.act = nn.SiLU(inplace=True)
        self.fc2 = nn.Conv3d(max(1, channels // reduction), channels, kernel_size=1)
        self.gate = nn.Sigmoid()
    def forward(self, x):
        w = self.pool(x)
        w = self.fc2(self.act(self.fc1(w)))
        return x * self.gate(w)

class ConvResSE3D(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, se_reduction: int = 8, dropout: float = 0.0):
        super().__init__()
        self.conv1 = nn.Conv3d(in_ch, out_ch, 3, padding=1)
        self.norm1 = nn.InstanceNorm3d(out_ch)
        self.conv2 = nn.Conv3d(out_ch, out_ch, 3, padding=1)
        self.norm2 = nn.InstanceNorm3d(out_ch)
        self.act = nn.LeakyReLU(0.1, inplace=True)
        self.se = SE3D(out_ch, se_reduction)
        self.drop = nn.Dropout3d(p=dropout) if dropout > 0 else nn.Identity()
        self.skip = nn.Identity() if in_ch == out_ch else nn.Conv3d(in_ch, out_ch, 1)
    def forward(self, x):
        identity = self.skip(x)
        x = self.act(self.norm1(self.conv1(x)))
        x = self.norm2(self.conv2(x))
        x = x + identity
        x = self.se(x)
        x = self.act(x)
        x = self.drop(x)
        return x


## 6) Lightweight 3D U‚ÄëNet (baseline option)

In [None]:
class ConvBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1), nn.InstanceNorm3d(out_ch), nn.LeakyReLU(0.1, inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1), nn.InstanceNorm3d(out_ch), nn.LeakyReLU(0.1, inplace=True),
        )
    def forward(self, x):
        return self.conv(x)

class UNet3D_Light(nn.Module):
    def __init__(self, in_ch=1, n_classes=1, base=16):
        super().__init__()
        self.enc1 = ConvBlock3D(in_ch, base)
        self.pool1 = nn.MaxPool3d(2)
        self.enc2 = ConvBlock3D(base, base*2)
        self.pool2 = nn.MaxPool3d(2)
        self.enc3 = ConvBlock3D(base*2, base*4)
        self.pool3 = nn.MaxPool3d(2)
        self.bottleneck = ConvBlock3D(base*4, base*8)
        self.up3 = nn.ConvTranspose3d(base*8, base*4, 2, stride=2)
        self.dec3 = ConvBlock3D(base*8, base*4)
        self.up2 = nn.ConvTranspose3d(base*4, base*2, 2, stride=2)
        self.dec2 = ConvBlock3D(base*4, base*2)
        self.up1 = nn.ConvTranspose3d(base*2, base, 2, stride=2)
        self.dec1 = ConvBlock3D(base*2, base)
        self.outc = nn.Conv3d(base, n_classes, 1)
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b  = self.bottleneck(self.pool3(e3))
        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        logits = self.outc(d1)
        return logits


## 7) **SE‚ÄëResidual U‚ÄëNet++ (3D) with Deep Supervision**

In [None]:
class UNetPP3D_SERes(nn.Module):
    def __init__(self, in_ch: int, n_classes: int, base: int = 24, se_reduction: int = 8, dropout: float = 0.0):
        super().__init__()
        ch = [base, base*2, base*4, base*8]
        self.pool = nn.MaxPool3d(2)
        self.up3_2 = nn.ConvTranspose3d(ch[3], ch[2], 2, stride=2)
        self.up2_1 = nn.ConvTranspose3d(ch[2], ch[1], 2, stride=2)
        self.up1_0 = nn.ConvTranspose3d(ch[1], ch[0], 2, stride=2)

        self.x00 = ConvResSE3D(in_ch, ch[0], se_reduction, dropout)
        self.x10 = ConvResSE3D(ch[0], ch[1], se_reduction, dropout)
        self.x20 = ConvResSE3D(ch[1], ch[2], se_reduction, dropout)
        self.x30 = ConvResSE3D(ch[2], ch[3], se_reduction, dropout)

        self.x01 = ConvResSE3D(ch[0] + ch[1], ch[0], se_reduction, dropout)
        self.x11 = ConvResSE3D(ch[1] + ch[2], ch[1], se_reduction, dropout)
        self.x21 = ConvResSE3D(ch[2] + ch[3], ch[2], se_reduction, dropout)

        self.x02 = ConvResSE3D(ch[0] + ch[0] + ch[1], ch[0], se_reduction, dropout)
        self.x12 = ConvResSE3D(ch[1] + ch[1] + ch[2], ch[1], se_reduction, dropout)

        self.x03 = ConvResSE3D(ch[0] + ch[0] + ch[0] + ch[1], ch[0], se_reduction, dropout)

        self.head1 = nn.Conv3d(ch[0], n_classes, 1)
        self.head2 = nn.Conv3d(ch[0], n_classes, 1)
        self.head3 = nn.Conv3d(ch[0], n_classes, 1)

    def forward(self, x):
        x00 = self.x00(x)
        x10 = self.x10(self.pool(x00))
        x20 = self.x20(self.pool(x10))
        x30 = self.x30(self.pool(x20))

        x01 = self.x01(torch.cat([x00, self.up1_0(x10)], dim=1))
        x11 = self.x11(torch.cat([x10, self.up2_1(x20)], dim=1))
        x21 = self.x21(torch.cat([x20, self.up3_2(x30)], dim=1))

        x02 = self.x02(torch.cat([x00, x01, self.up1_0(x11)], dim=1))
        x12 = self.x12(torch.cat([x10, x11, self.up2_1(x21)], dim=1))

        x03 = self.x03(torch.cat([x00, x01, x02, self.up1_0(x12)], dim=1))

        y1 = self.head1(x01)
        y2 = self.head2(x02)
        y3 = self.head3(x03)
        return [y1, y2, y3]


## 8) Build Model Helper

In [None]:
def build_model(cfg: Config) -> nn.Module:
    if cfg.model_name == 'seresunetpp_ds':
        return UNetPP3D_SERes(cfg.in_channels, cfg.num_classes, base=cfg.base_channels,
                              se_reduction=cfg.se_reduction, dropout=cfg.dropout)
    elif cfg.model_name == 'unet3d_light':
        return UNet3D_Light(cfg.in_channels, cfg.num_classes, base=cfg.base_channels)
    else:
        raise ValueError(f"Unknown model: {cfg.model_name}")


## 9) Losses & Metrics (Dice + BCE) with Deep Supervision Support

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        dims = (2,3,4)
        intersection = (probs * targets).sum(dim=dims)
        union = probs.sum(dim=dims) + targets.sum(dim=dims)
        dice = (2*intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

class BCEDiceLoss(nn.Module):
    def __init__(self, bce_weight=0.5, smooth=1.0):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss(smooth)
        self.w = bce_weight
    def forward(self, logits, targets):
        return self.w * self.bce(logits, targets) + (1 - self.w) * self.dice(logits, targets)

def dice_coefficient_from_logits(logits, targets, thr=0.5):
    with torch.no_grad():
        probs = torch.sigmoid(logits)
        preds = (probs > thr).float()
        dims = (2,3,4)
        inter = (preds * targets).sum(dim=dims)
        union = preds.sum(dim=dims) + targets.sum(dim=dims)
        dice = (2*inter) / (union + 1e-8)
        return dice.mean().item()

def compute_loss(logits, targets, loss_fn: nn.Module, ds_weights: Tuple[float, ...]):
    if isinstance(logits, list):
        assert len(logits) == len(ds_weights), f"DS heads ({len(logits)}) vs weights ({len(ds_weights)})"
        return sum(w * loss_fn(l, targets) for l, w in zip(logits, ds_weights))
    else:
        return loss_fn(logits, targets)


## 10) (Optional) SSL Loader & EMA

In [None]:
class ModelEMA:
    def __init__(self, model, decay=0.999):
        self.model = model
        import copy
        self.ema = copy.deepcopy(model)
        for p in self.ema.parameters():
            p.requires_grad_(False)
        self.decay = decay
    @torch.no_grad()
    def update(self):
        msd, esd = self.model.state_dict(), self.ema.state_dict()
        for k in msd.keys():
            esd[k].mul_(self.decay).add_(msd[k], alpha=1 - self.decay)

def load_ssl_weights(model: nn.Module, ckpt_path: Optional[str] = None):
    if not ckpt_path or not os.path.exists(ckpt_path):
        print("[SSL] No checkpoint provided ‚Äî skipping.")
        return
    print(f"[SSL] Loading encoder weights from {ckpt_path}")
    state = torch.load(ckpt_path, map_location='cpu')
    missing, unexpected = model.load_state_dict(state, strict=False)
    print("[SSL] Missing:", len(missing), "Unexpected:", len(unexpected))


## 11) Scheduler & Training (cosine warmup + accumulation + AMP)

In [None]:
from torch.optim.lr_scheduler import SequentialLR, LambdaLR, CosineAnnealingLR

# Scheduler extras
setattr(CFG, 'warmup_epochs', 5)
setattr(CFG, 'min_lr', 1e-6)
setattr(CFG, 'accum_steps', 1)

def build_scheduler(optimizer, cfg):
    warmup_epochs = int(getattr(cfg, 'warmup_epochs', 0))
    t_max = max(1, cfg.epochs - warmup_epochs)
    schedulers = []
    milestones = []
    if warmup_epochs > 0:
        warmup = LambdaLR(optimizer, lr_lambda=lambda e: (e + 1) / max(1, warmup_epochs))
        schedulers.append(warmup)
        milestones.append(warmup_epochs)
    cosine = CosineAnnealingLR(optimizer, T_max=t_max, eta_min=getattr(cfg, 'min_lr', 0.0))
    if milestones:
        return SequentialLR(optimizer, schedulers + [cosine], milestones=milestones)
    else:
        return cosine

def validate(model, loader, device, cfg: Config):
    model.eval()
    dices = []
    with torch.no_grad():
        for vol, lab in loader:
            vol = vol.to(device, non_blocking=True)
            lab = lab.to(device, non_blocking=True)
            out = model(vol)
            logits = out[-1] if isinstance(out, list) else out
            dices.append(dice_coefficient_from_logits(logits, lab, cfg.threshold))
    return float(np.mean(dices)) if dices else 0.0

def train(cfg: Config = CFG):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device:", device)
    train_loader, val_loader = make_loaders(cfg)

    model = build_model(cfg).to(device)
    load_ssl_weights(model, ckpt_path=None)

    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = build_scheduler(optimizer, cfg)
    loss_fn = BCEDiceLoss(bce_weight=cfg.bce_weight, smooth=cfg.dice_smooth)
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
    ema = ModelEMA(model, decay=cfg.ema_decay) if cfg.use_ema else None

    best_dice = -1.0
    log_path = os.path.join(cfg.out_dir, 'train_log.txt')
    with open(log_path, 'w') as f:
        f.write('epoch,loss,valDice,best,lr\n')

    for epoch in range(1, cfg.epochs + 1):
        model.train()
        optimizer.zero_grad(set_to_none=True)
        epoch_loss, n = 0.0, 0
        micro = 0
        t0 = time.time()

        for i, (vol, lab) in enumerate(train_loader, 1):
            vol = vol.to(device, non_blocking=True)
            lab = lab.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=cfg.amp):
                out = model(vol)
                loss_raw = compute_loss(out, lab, loss_fn, cfg.ds_weights)
                loss = loss_raw / max(1, cfg.accum_steps)
            scaler.scale(loss).backward()
            micro += 1
            if micro >= cfg.accum_steps:
                if cfg.grad_clip:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                if ema: ema.update()
                micro = 0
            epoch_loss += loss_raw.item() * vol.size(0)
            n += vol.size(0)

        # Flush remainder
        if micro > 0:
            if cfg.grad_clip:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            if ema: ema.update()

        avg_loss = epoch_loss / max(1, n)
        eval_model = ema.ema if ema else model
        val_dice = validate(eval_model, val_loader, device, cfg)
        best_dice = max(best_dice, val_dice)

        lr_now = optimizer.param_groups[0]['lr']
        line = f"Epoch {epoch:03d} | loss {avg_loss:.4f} | valDice {val_dice:.4f} | best {best_dice:.4f} | lr {lr_now:.2e}"
        print(line, f"| {time.time()-t0:.1f}s")
        with open(log_path, 'a') as f:
            f.write(f"{epoch},{avg_loss:.6f},{val_dice:.6f},{best_dice:.6f},{lr_now:.8f}\n")

        scheduler.step()

    print("Training finished. Best Dice:", best_dice)


## 12) **Quick Sanity Checks**

In [None]:
@torch.no_grad()
def peek_batch(cfg: Config = CFG):
    tl, _ = make_loaders(cfg)
    vol, lab = next(iter(tl))
    print("Normalization:", getattr(cfg, 'norm_mode', 'zscore'))
    print("vol:", vol.shape, "lab:", lab.shape, "min/max:", float(vol.min()), float(vol.max()))
    print("unique lab values:", torch.unique(lab))

def quick_overfit_steps(steps=20, cfg: Config = CFG):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tl, _ = make_loaders(cfg)
    vol, lab = next(iter(tl))
    vol, lab = vol.to(device), lab.to(device)

    model = build_model(cfg).to(device)
    loss_fn = BCEDiceLoss(bce_weight=cfg.bce_weight, smooth=cfg.dice_smooth)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)

    for i in range(1, steps+1):
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=cfg.amp):
            out = model(vol)
            loss = compute_loss(out, lab, loss_fn, cfg.ds_weights)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        last = out[-1] if isinstance(out, list) else out
        d = dice_coefficient_from_logits(last, lab, cfg.threshold)
        print(f"step {i:02d} loss {loss.item():.4f} dice {d:.4f}")


## 13) Validation Threshold Sweep

In [None]:
import numpy as np
@torch.no_grad()
def sweep_val_thresholds(model_path: str, thresholds=None, cfg: Config = CFG):
    if thresholds is None:
        thresholds = np.arange(0.10, 0.91, 0.05)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = build_model(cfg).to(device)
    state = torch.load(model_path, map_location=device)
    state_dict = state['model'] if isinstance(state, dict) and 'model' in state else state
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    _, val_loader = make_loaders(cfg)

    results = []
    for thr in thresholds:
        dices = []
        for vol, lab in val_loader:
            vol = vol.to(device, non_blocking=True)
            lab = lab.to(device, non_blocking=True)
            out = model(vol)
            logits = out[-1] if isinstance(out, list) else out
            probs = torch.sigmoid(logits)
            preds = (probs > float(thr)).float()
            dims = (2,3,4)
            inter = (preds * lab).sum(dim=dims)
            union = preds.sum(dim=dims) + lab.sum(dim=dims)
            dice = (2*inter) / (union + 1e-8)
            dices.append(dice.mean().item())
        mean_dice = float(np.mean(dices)) if dices else 0.0
        results.append((float(thr), mean_dice))

    best_thr, best_dice = max(results, key=lambda t: t[1]) if results else (0.5, 0.0)
    print("Threshold sweep results:")
    for thr, sc in results:
        print(f"{thr:.2f}: {sc:.4f}")
    print(f"Best threshold = {best_thr:.2f}, Dice = {best_dice:.4f}")
    return {'scores': results, 'best': (best_thr, best_dice)}


## 14) Entry Point

In [None]:
if __name__ == "__main__":
    # 1) Sanity checks
    # peek_batch()
    # quick_overfit_steps(steps=30)

    # 2) Train
    # train(CFG)
    pass
