In [None]:
import numpy as np
import pandas as pd
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:
import os, glob, time, json, math, random
from dataclasses import dataclass, asdict
from typing import List, Tuple, Dict

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import models
from torchvision.transforms import functional as TF

import matplotlib.pyplot as plt 

try:
    from tqdm import tqdm
except Exception:
    def tqdm(x, **kwargs): return x

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
@dataclass
class Config:
    data_dir: str = "/kaggle/input/lgg-mri-segmentation/kaggle_3m" 
    out_dir: str  = "artifacts_lgg_en"
    run_name: str = "lgg_multi_seg_en"

    img_size: Tuple[int,int] = (256, 256)
    val_split: float = 0.10
    test_split: float = 0.10
    num_workers: int = 2
    imagenet_norm: bool = True  

    epochs: int = 40
    warmup_epochs_frozen: int = 4  
    batch_size: int = 8
    lr: float = 3e-4
    weight_decay: float = 1e-4
    mixed_precision: bool = True

    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    pin_memory: bool = True

cfg = Config()
os.makedirs(cfg.out_dir, exist_ok=True)
seed_everything(cfg.seed)
print("Device:", cfg.device)
print("Output dir:", cfg.out_dir)
print(cfg)

In [None]:
IMG_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")

def is_image(p: str) -> bool:
    return os.path.splitext(p)[1].lower() in IMG_EXTS

def _patient_id_from(dirpath: str, root: str) -> str:
    rel = os.path.relpath(dirpath, root)
    parts = [p for p in rel.split(os.sep) if p not in (".", "")]
    return parts[0] if parts else os.path.basename(dirpath)

def _normalize_stem(basename: str) -> str:
    stem = os.path.splitext(basename)[0].lower()
    tokens = [
        "_mask", "-mask", " mask",
        "_seg", "-seg", " seg",
        "_segmentation", " segmentation",
        "_gt", "-gt", " groundtruth", " ground_truth",
        "_annotation", "-annotation"
    ]
    for t in tokens:
        stem = stem.replace(t, "")
    stem = stem.replace("-", "_").replace(" ", "")
    while "__" in stem:
        stem = stem.replace("__", "_")
    return stem

def list_lgg_pairs(root: str):
    if not os.path.isdir(root):
        raise FileNotFoundError(f"Data root not found: {root}")

    images, masks = {}, {}
    for dirpath, _, files in os.walk(root):
        rel_parts = os.path.relpath(dirpath, root).split(os.sep)
        in_mask_dir = any(p.lower() in ("mask", "masks") for p in rel_parts)
        for f in files:
            if not is_image(f): continue
            full = os.path.join(dirpath, f)
            pid = _patient_id_from(dirpath, root)
            base = os.path.basename(f)
            is_mask_file = ("mask" in base.lower()) or ("seg" in base.lower()) or ("gt" in base.lower()) or ("annot" in base.lower()) or in_mask_dir
            key = (pid, _normalize_stem(base))
            if is_mask_file:
                masks.setdefault(key, []).append(full)
            else:
                images.setdefault(key, []).append(full)

    pairs = []
    for key in sorted(set(images.keys()) & set(masks.keys())):
        pairs.append((images[key][0], masks[key][0], key[0])) 

    if not pairs:
        candidates = []
        for dirpath, _, files in os.walk(root):
            for f in files:
                if is_image(f) and ("mask" in f.lower() or "seg" in f.lower() or "gt" in f.lower() or "annot" in f.lower() or "mask" in dirpath.lower() or "masks" in dirpath.lower()):
                    candidates.append(os.path.join(dirpath, f))
        print(f"[Diagnostics] Found {len(candidates)} mask-like files.")
        for p in candidates[:10]:
            print("  -", os.path.relpath(p, root))
        raise RuntimeError("No (image, mask) pairs found. Review naming patterns or folder structure.")

    return pairs

def split_by_patient(pairs, val_ratio: float, test_ratio: float, seed: int=42):
    patients = sorted({pid for _,_,pid in pairs})
    rng = random.Random(seed)
    rng.shuffle(patients)
    n = len(patients)
    n_test = int(n * test_ratio)
    n_val  = int(n * val_ratio)
    n_train = n - n_val - n_test
    train_p = set(patients[:n_train])
    val_p   = set(patients[n_train:n_train+n_val])
    test_p  = set(patients[n_train+n_val:])
    tr = [p for p in pairs if p[2] in train_p]
    va = [p for p in pairs if p[2] in val_p]
    te = [p for p in pairs if p[2] in test_p]
    return tr, va, te, patients

class LGGSegmentationDataset(Dataset):
    IMAGENET_MEAN = [0.485, 0.456, 0.406]
    IMAGENET_STD  = [0.229, 0.224, 0.225]
    def __init__(self, pairs, img_size, augment=True, imagenet_norm=True):
        self.pairs = pairs
        self.img_size = img_size
        self.augment = augment
        self.imagenet_norm = imagenet_norm

    def __len__(self): return len(self.pairs)

    def _load(self, ip, mp):
        img = Image.open(ip).convert("RGB")
        msk = Image.open(mp).convert("L")
        return img, msk

    def _resize(self, img, msk):
        H,W = self.img_size
        img = TF.resize(img, size=[H,W], interpolation=TF.InterpolationMode.BILINEAR)
        msk = TF.resize(msk, size=[H,W], interpolation=TF.InterpolationMode.NEAREST)
        return img, msk

    def _augment(self, img, msk):
        if random.random() < 0.5:
            img = TF.hflip(img); msk = TF.hflip(msk)
        angle = random.uniform(-20, 20)
        translate = (int(random.uniform(-0.05,0.05)*img.width),
                     int(random.uniform(-0.05,0.05)*img.height))
        scale = random.uniform(0.95, 1.05)
        shear = random.uniform(-5, 5)
        img = TF.affine(img, angle=angle, translate=translate, scale=scale, shear=[shear,0.0],
                        interpolation=TF.InterpolationMode.BILINEAR)
        msk = TF.affine(msk, angle=angle, translate=translate, scale=scale, shear=[shear,0.0],
                        interpolation=TF.InterpolationMode.NEAREST)
        return img, msk

    def __getitem__(self, idx):
        ip, mp, pid = self.pairs[idx]
        img, msk = self._load(ip, mp)
        img, msk = self._resize(img, msk)
        if self.augment:
            img, msk = self._augment(img, msk)
        img_t = TF.to_tensor(img)  
        if self.imagenet_norm:
            img_t = TF.normalize(img_t, mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
        msk_t = torch.from_numpy(np.array(msk, dtype=np.uint8))
        msk_t = (msk_t > 127).float().unsqueeze(0) 
        return img_t, msk_t, ip, mp, pid

def make_dataloaders(cfg):
    pairs = list_lgg_pairs(cfg.data_dir)
    tr, va, te, patients = split_by_patient(pairs, cfg.val_split, cfg.test_split, cfg.seed)
    print(f"Patients: {len(patients)}")
    print(f"Pairs: train={len(tr)} | val={len(va)} | test={len(te)}")

    pin = cfg.pin_memory and (cfg.device == "cuda")
    train_ds = LGGSegmentationDataset(tr, cfg.img_size, augment=True,  imagenet_norm=cfg.imagenet_norm)
    val_ds   = LGGSegmentationDataset(va, cfg.img_size, augment=False, imagenet_norm=cfg.imagenet_norm)
    test_ds  = LGGSegmentationDataset(te, cfg.img_size, augment=False, imagenet_norm=cfg.imagenet_norm)

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,  num_workers=cfg.num_workers, pin_memory=pin, drop_last=True)
    val_loader   = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=pin)
    test_loader  = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=pin)
    return train_loader, val_loader, test_loader

In [6]:
# Run once after setting cfg.data_dir to verify pairing:
# pairs = list_lgg_pairs(cfg.data_dir)
# print("Total pairs:", len(pairs))
# for i, (ip, mp, pid) in enumerate(pairs[:5]):
#     print(f"{i+1:02d} [patient={pid}]")
#     print("  image:", os.path.relpath(ip, cfg.data_dir))
#     print("  mask :", os.path.relpath(mp, cfg.data_dir))

In [7]:
from torchvision import models

def build_deeplabv3_resnet50():
    try:
        w = models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT
        m = models.segmentation.deeplabv3_resnet50(weights=w)
    except Exception:
        m = models.segmentation.deeplabv3_resnet50(weights=None)
    in_ch = m.classifier[-1].in_channels if hasattr(m.classifier[-1], "in_channels") else 256
    m.classifier[-1] = nn.Conv2d(in_ch, 1, kernel_size=1)
    if hasattr(m, "aux_classifier") and m.aux_classifier is not None:
        in_ch_aux = m.aux_classifier[-1].in_channels if hasattr(m.aux_classifier[-1], "in_channels") else 256
        m.aux_classifier[-1] = nn.Conv2d(in_ch_aux, 1, kernel_size=1)
    return m

def build_deeplabv3_resnet101():
    try:
        w = models.segmentation.DeepLabV3_ResNet101_Weights.DEFAULT
        m = models.segmentation.deeplabv3_resnet101(weights=w)
    except Exception:
        m = models.segmentation.deeplab3_resnet101(weights=None)
    in_ch = m.classifier[-1].in_channels if hasattr(m.classifier[-1], "in_channels") else 256
    m.classifier[-1] = nn.Conv2d(in_ch, 1, kernel_size=1)
    if hasattr(m, "aux_classifier") and m.aux_classifier is not None:
        in_ch_aux = m.aux_classifier[-1].in_channels if hasattr(m.aux_classifier[-1], "in_channels") else 256
        m.aux_classifier[-1] = nn.Conv2d(in_ch_aux, 1, kernel_size=1)
    return m

def build_fcn_resnet50():
    try:
        w = models.segmentation.FCN_ResNet50_Weights.DEFAULT
        m = models.segmentation.fcn_resnet50(weights=w)
    except Exception:
        m = models.segmentation.fcn_resnet50(weights=None)
    in_ch = m.classifier[-1].in_channels if hasattr(m.classifier[-1], "in_channels") else 512
    m.classifier[-1] = nn.Conv2d(in_ch, 1, kernel_size=1)
    if hasattr(m, "aux_classifier") and m.aux_classifier is not None:
        in_ch_aux = m.aux_classifier[-1].in_channels if hasattr(m.aux_classifier[-1], "in_channels") else 256
        m.aux_classifier[-1] = nn.Conv2d(in_ch_aux, 1, kernel_size=1)
    return m

def build_lraspp_mobilenet_v3():
    try:
        w = models.segmentation.LRASPP_MobileNet_V3_Large_Weights.DEFAULT
        m = models.segmentation.lraspp_mobilenet_v3_large(weights=w)
    except Exception:
        m = models.segmentation.lraspp_mobilenet_v3_large(weights=None)
    # Adjust final classifier to 1 channel
    m.classifier.high_classifier[4] = nn.Conv2d(128, 1, kernel_size=1)
    return m

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.proj = nn.Identity() if in_ch==out_ch else nn.Conv2d(in_ch, out_ch, 1, bias=False)
    def forward(self, x):
        idt = self.proj(x)
        x = F.relu(self.bn1(self.conv1(x)), inplace=True)
        x = self.bn2(self.conv2(x))
        x = F.relu(x + idt, inplace=True)
        return x

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = ResBlock(in_ch, out_ch)
        self.pool = nn.MaxPool2d(2)
    def forward(self, x):
        x = self.block(x); return x, self.pool(x)

class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, 2)
        self.block = ResBlock(in_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x)
        if x.shape[-1]!=skip.shape[-1] or x.shape[-2]!=skip.shape[-2]:
            dx = skip.shape[-1]-x.shape[-1]; dy = skip.shape[-2]-x.shape[-2]
            x = F.pad(x, [dx//2, dx-dx//2, dy//2, dy-dy//2])
        x = torch.cat([skip, x], dim=1)
        return self.block(x)

class ResUNet(nn.Module):
    def __init__(self, base=64):
        super().__init__()
        self.in_conv = ResBlock(3, base)
        self.d1 = Down(base, base*2)
        self.d2 = Down(base*2, base*4)
        self.d3 = Down(base*4, base*8)
        self.bot = ResBlock(base*8, base*16)
        self.u3 = Up(base*16, base*8)
        self.u2 = Up(base*8, base*4)
        self.u1 = Up(base*4, base*2)
        self.u0 = Up(base*2, base)
        self.out = nn.Conv2d(base, 1, 1)
    def forward(self, x):
        x0 = self.in_conv(x)
        x1, p1 = self.d1(x0)
        x2, p2 = self.d2(p1)
        x3, p3 = self.d3(p2)
        xb = self.bot(p3)
        x = self.u3(xb, x3)
        x = self.u2(x, x2)
        x = self.u1(x, x1)
        x = self.u0(x, x0)
        return self.out(x)

In [8]:
class DiceLoss(nn.Module):
    def __init__(self, smooth: float=1.0):
        super().__init__(); self.smooth=smooth
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        num = 2.0*(probs*targets).sum(dim=(2,3)) + self.smooth
        den = (probs + targets).sum(dim=(2,3)) + self.smooth
        dice = (num/den).mean()
        return 1.0 - dice

def bce_dice_loss(logits, targets):
    return F.binary_cross_entropy_with_logits(logits, targets) + DiceLoss()(logits, targets)

@torch.no_grad()
def seg_counts(logits, targets, thr: float=0.5):
    probs = torch.sigmoid(logits)
    pred = (probs > thr).to(targets.dtype)
    y = targets.to(targets.dtype)
    tp = (pred*y).sum().item()
    fp = (pred*(1-y)).sum().item()
    fn = ((1-pred)*y).sum().item()
    tn = ((1-pred)*(1-y)).sum().item()
    total = y.numel()
    return tp, fp, fn, tn, total

@torch.no_grad()
def seg_metrics(tp, fp, fn, tn, total) -> Dict[str,float]:
    eps = 1e-12
    acc = (tp+tn)/(total+eps)
    prec = tp/(tp+fp+eps)
    rec = tp/(tp+fn+eps)
    f1 = 2*prec*rec/(prec+rec+eps)
    iou = tp/(tp+fp+fn+eps)
    dice = 2*tp/(2*tp+fp+fn+eps)
    return {"accuracy":float(acc), "precision":float(prec), "recall":float(rec), "f1":float(f1), "iou":float(iou), "dice":float(dice)}

In [9]:
def freeze_backbone(model):
    for n,p in model.named_parameters():
        if any(k in n for k in ["classifier", "aux_classifier"]):
            p.requires_grad=True
        else:
            p.requires_grad=False

def unfreeze_all(model):
    for p in model.parameters(): p.requires_grad=True

def forward_logits(model, x):
    out = model(x)
    if isinstance(out, dict) and "out" in out: return out["out"]
    return out

def train_eval_model(model, loaders, cfg, tag: str):
    train_loader, val_loader, test_loader = loaders
    device = cfg.device
    model.to(device)
    scaler = torch.cuda.amp.GradScaler(enabled=(cfg.mixed_precision and (device=='cuda')))

    freeze_backbone(model)
    head_opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=cfg.weight_decay)
    head_sch = torch.optim.lr_scheduler.ReduceLROnPlateau(head_opt, mode='min', factor=0.5, patience=3, verbose=True)

    hist = {"epoch": [], "train_loss": [], "val_loss": [], "val_dice": []}
    best_dice = -1.0
    best_path = os.path.join(cfg.out_dir, f"{cfg.run_name}_{tag}_best.pt")

    def run_epoch(loader, train: bool):
        if train: model.train()
        else: model.eval()
        total_loss = 0.0
        tp=fp=fn=tn=total=0
        for imgs, msks, *_ in tqdm(loader, desc="Train" if train else "Eval", leave=False):
            imgs = imgs.to(device, non_blocking=True)
            msks = msks.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=(cfg.mixed_precision and (device=='cuda'))):
                logits = forward_logits(model, imgs)
                if logits.shape[1] != 1: logits = logits[:, :1, ...]
                loss = bce_dice_loss(logits, msks)
            if train:
                head_opt.zero_grad(set_to_none=True)
                scaler.scale(loss).backward()
                scaler.step(head_opt)
                scaler.update()
            total_loss += loss.item()
            a,b,c,d,t = seg_counts(logits, msks)
            tp+=a; fp+=b; fn+=c; tn+=d; total+=t
        return total_loss/len(loader), seg_metrics(tp,fp,fn,tn,total)

    for epoch in range(1, cfg.warmup_epochs_frozen+1):
        tr_loss, _ = run_epoch(train_loader, True)
        va_loss, va_m = run_epoch(val_loader, False)
        head_sch.step(va_loss)
        hist["epoch"].append(epoch)
        hist["train_loss"].append(tr_loss)
        hist["val_loss"].append(va_loss)
        hist["val_dice"].append(va_m["dice"])
        print(f"[{tag}] Warmup {epoch}/{cfg.warmup_epochs_frozen} | train={tr_loss:.4f} | val={va_loss:.4f} | Dice={va_m['dice']:.4f}")
        if va_m["dice"] > best_dice:
            best_dice = va_m["dice"]
            torch.save({"model_state": model.state_dict(), "val_metrics": va_m, "cfg": asdict(cfg)}, best_path)
            print(f"  ✓ Saved best (Dice={best_dice:.4f}) → {best_path}")

    unfreeze_all(model)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=3, verbose=True)

    for epoch in range(cfg.warmup_epochs_frozen+1, cfg.epochs+1):
        tr_loss, _ = run_epoch(train_loader, True)
        va_loss, va_m = run_epoch(val_loader, False)
        sch.step(va_loss)
        hist["epoch"].append(epoch)
        hist["train_loss"].append(tr_loss)
        hist["val_loss"].append(va_loss)
        hist["val_dice"].append(va_m["dice"])
        print(f"[{tag}] Epoch {epoch}/{cfg.epochs} | train={tr_loss:.4f} | val={va_loss:.4f} "
              f"| Acc={va_m['accuracy']:.4f} P/R/F1={va_m['precision']:.4f}/{va_m['recall']:.4f}/{va_m['f1']:.4f} "
              f"| IoU={va_m['iou']:.4f} | Dice={va_m['dice']:.4f}")
        if va_m['dice'] > best_dice:
            best_dice = va_m['dice']
            torch.save({"model_state": model.state_dict(), "val_metrics": va_m, "cfg": asdict(cfg)}, best_path)
            print(f"  ✓ Saved best (Dice={best_dice:.4f}) → {best_path}")

    # Curves
    def plot_curve(xs, ys, title, xlabel, ylabel, savepath):
        plt.figure(); plt.plot(xs, ys); plt.title(title); plt.xlabel(xlabel); plt.ylabel(ylabel); plt.grid(True)
        plt.savefig(savepath, bbox_inches="tight"); plt.show()
    xs = hist["epoch"]
    plot_curve(xs, hist["train_loss"], f"{tag} — Train Loss", "Epoch", "Loss", os.path.join(cfg.out_dir, f"{cfg.run_name}_{tag}_train_loss.png"))
    plot_curve(xs, hist["val_loss"],   f"{tag} — Val Loss",   "Epoch", "Loss", os.path.join(cfg.out_dir, f"{cfg.run_name}_{tag}_val_loss.png"))
    plot_curve(xs, hist["val_dice"],   f"{tag} — Val Dice",   "Epoch", "Dice", os.path.join(cfg.out_dir, f"{cfg.run_name}_{tag}_val_dice.png"))

    # Test with best
    ckpt = torch.load(best_path, map_location=cfg.device)
    model.load_state_dict(ckpt["model_state"])
    model.eval()
    tp=fp=fn=tn=total=0; test_loss=0.0
    with torch.no_grad():
        for imgs, msks, *_ in tqdm(test_loader, desc="Test", leave=False):
            imgs = imgs.to(device); msks = msks.to(device)
            logits = forward_logits(model, imgs)
            if logits.shape[1] != 1: logits = logits[:, :1, ...]
            test_loss += bce_dice_loss(logits, msks).item()
            a,b,c,d,t = seg_counts(logits, msks)
            tp+=a; fp+=b; fn+=c; tn+=d; total+=t
    test_loss /= len(test_loader)
    test_m = seg_metrics(tp,fp,fn,tn,total)
    return {"tag": tag, "best_val_dice": float(best_dice), "test_loss": float(test_loss), **test_m}


In [10]:
def run_all(cfg):
    train_loader, val_loader, test_loader = make_dataloaders(cfg)
    loaders = (train_loader, val_loader, test_loader)

    builders = [
        ("deeplabv3_resnet50",  build_deeplabv3_resnet50),
        ("deeplabv3_resnet101", build_deeplabv3_resnet101),
        ("fcn_resnet50",        build_fcn_resnet50),
        ("lraspp_mobilenetv3",  build_lraspp_mobilenet_v3),
        ("resunet",             lambda: ResUNet(base=64)),
    ]

    results = []
    for tag, builder in builders:
        print(f"\\n=== Training {tag} ===")
        model = builder()
        res = train_eval_model(model, loaders, cfg, tag)
        results.append(res)

    with open(os.path.join(cfg.out_dir, f"{cfg.run_name}_summary.json"), "w") as f:
        json.dump(results, f, indent=2)

    best = max(results, key=lambda r: r["dice"])

    print("\\n==== Final Comparison (Test) ====")
    for r in results: print(r)
    print("\\n==== Recommended Best Model (by Dice) ====")
    print(best)

    labels = [r["tag"] for r in results]
    scores = [r["dice"] for r in results]
    plt.figure()
    plt.bar(range(len(labels)), scores)
    plt.xticks(range(len(labels)), labels, rotation=15)
    plt.ylabel("Dice")
    plt.title("Model Comparison — Dice (Test)")
    plt.grid(True, axis="y")
    plt.savefig(os.path.join(cfg.out_dir, f"{cfg.run_name}_dice_comparison.png"), bbox_inches="tight")
    plt.show()

    return results, best

In [11]:
@torch.no_grad()
def show_predictions(model, loader, device, max_images: int=3, thr: float=0.5):
    model.eval()
    def logits_out(m, x):
        out = m(x)
        return out["out"] if isinstance(out, dict) and "out" in out else out
    shown = 0
    for imgs, msks, ips, mps, pids in loader:
        imgs = imgs.to(device)
        logits = logits_out(model, imgs)
        if logits.shape[1] != 1: logits = logits[:, :1, ...]
        probs = torch.sigmoid(logits).cpu()
        preds = (probs > thr).float()
        for i in range(imgs.size(0)):
            if shown >= max_images: return
            img = imgs[i].cpu().permute(1,2,0).numpy()
            img = (img - img.min())/(img.max()-img.min()+1e-6)  
            msk = msks[i].cpu().squeeze(0).numpy()
            prd = preds[i].squeeze(0).numpy()
            plt.figure(); plt.title(f"Input — {os.path.basename(ips[i])}"); plt.imshow(img); plt.axis("off"); plt.show()
            plt.figure(); plt.title("Ground Truth"); plt.imshow(msk, cmap="gray"); plt.axis("off"); plt.show()
            plt.figure(); plt.title("Prediction"); plt.imshow(prd, cmap="gray"); plt.axis("off"); plt.show()
            shown += 1

In [None]:
results, best = run_all(cfg)

from pathlib import Path
best_tag = best["tag"]
ckpt_path = Path(cfg.out_dir) / f"{cfg.run_name}_{best_tag}_best.pt"
builder_map = {
    "deeplabv3_resnet50": build_deeplabv3_resnet50,
    "deeplabv3_resnet101": build_deeplabv3_resnet101,
    "fcn_resnet50": build_fcn_resnet50,
    "lraspp_mobilenetv3": build_lraspp_mobilenet_v3,
    "resunet": lambda: ResUNet(base=64),
}
best_model = builder_map[best_tag]().to(cfg.device)
ckpt = torch.load(ckpt_path, map_location=cfg.device)
best_model.load_state_dict(ckpt["model_state"])
train_loader, val_loader, test_loader = make_dataloaders(cfg)
show_predictions(best_model, test_loader, cfg.device, max_images=3)