In [None]:
from google.colab import drive
drive.mount('/content/drive')

# ============================================================
# STEP 0 — ENVIRONMENT & LIBRARIES (IMPORTS + VERSION DUMP)
# Cell 1/9 — Run identifiers, paths, and folders
# ------------------------------------------------------------
from datetime import datetime
from pathlib import Path

# ---- run knobs (same style as yours) ----
DATASET    = "busi"
IMAGE_SIZE = 256
SEED       = 42
BATCH_SIZE = 8
EPOCHS     = 10
AMP_ON     = True
MODEL_TAG  = "unetr_model"
RUN_NAME   = f"{MODEL_TAG}_{DATASET}_IMG{IMAGE_SIZE}_SEED{SEED}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
CONFIG_PATH = "configs/example.yaml"

# ---- standard folders ----
root = Path("/content/drive/MyDrive/unetr_model_busi_Test")
for p in ["logs", "checkpoints", "figures", "runs", "summary"]:
    (root / p).mkdir(parents=True, exist_ok=True)

summary_txt_path  = root / "summary" / f"{RUN_NAME}_env.txt"
summary_json_path = root / "summary" / f"{RUN_NAME}_env.json"


Mounted at /content/drive


In [None]:
# ============================================================
# STEP 0 — ENVIRONMENT & LIBRARIES
# Cell 2/9 — Base imports & safe-import helper
# ------------------------------------------------------------
import os, sys, json, time, platform, importlib, random

def try_import(name: str):
    try:
        mod = importlib.import_module(name)
        ver = getattr(mod, "__version__", "unknown")
        if name == "PIL": ver = getattr(mod, "__version__", ver)
        if name == "cv2": ver = getattr(mod, "__version__", ver)
        return mod, ver
    except Exception as e:
        return None, f"NOT INSTALLED ({type(e).__name__})"

!pip -q install -U monai torchmetrics thop fvcore timm albumentations==1.4.4 psutil pandas


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/50.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.4/150.4 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m56.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# ============================================================
# STEP 0 — ENVIRONMENT & LIBRARIES
# Cell 3/9 — Import common stack (DL, data, viz, utils)
# ------------------------------------------------------------
# Core / data
numpy, np_ver         = try_import("numpy")
pandas, pd_ver        = try_import("pandas")

# Deep Learning
torch, torch_ver      = try_import("torch")
torchvision, tv_ver   = try_import("torchvision")
timm, timm_ver        = try_import("timm")
monai, monai_ver      = try_import("monai")
torchmetrics, tm_ver  = try_import("torchmetrics")

# Aug / IO / Viz
albumentations, alb_ver = try_import("albumentations")
cv2, cv2_ver            = try_import("cv2")
PIL, pil_ver            = try_import("PIL")
matplotlib, mpl_ver     = try_import("matplotlib")

# Utils
yaml, yaml_ver       = try_import("yaml")
sklearn, sk_ver      = try_import("sklearn")
psutil, psutil_ver   = try_import("psutil")

# Profiling
thop, thop_ver       = try_import("thop")
fvcore, fvcore_ver   = try_import("fvcore")


In [None]:
# ============================================================
# STEP 0 — ENVIRONMENT & LIBRARIES
# Cell 4/9 — Device, CUDA/cuDNN, and GPU VRAM discovery
# ------------------------------------------------------------
device         = "cpu"
gpu_name       = "N/A"
total_vram_mb  = "N/A"
total_vram_gb  = "N/A"
cuda_version   = "N/A"
cudnn_version  = "N/A"

if torch is not None:
    cuda_available = torch.cuda.is_available()
    device = "cuda" if cuda_available else "cpu"
    cuda_version = getattr(torch.version, "cuda", "N/A")
    try:
        cudnn_version = str(torch.backends.cudnn.version()) if torch.backends.cudnn.is_available() else "N/A"
    except Exception:
        cudnn_version = "N/A"
    if cuda_available:
        try:
            gpu_name = torch.cuda.get_device_name(0)
            props = torch.cuda.get_device_properties(0)
            total_vram_bytes = getattr(props, "total_memory", 0)
            total_vram_mb = round(total_vram_bytes / (1024**2), 2)
            total_vram_gb = round(total_vram_bytes / (1024**3), 2)
        except Exception:
            gpu_name = "Unknown (query failed)"
            total_vram_mb = "Unknown"
            total_vram_gb = "Unknown"


In [None]:
# ============================================================
# STEP 0 — ENVIRONMENT & LIBRARIES
# Cell 5/9 — Reproducibility (seeds + deterministic flags)
# ------------------------------------------------------------
random.seed(SEED)
if numpy:
    numpy.random.seed(SEED)

if torch is not None:
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass
    try:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception:
        pass


In [None]:
# ============================================================
# STEP 0 — ENVIRONMENT & LIBRARIES
# Cell 6/9 — Assemble environment snapshot dict
# ------------------------------------------------------------
env_info = {
    "run": {
        "run_name": RUN_NAME,
        "datetime": datetime.now().isoformat(timespec="seconds"),
        "dataset": DATASET,
        "image_size": IMAGE_SIZE,
        "batch_size": BATCH_SIZE,
        "epochs": EPOCHS,
        "amp_on": AMP_ON,
        "seed": SEED,
        "config_path": CONFIG_PATH if Path(CONFIG_PATH).exists() else f"{CONFIG_PATH} (not found)",
    },
    "system": {
        "python": sys.version.split()[0],
        "platform": platform.platform(),
        "device": device,
        "gpu_name": gpu_name,
        "gpu_total_vram_mb": total_vram_mb,
        "gpu_total_vram_gb": total_vram_gb,
        "cuda_version": cuda_version,
        "cudnn_version": cudnn_version,
    },
    "libraries": {
        "torch": torch_ver,
        "torchvision": tv_ver,
        "timm": timm_ver,
        "monai": monai_ver,
        "torchmetrics": tm_ver,
        "numpy": np_ver,
        "pandas": pd_ver,
        "albumentations": alb_ver,
        "opencv-python (cv2)": cv2_ver,
        "Pillow (PIL)": pil_ver,
        "matplotlib": mpl_ver,
        "pyyaml": yaml_ver,
        "scikit-learn": sk_ver,
        "psutil": psutil_ver,
        "thop": thop_ver,
        "fvcore": fvcore_ver,
    },
}


In [None]:
# ============================================================
# STEP 0 — ENVIRONMENT & LIBRARIES
# Cell 7/9 — Pretty print snapshot to console
# ------------------------------------------------------------
border = "=" * 70
print(border)
print("STEP 0 — ENVIRONMENT & LIBRARIES (IMPORTS + VERSION DUMP)")
print(border)
print(f"Run Name      : {env_info['run']['run_name']}")
print(f"Date/Time     : {env_info['run']['datetime']}")
print(f"Dataset       : {env_info['run']['dataset']}")
print(f"Image Size    : {env_info['run']['image_size']}")
print(f"Batch Size    : {env_info['run']['batch_size']}")
print(f"Epochs        : {env_info['run']['epochs']}")
print(f"AMP (mixed precision): {env_info['run']['amp_on']}")
print(f"Seed          : {env_info['run']['seed']}")
print(f"Config Path   : {env_info['run']['config_path']}")
print("-" * 70)
print(f"Python        : {env_info['system']['python']}")
print(f"Platform      : {env_info['system']['platform']}")
print(f"Device        : {env_info['system']['device']}")
print(f"GPU           : {env_info['system']['gpu_name']}")
print(f"GPU VRAM      : {env_info['system']['gpu_total_vram_mb']} MB ({env_info['system']['gpu_total_vram_gb']} GB)")
print(f"CUDA / cuDNN  : {env_info['system']['cuda_version']} / {env_info['system']['cudnn_version']}")
print("-" * 70)
print("Libraries:")
for lib, ver in env_info["libraries"].items():
    print(f"  - {lib:<24} {ver}")
print(border)


STEP 0 — ENVIRONMENT & LIBRARIES (IMPORTS + VERSION DUMP)
Run Name      : unetr_model_busi_IMG256_SEED42_2025-11-04_16-43-23
Date/Time     : 2025-11-04T16:44:36
Dataset       : busi
Image Size    : 256
Batch Size    : 8
Epochs        : 10
AMP (mixed precision): True
Seed          : 42
Config Path   : configs/example.yaml (not found)
----------------------------------------------------------------------
Python        : 3.12.12
Platform      : Linux-6.6.105+-x86_64-with-glibc2.35
Device        : cpu
GPU           : N/A
GPU VRAM      : N/A MB (N/A GB)
CUDA / cuDNN  : 12.6 / 91002
----------------------------------------------------------------------
Libraries:
  - torch                    2.8.0+cu126
  - torchvision              0.23.0+cu126
  - timm                     1.0.21
  - monai                    1.5.1
  - torchmetrics             1.8.2
  - numpy                    2.0.2
  - pandas                   2.3.3
  - albumentations           1.4.4
  - opencv-python (cv2)      4.12.0
  - 

In [None]:
# ============================================================
# STEP 0 — ENVIRONMENT & LIBRARIES
# Cell 8/9 — Save TXT + JSON environment snapshots
# ------------------------------------------------------------
with open(summary_txt_path, "w", encoding="utf-8") as f:
    f.write(border + "\n")
    f.write("STEP 0 — ENVIRONMENT & LIBRARIES (IMPORTS + VERSION DUMP)\n")
    f.write(border + "\n")
    for section, payload in env_info.items():
        f.write(f"[{section.UPPER()}]\n" if hasattr(section, 'UPPER') else f"[{section.upper()}]\n")
        if isinstance(payload, dict):
            for k, v in payload.items():
                if isinstance(v, dict):
                    f.write(f"  {k}:\n")
                    for kk, vv in v.items():
                        f.write(f"    - {kk}: {vv}\n")
                else:
                    f.write(f"  - {k}: {v}\n")
        f.write("\n")

with open(summary_json_path, "w", encoding="utf-8") as f:
    json.dump(env_info, f, indent=2)

print(f"Saved environment snapshots to:\n  • {summary_txt_path}\n  • {summary_json_path}")


Saved environment snapshots to:
  • /content/drive/MyDrive/unetr_model_busi_Test/summary/unetr_model_busi_IMG256_SEED42_2025-11-04_16-43-23_env.txt
  • /content/drive/MyDrive/unetr_model_busi_Test/summary/unetr_model_busi_IMG256_SEED42_2025-11-04_16-43-23_env.json


In [None]:
# ============================================================
# STEP 0 — ENVIRONMENT & LIBRARIES
# Cell 9/9 — Initialize per-run CSV log header
# ------------------------------------------------------------
csv_path = root / "logs" / f"{RUN_NAME}.csv"
if not csv_path.exists():
    with open(csv_path, "w", encoding="utf-8") as f:
        f.write("epoch,lr,train_loss,val_loss,train_dice,val_dice,train_iou,val_iou,epoch_time\n")
print(f"Initialized log CSV (if new): {csv_path}")


Initialized log CSV (if new): /content/drive/MyDrive/unetr_model_busi_Test/logs/unetr_model_busi_IMG256_SEED42_2025-11-04_16-43-23.csv


In [None]:
# ============================================================
# STEP 1 — DATASET DOWNLOAD (BUSI, 256px) → MedSegBench cache
# Cell 1/6 — Setup: cache dir, env var, and size
# ------------------------------------------------------------
import os
from pathlib import Path

SIZE = 256
cache_root = Path("/content/data/MedSegBenchCache")
cache_root.mkdir(parents=True, exist_ok=True)
os.environ["MEDSEGBENCH_DIR"] = str(cache_root)

print(f"[INFO] MEDSEGBENCH_DIR = {cache_root.resolve()}")
print(f"[INFO] Target resolution = {SIZE}px")

!pip -q install medsegbench


[INFO] MEDSEGBENCH_DIR = /content/data/MedSegBenchCache
[INFO] Target resolution = 256px
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# ============================================================
# STEP 1 — DATASET DOWNLOAD
# Cell 2/6 — Dataset source details (Zenodo v1 record)
# ------------------------------------------------------------
target_name = f"busi_{SIZE}.npz"
target_path = cache_root / target_name

url = f"https://zenodo.org/records/13358372/files/{target_name}?download=1"

# ✅ Put your BUSI_256 MD5 here (you mentioned you have it)
expected_md5 = "198aea70968b71adf593b32c41a6e995"

print(f"[INFO] Target file  : {target_name}")
print(f"[INFO] Download URL : {url}")
print(f"[INFO] Expected MD5 : {expected_md5}")


[INFO] Target file  : busi_256.npz
[INFO] Download URL : https://zenodo.org/records/13358372/files/busi_256.npz?download=1
[INFO] Expected MD5 : 198aea70968b71adf593b32c41a6e995


In [None]:
# ============================================================
# STEP 1 — DATASET DOWNLOAD
# Cell 3/6 — Helpers (md5sum + download runners)
# ------------------------------------------------------------
import hashlib, subprocess, shutil

def md5sum(path: Path) -> str:
    h = hashlib.md5()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(8192), b""):
            h.update(chunk)
    return h.hexdigest()

def download_file(url: str, out_path: Path) -> None:
    curl = shutil.which("curl")
    if curl:
        print("[INFO] Downloading with curl ...")
        subprocess.run([curl, "-L", "-f", url, "-o", str(out_path)], check=True)
        return
    wget = shutil.which("wget")
    if wget:
        print("[INFO] curl not found; downloading with wget ...")
        subprocess.run([wget, "-O", str(out_path), url], check=True)
        return
    raise RuntimeError("Neither curl nor wget is available on PATH.")


In [None]:
# ============================================================
# STEP 1 — DATASET DOWNLOAD
# Cell 4/6 — Download (idempotent)
# ------------------------------------------------------------
if not target_path.exists():
    print(f"[INFO] Downloading to {target_path} ...")
    try:
        download_file(url, target_path)
    except subprocess.CalledProcessError as e:
        raise RuntimeError(f"Downloader failed with return code {e.returncode}.") from e
else:
    print(f"[INFO] File already present: {target_path}")


[INFO] Downloading to /content/data/MedSegBenchCache/busi_256.npz ...
[INFO] Downloading with curl ...


In [None]:
# ============================================================
# STEP 1 — DATASET DOWNLOAD
# Cell 5/6 — Integrity check (MD5)
# ------------------------------------------------------------
got = md5sum(target_path)
print(f"[INFO] MD5 (computed): {got}")
if expected_md5 and got != expected_md5:
    raise RuntimeError(
        f"MD5 mismatch for {target_name}. Expected {expected_md5}, got {got}.\n"
        "Delete the file and rerun this step to redownload."
    )
print(f"✅ Download + MD5 OK → {target_path}")


[INFO] MD5 (computed): 198aea70968b71adf593b32c41a6e995
✅ Download + MD5 OK → /content/data/MedSegBenchCache/busi_256.npz


In [None]:
# ============================================================
# STEP 1 — DATASET DOWNLOAD
# Cell 6/6 — Ready message
# ------------------------------------------------------------
print("[READY] busi (256px) cached in MEDSEGBENCH_DIR.")
print("[NEXT] STEP 2: Reproducibility & Config Lock; STEP 3: load predefined splits and print counts.")


[READY] busi (256px) cached in MEDSEGBENCH_DIR.
[NEXT] STEP 2: Reproducibility & Config Lock; STEP 3: load predefined splits and print counts.


In [None]:
# ============================================================
# STEP 2 — REPRODUCIBILITY & CONFIG LOCK
# Cell 1/5 — Resolve run knobs, paths, and dataset file
# ------------------------------------------------------------
import os
from pathlib import Path

try: root
except NameError: root = Path(".")

DATASET    = globals().get("DATASET", "busi")
IMAGE_SIZE = int(globals().get("IMAGE_SIZE", 256))
SEED       = int(globals().get("SEED", 42))
BATCH_SIZE = int(globals().get("BATCH_SIZE", 8))
EPOCHS     = int(globals().get("EPOCHS", 10))
AMP_ON     = bool(globals().get("AMP_ON", True))
MODEL_TAG  = globals().get("MODEL_TAG", "TransUNetLiteTiny_model")
RUN_NAME   = globals().get("RUN_NAME", f"{MODEL_TAG}_{DATASET}_IMG{IMAGE_SIZE}_SEED{SEED}")
CONFIG_PATH = globals().get("CONFIG_PATH", "configs/example.yaml")

RESOLUTION = int(globals().get("SIZE", IMAGE_SIZE))

msb_dir = Path(os.environ.get("MEDSEGBENCH_DIR", os.path.expanduser("~/.medsegbench")))
busi_file = msb_dir / f"busi_{RESOLUTION}.npz"

print(f"[INFO] Artifacts root        : {root.resolve()}")
print(f"[INFO] MEDSEGBENCH_DIR       : {msb_dir.resolve()}")
print(f"[INFO] Expected busi file    : {busi_file}")
print(f"[INFO] Run                   : {RUN_NAME}")
print(f"[INFO] Model tag             : {MODEL_TAG}")
print(f"[INFO] Seed / ImageSize      : {SEED} / {IMAGE_SIZE}")
print(f"[INFO] Batch / Epochs / AMP  : {BATCH_SIZE} / {EPOCHS} / {AMP_ON}")


[INFO] Artifacts root        : /content/drive/MyDrive/unetr_model_busi_Test
[INFO] MEDSEGBENCH_DIR       : /content/data/MedSegBenchCache
[INFO] Expected busi file    : /content/data/MedSegBenchCache/busi_256.npz
[INFO] Run                   : unetr_model_busi_IMG256_SEED42_2025-11-04_16-43-23
[INFO] Model tag             : unetr_model
[INFO] Seed / ImageSize      : 42 / 256
[INFO] Batch / Epochs / AMP  : 8 / 10 / True


In [None]:
# ============================================================
# STEP 2 — REPRODUCIBILITY & CONFIG LOCK
# Cell 2/5 — Seed + deterministic flags (re-assert)
# ------------------------------------------------------------
import random
import numpy as np
import torch

def set_global_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    try: torch.use_deterministic_algorithms(True)
    except Exception: pass
    try:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception: pass

set_global_seed(SEED)
print(f"[OK] Seeds set and deterministic flags applied (seed={SEED}).")


[OK] Seeds set and deterministic flags applied (seed=42).


In [None]:
# ============================================================
# STEP 2 — REPRODUCIBILITY & CONFIG LOCK
# Cell 3/5 — Build default config (if missing) and load it
# ------------------------------------------------------------
import yaml
(root / "configs").mkdir(parents=True, exist_ok=True)

cfg_path = Path(CONFIG_PATH)
if not cfg_path.exists():
    cfg_path = root / "configs" / "default_busi.yaml"

default_cfg = {
    "run": {"run_name": RUN_NAME, "seed": SEED, "amp_on": AMP_ON},
    "data": {"dataset": DATASET, "resolution": RESOLUTION, "medsegbench_dir": str(msb_dir), "predefined_splits": True},
    "train": {
        "image_size": IMAGE_SIZE, "batch_size": BATCH_SIZE, "epochs": EPOCHS, "num_workers": 4,
        "optimizer": {"name": "adamw", "lr": 3e-4, "weight_decay": 1e-4},
        "scheduler": {"name": "cosine", "warmup_epochs": 5},
        "early_stopping": {"monitor": "val_dice", "patience": 20},
        "mixed_precision": AMP_ON
    },
    "augment": {
        "geometric": {"flip": True, "rotate": True, "scale": True, "elastic": False},
        "appearance": {"brightness_contrast": True, "blur_noise": True},
        "probabilities": {"flip": 0.5, "rotate": 0.3, "scale": 0.3, "brightness_contrast": 0.3, "blur_noise": 0.2}
    },
    "loss": {"primary": "dice_bce", "weights": {"dice": 0.7, "bce": 0.3}},
    "metrics": {"threshold": 0.5, "report": ["dice", "iou"]},
    "logging": {
        "artifacts_root": str(root.resolve()),
        "print_per_epoch_fields": ["epoch","lr","train_loss","val_loss","train_dice","val_dice","train_iou","val_iou","epoch_time"],
        "save_csv_per_epoch": True,
        "save_best_by": "val_dice"
    },
    "model": {"name": MODEL_TAG, "scale": "auto", "params": {}}
}

if not Path(CONFIG_PATH).exists():
    with open(cfg_path, "w", encoding="utf-8") as f:
        yaml.safe_dump(default_cfg, f, sort_keys=False)
    print(f"[INFO] Created default config at: {cfg_path.resolve()}")
else:
    cfg_path = Path(CONFIG_PATH)

with open(cfg_path, "r", encoding="utf-8") as f:
    cfg = yaml.safe_load(f)

print(f"[OK] Loaded config from: {cfg_path.resolve()}")
print(f"[INFO] Config run_name: {cfg['run'].get('run_name')}")


[INFO] Created default config at: /content/drive/MyDrive/unetr_model_busi_Test/configs/default_busi.yaml
[OK] Loaded config from: /content/drive/MyDrive/unetr_model_busi_Test/configs/default_busi.yaml
[INFO] Config run_name: unetr_model_busi_IMG256_SEED42_2025-11-04_16-43-23


In [None]:
# ============================================================
# STEP 2 — REPRODUCIBILITY & CONFIG LOCK
# Cell 4/5 — Sanity checks: dataset presence & key fields
# ------------------------------------------------------------
problems = []

if not busi_file.exists():
    problems.append(f"Missing dataset cache file: {busi_file}")

required_keys = [
    ("run", "seed"), ("data", "medsegbench_dir"), ("data", "resolution"),
    ("train", "batch_size"), ("train", "epochs"),
    ("loss", "primary"), ("metrics", "threshold"), ("model", "name"),
]
for sect, key in required_keys:
    if sect not in cfg or key not in cfg[sect]:
        problems.append(f"Config missing: {sect}.{key}")

if problems:
    print("[WARN] Sanity check issues:")
    for p in problems: print(" -", p)
else:
    print("[OK] Dataset file present and config has required keys.")

print(f"[ECHO] Using dataset cache: {busi_file}")
print(f"[ECHO] MEDSEGBENCH_DIR   : {msb_dir}")
print(f"[ECHO] Model name        : {cfg['model']['name']}")
print(f"[ECHO] Loss              : {cfg['loss']['primary']} (weights={cfg['loss'].get('weights')})")
print(f"[ECHO] Metrics threshold : {cfg['metrics']['threshold']}")


[OK] Dataset file present and config has required keys.
[ECHO] Using dataset cache: /content/data/MedSegBenchCache/busi_256.npz
[ECHO] MEDSEGBENCH_DIR   : /content/data/MedSegBenchCache
[ECHO] Model name        : unetr_model
[ECHO] Loss              : dice_bce (weights={'dice': 0.7, 'bce': 0.3})
[ECHO] Metrics threshold : 0.5


In [None]:
# ============================================================
# STEP 2 — REPRODUCIBILITY & CONFIG LOCK
# Cell 5/5 — Snapshot config for this run
# ------------------------------------------------------------
(root / "summary").mkdir(parents=True, exist_ok=True)
cfg_snapshot = root / "summary" / f"{RUN_NAME}_config.yaml"
with open(cfg_snapshot, "w", encoding="utf-8") as f:
    yaml.safe_dump(cfg, f, sort_keys=False)

print(f"[OK] Saved config snapshot to: {cfg_snapshot.resolve()}")
print("[NEXT] STEP 3 will load MedSegBench predefined splits and print sample counts per set (no re-splitting).")


[OK] Saved config snapshot to: /content/drive/MyDrive/unetr_model_busi_Test/summary/unetr_model_busi_IMG256_SEED42_2025-11-04_16-43-23_config.yaml
[NEXT] STEP 3 will load MedSegBench predefined splits and print sample counts per set (no re-splitting).


In [None]:
# ============================================================
# STEP 3 — LOAD PREDEFINED SPLITS & PRINT COUNTS (NO RE-SPLIT)
# Cell 1/4 — Resolve paths and open the cached NPZ
# ------------------------------------------------------------
import numpy as np

msb_dir = Path(os.environ.get("MEDSEGBENCH_DIR", os.path.expanduser("~/.medsegbench")))
RESOLUTION = int(globals().get("RESOLUTION", globals().get("SIZE", 256)))
busi_file = msb_dir / f"busi_{RESOLUTION}.npz"
assert busi_file.exists(), f"Expected dataset file not found: {busi_file}"

npz = np.load(busi_file, allow_pickle=True)
keys = list(npz.keys())
print(f"[INFO] Loaded: {busi_file}")
print(f"[ECHO] MEDSEGBENCH_DIR: {msb_dir.resolve()}")
print(f"[INFO] NPZ keys ({len(keys)}): {keys[:12]}{'...' if len(keys)>12 else ''}")


[INFO] Loaded: /content/data/MedSegBenchCache/busi_256.npz
[ECHO] MEDSEGBENCH_DIR: /content/data/MedSegBenchCache
[INFO] NPZ keys (18): ['train_images_C1', 'train_label_C1', 'train_images_C2', 'train_label_C2', 'test_images_C1', 'test_label_C1', 'test_images_C2', 'test_label_C2', 'val_images_C1', 'val_label_C1', 'val_images_C2', 'val_label_C2']...


In [None]:
# ============================================================
# STEP 3 — LOAD PREDEFINED SPLITS & PRINT COUNTS (NO RE-SPLIT)
# Cell 2/4 — Infer split format (supports *_label/_labels)
# ------------------------------------------------------------
def _as_list(x):
    if isinstance(x, np.ndarray): x = x.tolist()
    return list(x) if isinstance(x, (list, tuple)) else [x]

def infer_splits(npz_obj):
    k = set(npz_obj.keys())

    # Case A: per-split arrays (preferred)
    def _find_lbl_key(split):
        for suffix in ("masks","mask","labels","label"):
            cand = f"{split}_{suffix}"
            if cand in k: return cand
        return None

    have_all = True
    meta = {}
    for s in ("train","val","test"):
        ik = f"{s}_images"; lk = _find_lbl_key(s)
        if ik in k and lk in k and len(npz_obj[ik]) == len(npz_obj[lk]):
            meta[s] = {"n": len(npz_obj[ik]), "img_key": ik, "lbl_key": lk}
        else:
            have_all = False; break
    if have_all:
        counts = {s: meta[s]["n"] for s in meta}
        idx    = {s: list(range(meta[s]["n"])) for s in meta}
        return counts, idx, "A(images+labels)"

    # Case B: global arrays + explicit indices
    for tri, vai, tei in [("train_idx","val_idx","test_idx"),
                          ("train_indices","val_indices","test_indices"),
                          ("split_train","split_val","split_test")]:
        if tri in k and vai in k and tei in k:
            tr, va, te = _as_list(npz_obj[tri]), _as_list(npz_obj[vai]), _as_list(npz_obj[tei])
            counts = {"train": len(tr), "val": len(va), "test": len(te)}
            idx    = {"train": tr, "val": va, "test": te}
            return counts, idx, "B(indices)"

    raise RuntimeError("Could not infer predefined splits (need per-split arrays or *_idx lists).")

counts, split_idx, pattern = infer_splits(npz)
print(f"[OK] Split pattern detected: Case {pattern}")


[OK] Split pattern detected: Case A(images+labels)


In [None]:
# ============================================================
# STEP 3 — LOAD PREDEFINED SPLITS (SPEED-UP, OPTIONAL)
# Cell 2.5/4 — Materialize arrays to RAM and rebind `npz`
# ------------------------------------------------------------
npz_ram = {}
for k in keys:
    obj = npz[k]
    try: npz_ram[k] = obj[:] if isinstance(obj, np.ndarray) else obj
    except Exception: npz_ram[k] = obj

try: npz.close()
except Exception: pass
npz = npz_ram

def _shape(x): return getattr(x, "shape", None)
print("[SPEED] NPZ materialized to RAM. Example shapes:")
for probe in ["train_images","train_label","train_masks","val_images","val_label","test_images","test_label"]:
    if probe in npz:
        print(f"  • {probe}: {_shape(npz[probe])}")


[SPEED] NPZ materialized to RAM. Example shapes:
  • train_images: (452, 256, 256)
  • train_label: (452, 256, 256)
  • val_images: (64, 256, 256)
  • val_label: (64, 256, 256)
  • test_images: (131, 256, 256)
  • test_label: (131, 256, 256)


In [None]:
# ============================================================
# STEP 3 — LOAD PREDEFINED SPLITS & PRINT COUNTS (NO RE-SPLIT)
# Cell 3/4 — Print counts per set
# ------------------------------------------------------------
print("[COUNTS] Samples per split (predefined by MedSegBench)")
print(f"  • Train : {counts['train']}")
print(f"  • Val   : {counts['val']}")
print(f"  • Test  : {counts['test']}")


[COUNTS] Samples per split (predefined by MedSegBench)
  • Train : 452
  • Val   : 64
  • Test  : 131


In [None]:
# ============================================================
# STEP 3 — LOAD PREDEFINED SPLITS & PRINT COUNTS (NO RE-SPLIT)
# Cell 4/4 — Save IDs to disk for reproducibility
# ------------------------------------------------------------
summary_dir = Path(globals().get("root", Path("."))) / "summary"
summary_dir.mkdir(parents=True, exist_ok=True)

def write_list(path: Path, arr):
    with open(path, "w", encoding="utf-8") as f:
        for x in arr: f.write(f"{x}\n")

train_ids_path = summary_dir / f"busi_{RESOLUTION}_train_ids.txt"
val_ids_path   = summary_dir / f"busi_{RESOLUTION}_val_ids.txt"
test_ids_path  = summary_dir / f"busi_{RESOLUTION}_test_ids.txt"

write_list(train_ids_path, split_idx["train"])
write_list(val_ids_path,   split_idx["val"])
write_list(test_ids_path,  split_idx["test"])

print("[OK] Saved split ID lists:")
print(f"  • {train_ids_path}")
print(f"  • {val_ids_path}")
print(f"  • {test_ids_path}")
print("[NEXT] STEP 4 will cover preprocessing pipeline (resize/normalize) and identical augmentations.")


[OK] Saved split ID lists:
  • /content/drive/MyDrive/unetr_model_busi_Test/summary/busi_256_train_ids.txt
  • /content/drive/MyDrive/unetr_model_busi_Test/summary/busi_256_val_ids.txt
  • /content/drive/MyDrive/unetr_model_busi_Test/summary/busi_256_test_ids.txt
[NEXT] STEP 4 will cover preprocessing pipeline (resize/normalize) and identical augmentations.


In [None]:
# ============================================================
# STEP 4 — PREPROCESSING & AUGMENTATIONS (IDENTICAL POLICY)
# Cell 1/5 — Imports, constants, and NPZ reload
# ------------------------------------------------------------
import numpy as np, torch
from pathlib import Path

RESOLUTION = int(globals().get("RESOLUTION", globals().get("SIZE", 256)))
IMAGE_SIZE = int(globals().get("IMAGE_SIZE", RESOLUTION))
BATCH_SIZE = int(globals().get("BATCH_SIZE", 8))

NORM_MEAN = (0.485, 0.456, 0.406)
NORM_STD  = (0.229, 0.224, 0.225)

msb_dir   = Path(os.environ.get("MEDSEGBENCH_DIR", os.path.expanduser("~/.medsegbench")))
busi_file = msb_dir / f"busi_{RESOLUTION}.npz"
assert busi_file.exists(), f"Expected dataset file not found: {busi_file}"

npz_l = np.load(busi_file, allow_pickle=True)

def _as_list(x):
    if isinstance(x, np.ndarray): x = x.tolist()
    return list(x) if isinstance(x, (list, tuple)) else [x]

def _infer_splits(npz_obj):
    k = set(npz_obj.keys())
    def _find_lbl_key(split):
        for suffix in ("masks","mask","labels","label"):
            cand = f"{split}_{suffix}"
            if cand in k: return cand
        return None
    have_all = True
    meta = {}
    for s in ("train","val","test"):
        ik = f"{s}_images"; lk = _find_lbl_key(s)
        if ik in k and lk in k:
            n = len(npz_obj[ik])
            if n != len(npz_obj[lk]): raise RuntimeError(f"{s}: images != labels length")
            meta[s] = {"n": n, "img_key": ik, "lbl_key": lk}
        else:
            have_all = False; break
    if have_all:
        return {s: meta[s]["n"] for s in meta}, {s: list(range(meta[s]["n"])) for s in meta}
    for tri, vai, tei in [("train_idx","val_idx","test_idx"),
                          ("train_indices","val_indices","test_indices"),
                          ("split_train","split_val","split_test")]:
        if tri in k and vai in k and tei in k:
            tr = _as_list(npz_obj[tri]); va = _as_list(npz_obj[vai]); te = _as_list(npz_obj[tei])
            return {"train": len(tr), "val": len(va), "test": len(te)}, {"train": tr, "val": va, "test": te}
    raise RuntimeError("Could not re-infer splits; ensure STEP 3 ran successfully.")

if "counts" in globals() and "split_idx" in globals():
    _counts, _split_idx = counts, split_idx
else:
    _counts, _split_idx = _infer_splits(npz_l)

print(f"[INFO] Using NPZ: {busi_file}")
print(f"[INFO] Image size policy: RESOLUTION={RESOLUTION} → NETWORK INPUT={IMAGE_SIZE}")
print(f"[COUNTS] train={_counts['train']}  val={_counts['val']}  test={_counts['test']}")


[INFO] Using NPZ: /content/data/MedSegBenchCache/busi_256.npz
[INFO] Image size policy: RESOLUTION=256 → NETWORK INPUT=256
[COUNTS] train=452  val=64  test=131


In [None]:
# ============================================================
# STEP 4 — PREPROCESSING & AUGMENTATIONS
# Cell 2/5 — Albumentations transforms (train/val/test)
# ------------------------------------------------------------
import albumentations as A
from albumentations.pytorch import ToTensorV2

resize_ops = []
if IMAGE_SIZE != RESOLUTION:
    resize_ops = [A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE, interpolation=1)]  # 1=bilinear

train_tf = A.Compose([
    *resize_ops,
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.02, scale_limit=0.15, rotate_limit=15, border_mode=0, p=0.3),
    A.RandomBrightnessContrast(p=0.3),
    A.GaussianBlur(blur_limit=(3,5), p=0.15),
    A.GaussNoise(var_limit=(5.0, 20.0), p=0.15),
    A.Normalize(mean=NORM_MEAN, std=NORM_STD),
    ToTensorV2(transpose_mask=True),
])

val_tf = A.Compose([
    *resize_ops,
    A.Normalize(mean=NORM_MEAN, std=NORM_STD),
    ToTensorV2(transpose_mask=True),
])

test_tf = val_tf
print("[OK] Transforms configured (train/val/test).")


[OK] Transforms configured (train/val/test).


In [None]:
# ============================================================
# STEP 4 — PREPROCESSING & AUGMENTATIONS
# Cell 3/5 — Dataset that memory-maps the NPZ (ISICNPZDataset-style)
# ------------------------------------------------------------
import torch
from torch.utils.data import Dataset

def _label_key_for(split, npz_obj_or_keys):
    k = set(npz_obj_or_keys if isinstance(npz_obj_or_keys, (set,list,tuple)) else npz_obj_or_keys.keys())
    for suf in ("masks","mask","labels","label"):
        cand = f"{split}_{suf}"
        if cand in k: return cand
    raise KeyError(f"No label key found for split={split}.")

class ISICNPZDataset(Dataset):
    def __init__(self, npz_path, split: str, indices, transform=None):
        super().__init__()
        self.path = str(npz_path)
        _peek = np.load(self.path, allow_pickle=True, mmap_mode="r")
        self.img_key = f"{split}_images"
        self.lbl_key = _label_key_for(split, _peek)
        self.length = len(_peek[self.img_key])
        assert self.length == len(_peek[self.lbl_key]), "Images/labels length mismatch."
        _peek.close()

        self.split = split
        self.indices = list(indices)
        self.transform = transform
        self._npz = None

    def _ensure_open(self):
        if self._npz is None:
            self._npz = np.load(self.path, allow_pickle=True, mmap_mode="r")

    def __len__(self): return len(self.indices)

    def __getitem__(self, idx):
        self._ensure_open()
        i = self.indices[idx]
        img = self._npz[self.img_key][i]  # HxW or HxWx3
        msk = self._npz[self.lbl_key][i]  # HxW

        if img.ndim == 2: img = np.stack([img, img, img], axis=-1)
        if msk.max() > 1: msk = (msk > 127).astype(np.uint8)

        if self.transform is not None:
            out = self.transform(image=img, mask=msk)
            img_t = out["image"]
            msk_t = out["mask"].unsqueeze(0) if out["mask"].ndim == 2 else out["mask"]
        else:
            img_f = img.astype(np.float32) / 255.0
            img_f = (img_f - np.array(NORM_MEAN)) / np.array(NORM_STD)
            img_t = torch.from_numpy(img_f).permute(2,0,1).contiguous()
            msk_t = torch.from_numpy(msk.astype(np.float32)).unsqueeze(0)

        return img_t, msk_t

    def __del__(self):
        try:
            if self._npz is not None: self._npz.close()
        except Exception: pass


In [None]:
# ============================================================
# STEP 4 — PREPROCESSING & AUGMENTATIONS
# Cell 4/5 — DataLoaders with safe settings
# ------------------------------------------------------------
from torch.utils.data import DataLoader
import torch, os

train_ds = ISICNPZDataset(busi_file, "train", _split_idx["train"], transform=train_tf)
val_ds   = ISICNPZDataset(busi_file, "val",   _split_idx["val"],   transform=val_tf)
test_ds  = ISICNPZDataset(busi_file, "test",  _split_idx["test"],  transform=test_tf)

num_workers = 2
pin = torch.cuda.is_available()

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=num_workers, pin_memory=pin,
                          drop_last=True, prefetch_factor=2, persistent_workers=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=num_workers, pin_memory=pin,
                          drop_last=False, prefetch_factor=2, persistent_workers=True)
test_loader  = DataLoader(test_ds,  batch_size=1, shuffle=False,
                          num_workers=num_workers, pin_memory=pin,
                          drop_last=False, prefetch_factor=2, persistent_workers=True)

xb, yb = next(iter(train_loader))
print(f"[OK] Train batch shapes: images={tuple(xb.shape)} masks={tuple(yb.shape)}")
print(f"[INFO] num_workers={num_workers}, pin_memory={pin}, batch_size={BATCH_SIZE}")


[OK] Train batch shapes: images=(8, 3, 256, 256) masks=(8, 1, 256, 256)
[INFO] num_workers=2, pin_memory=False, batch_size=8


In [None]:
# ============================================================
# STEP 4 — PREPROCESSING & AUGMENTATIONS
# Cell 5/5 — Policy echo (for the paper/log)
# ------------------------------------------------------------
print("[POLICY] Preprocessing/Normalization")
print(f"  • Resize to: {IMAGE_SIZE}x{IMAGE_SIZE} (if different from NPZ {RESOLUTION})")
print(f"  • Normalize (ImageNet): mean={NORM_MEAN}, std={NORM_STD}")
print("[POLICY] Train augmentations")
print("  • HorizontalFlip p=0.5")
print("  • ShiftScaleRotate (±2% shift, ±15% scale, ±15° rotate) p=0.3")
print("  • Brightness/Contrast p=0.3")
print("  • GaussianBlur p=0.15, GaussNoise p=0.15")
print("[POLICY] Val/Test: no augmentations (only resize + normalize)")
print("[READY] DataLoaders prepared. Next: STEP 5 (Data sanity visuals).")


[POLICY] Preprocessing/Normalization
  • Resize to: 256x256 (if different from NPZ 256)
  • Normalize (ImageNet): mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
[POLICY] Train augmentations
  • HorizontalFlip p=0.5
  • ShiftScaleRotate (±2% shift, ±15% scale, ±15° rotate) p=0.3
  • Brightness/Contrast p=0.3
  • GaussianBlur p=0.15, GaussNoise p=0.15
[POLICY] Val/Test: no augmentations (only resize + normalize)
[READY] DataLoaders prepared. Next: STEP 5 (Data sanity visuals).


In [None]:
# ============================================================
# STEP 6 — CPU EVALUATION
# Cell 1/4
# ------------------------------------------------------------
# Goal:
#   • Use a pretrained ViT-B/16 encoder (timm) and a U-Net decoder.
#   • Keep helpers compatible with our pipeline (ConvBNReLU, UpBlock).
# ============================================================
import torch, torch.nn as nn, torch.nn.functional as F
import timm

# Reuse helpers if already defined elsewhere
if "ConvBNReLU" not in globals():
    class ConvBNReLU(nn.Module):
        def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
            super().__init__()
            self.conv = nn.Conv2d(in_ch, out_ch, k, s, p, bias=False)
            self.bn   = nn.BatchNorm2d(out_ch)
            self.act  = nn.ReLU(inplace=True)
        def forward(self, x): return self.act(self.bn(self.conv(x)))

if "UpBlock" not in globals():
    class UpBlock(nn.Module):
        def __init__(self, in_ch, skip_ch, out_ch):
            super().__init__()
            self.conv1 = ConvBNReLU(in_ch + skip_ch, out_ch)
            self.conv2 = ConvBNReLU(out_ch, out_ch)
        def forward(self, x, skip):
            # upsample x -> skip spatial size, then fuse
            x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
            x = torch.cat([x, skip], dim=1)
            return self.conv2(self.conv1(x))

# ============================================================
# STEP 6D — UNETR (paper-style) — Model definition
# Cell 2/4
# ------------------------------------------------------------
# Design notes:
#   • We tap 4 transformer depths (≈[3,6,9,12]) via hooks → tokens.
#   • Convert tokens to 1/16 feature maps, then 1×1 project to widths:
#         f0=96, f1=192, f2=384, f3=768  (all at 1/16 stride)
#   • Decoder:
#         l3: 768→384 (deep stream)
#         up2: (384 ⊕ 192) → 192   @ 1/16 → up to 1/8
#         up1: (192 ⊕  96) →  96   @ 1/8  → refine to 1/4 → head
#   • This fixes channel mismatches and keeps shapes explicit.
# ============================================================
class UNETR_PaperStyle(nn.Module):
    def __init__(self, img_size=256, depths=(3,6,9,12)):
        super().__init__()
        # Pretrained ViT-B/16 encoder
        self.vit = timm.create_model(
            "vit_base_patch16_224",
            pretrained=True, num_classes=0, global_pool="", img_size=img_size
        )
        self.embed_dim = getattr(self.vit, "num_features", 768)
        self.patch     = 16
        self.depths    = depths

        # 1x1 projections for the 4 tapped stages (all maps are 1/16 stride)
        self.proj = nn.ModuleList([
            nn.Conv2d(self.embed_dim,  96, 1, bias=False),  # stage @ depth≈3
            nn.Conv2d(self.embed_dim, 192, 1, bias=False),  # depth≈6
            nn.Conv2d(self.embed_dim, 384, 1, bias=False),  # depth≈9
            nn.Conv2d(self.embed_dim, 768, 1, bias=False),  # depth≈12 (deep)
        ])

        # Lateral 1x1 to normalize decoder inputs (prevents channel mismatches)
        self.l3 = nn.Conv2d(768, 384, 1, bias=False)  # deep stream → 384
        self.l2 = nn.Conv2d(192, 192, 1, bias=False)  # mid skip     → 192
        self.l1 = nn.Conv2d( 96,  96, 1, bias=False)  # shallow skip  →  96

        # Decoder top-down
        self.up2 = UpBlock(in_ch=384, skip_ch=192, out_ch=192)  # fuse @1/16 → (then we upsample to 1/8)
        self.up1 = UpBlock(in_ch=192, skip_ch= 96, out_ch= 96)  # fuse @1/8  → output @1/8

        # Refine @1/4 then go to full res
        self.refine_quarter = ConvBNReLU(96, 96)
        self.head = nn.Conv2d(96, 1, 1)

        # storage for hooked transformer outputs (tokens)
        self._feats = []
        def make_hook():
            def hook(module, inp, outp):
                self._feats.append(outp)  # (B, N or 1+N, C)
            return hook
        for d in self.depths:
            self.vit.blocks[d-1].register_forward_hook(make_hook())

    def _tokens_to_map(self, t, H, W):
        # t: (B, 1+N, C) or (B, N, C)  → (B, C, H/16, W/16)
        if t.size(1) == (H//self.patch)*(W//self.patch) + 1:
            t = t[:, 1:, :]
        B, N, C = t.shape
        gh, gw = H // self.patch, W // self.patch
        return t.transpose(1, 2).contiguous().view(B, C, gh, gw)

    def forward(self, x):
        B, _, H, W = x.shape
        self._feats.clear()

        _ = self.vit.forward_features(x)  # triggers hooks at depths
        assert len(self._feats) == 4, "UNETR taps not captured (expected 4)."

        # Project all tapped token maps to 1/16 spatial features
        fmap_1_16 = [ self._tokens_to_map(t, H, W) for t in self._feats ]  # all (B, C, H/16, W/16)
        f0 = self.proj[0](fmap_1_16[0])   # 96   @1/16
        f1 = self.proj[1](fmap_1_16[1])   # 192  @1/16
        f2 = self.proj[2](fmap_1_16[2])   # 384  @1/16
        f3 = self.proj[3](fmap_1_16[3])   # 768  @1/16  (deep)

        # Decoder
        x3 = self.l3(f3)                  # 384 @1/16
        d2 = self.up2(x3, self.l2(f1))    # 192 @1/16
        d2 = F.interpolate(d2, scale_factor=2, mode="bilinear", align_corners=False)   # → 1/8

        skip1 = F.interpolate(self.l1(f0), scale_factor=2, mode="bilinear", align_corners=False)  # 96 @1/8
        d1 = self.up1(d2, skip1)          # 96 @1/8

        # refine at 1/4, then to H
        d1_quarter = self.refine_quarter(F.interpolate(d1, scale_factor=2, mode="bilinear", align_corners=False))  # 1/4
        y = F.interpolate(d1_quarter, scale_factor=4, mode="bilinear", align_corners=False)  # → H
        logits = self.head(y)
        return {"logits": logits}


In [None]:
# ============================================================
# STEP 6 — CPU EVALUATION
# Cell 2/4 — CPU fairness + CKPT path + preload 50 RAW test samples
# (SAFE: avoid set_num_interop_threads error if threads already started)
# ------------------------------------------------------------
import os, psutil, pandas as pd
from datetime import datetime

# ---- CPU fairness (as requested) ----
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"]  = "1"

import torch

# Always safe at runtime:
torch.set_num_threads(1)

# set_num_interop_threads must be called before any parallel work; try, else continue.
try:
    if hasattr(torch, "set_num_interop_threads"):
        torch.set_num_interop_threads(1)
except RuntimeError as e:
    # Already started parallel work; keep current interop setting but log it.
    print(f"[WARN] {e} — continuing with interop_threads={getattr(torch, 'get_num_interop_threads', lambda: 'N/A')()}")

DEVICE = "cpu"
print(f"[CPU] threads={torch.get_num_threads()} "
      f"interop={getattr(torch, 'get_num_interop_threads', lambda: 'N/A')()} "
      f"OMP={os.getenv('OMP_NUM_THREADS')} MKL={os.getenv('MKL_NUM_THREADS')}")

# ---- threshold from cfg (fallback 0.5) ----
THRESH = float(cfg.get("metrics", {}).get("threshold", 0.5)) if "cfg" in globals() else 0.5
print(f"[INFO] THRESH={THRESH}")

# ---- your exact trained checkpoint path ----
from pathlib import Path
CKPT_PATH = "/content/drive/MyDrive/unetr_model_busi_IMG256_SEED42_2025-11-04_13-32-50_best.pt"  # <<< EDIT THIS
if not Path(CKPT_PATH).exists():
    raise FileNotFoundError(f"Checkpoint not found: {CKPT_PATH}")

# ---- build a 50-sample test index (first 50 of split) ----
NUM_SAMPLES = 50
WARMUP = 5
test_ids = split_idx["test"]
if len(test_ids) < NUM_SAMPLES:
    raise ValueError(f"Test set has {len(test_ids)} samples; need at least {NUM_SAMPLES}.")
sel_ids = test_ids[:NUM_SAMPLES]

# ---- preload RAW arrays (avoid disk I/O in timing) ----
# Note: transforms/tensorization are INSIDE timing for end-to-end latency.
test_img_key = "test_images"
for lblk in ("test_masks","test_mask","test_labels","test_label"):
    if lblk in npz: test_lbl_key = lblk; break
else:
    raise KeyError("No test label key among {test_masks, test_mask, test_labels, test_label}.")

raw_samples = []
for i in sel_ids:
    img = npz[test_img_key][i]
    msk = npz[test_lbl_key][i]
    raw_samples.append((img, msk))
print(f"[OK] Preloaded RAW {len(raw_samples)} test samples into RAM.")


[CPU] threads=1 interop=1 OMP=1 MKL=1
[INFO] THRESH=0.5
[OK] Preloaded RAW 50 test samples into RAM.


In [None]:
# ============================================================
# STEP 6 — CPU EVAL
# Cell 3/4 (REPLACED) — Instantiate model + CLEAN LOAD (CPU)
# ------------------------------------------------------------
from pathlib import Path
import re, torch

IMAGE_SIZE = int(globals().get("IMAGE_SIZE", 256))
model = UNETR_PaperStyle(img_size=IMAGE_SIZE).to("cpu")

ckpt = torch.load(CKPT_PATH, map_location="cpu")

def unwrap_state(d):
    # common wrappers
    for k in ["model_state","state_dict","model","net","ema","model_state_dict"]:
        if isinstance(d, dict) and k in d and isinstance(d[k], dict):
            return d[k]
    return d if isinstance(d, dict) else {}

raw = unwrap_state(ckpt)

# 1) strip 'module.' prefix
clean = {}
for k, v in raw.items():
    nk = k[7:] if k.startswith("module.") else k
    clean[nk] = v

# 2) drop profiling buffers (from thop/fvcore)
def is_profile_key(k: str) -> bool:
    return k.endswith(".total_ops") or k.endswith(".total_params")
clean = {k:v for k,v in clean.items() if not is_profile_key(k)}

# 3) optional prefix remaps (adapt if your training used different names)
#    e.g., 'backbone.' -> 'vit.' , 'encoder.' -> 'vit.' , 'transformer.' -> 'vit.'
remaps = [
    (r"^backbone\.", "vit."),
    (r"^encoder\.",  "vit."),
    (r"^transformer\.", "vit."),
]
remapped = {}
for k, v in clean.items():
    nk = k
    for pat, rep in remaps:
        nk = re.sub(pat, rep, nk)
    remapped[nk] = v
clean = remapped

# 4) keep only keys that exist in current model (exact-name intersection)
model_sd = model.state_dict()
intersect = {k:v for k,v in clean.items() if k in model_sd and v.shape == model_sd[k].shape}

# 5) report coverage
print(f"[CKPT] total keys in state: {len(raw)}")
print(f"[CKPT] after strip+drop:     {len(clean)}")
print(f"[CKPT] intersect (name+shape): {len(intersect)} / model expects {len(model_sd)}")

# 6) load intersect only (others remain as initialized / DeiT pretrained)
missing_before = set(model_sd.keys()) - set(intersect.keys())
load_res = model_sd.copy()
load_res.update(intersect)
model.load_state_dict(load_res, strict=False)

# sanity: print what we still miss (first 40)
still_missing = list(set(model.state_dict().keys()) - set(intersect.keys()))
print(f"[LOAD] missing (after clean) ~ {len(still_missing)} (showing up to 40)")
print(still_missing[:40])

model.eval()
print("[OK] Cleaned checkpoint loaded into model on CPU.")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

[CKPT] total keys in state: 559
[CKPT] after strip+drop:     191
[CKPT] intersect (name+shape): 189 / model expects 189
[LOAD] missing (after clean) ~ 0 (showing up to 40)
[]
[OK] Cleaned checkpoint loaded into model on CPU.


In [None]:
from pathlib import Path
import torch, re

ckpt = torch.load(CKPT_PATH, map_location="cpu")

def _unwrap_state(d):
    # try common containers
    for k in ["state_dict","model","net","ema","model_state"]:
        if isinstance(d, dict) and k in d and isinstance(d[k], dict):
            return d[k]
    return d if isinstance(d, dict) else {}

raw_state = _unwrap_state(ckpt)

print(f"[CKPT] top-level keys: {list(ckpt.keys())[:20]}")
print(f"[CKPT] state len: {len(raw_state)}")

# Show a few parameter names to identify architecture/backbone
sample_keys = list(raw_state.keys())[:40]
print("[CKPT] sample param keys:")
for k in sample_keys:
    print("  ", k)

# Also show model keys to compare
model_keys = list(model.state_dict().keys())
print(f"[MODEL] expects {len(model_keys)} tensors")
print("[MODEL] sample expected keys:")
for k in model_keys[:40]:
    print("  ", k)


[CKPT] top-level keys: ['epoch', 'model_state', 'optimizer_state', 'val_dice', 'val_loss', 'cfg']
[CKPT] state len: 559
[CKPT] sample param keys:
   total_ops
   total_params
   vit.cls_token
   vit.pos_embed
   vit.total_ops
   vit.total_params
   vit.patch_embed.total_ops
   vit.patch_embed.total_params
   vit.patch_embed.proj.weight
   vit.patch_embed.proj.bias
   vit.patch_embed.norm.total_ops
   vit.patch_embed.norm.total_params
   vit.patch_drop.total_ops
   vit.patch_drop.total_params
   vit.norm_pre.total_ops
   vit.norm_pre.total_params
   vit.blocks.0.total_ops
   vit.blocks.0.total_params
   vit.blocks.0.norm1.weight
   vit.blocks.0.norm1.bias
   vit.blocks.0.norm1.total_ops
   vit.blocks.0.norm1.total_params
   vit.blocks.0.attn.total_ops
   vit.blocks.0.attn.total_params
   vit.blocks.0.attn.qkv.weight
   vit.blocks.0.attn.qkv.bias
   vit.blocks.0.attn.q_norm.total_ops
   vit.blocks.0.attn.q_norm.total_params
   vit.blocks.0.attn.k_norm.total_ops
   vit.blocks.0.attn.k_nor

In [None]:
missing, unexpected = model.load_state_dict(raw_state, strict=False)
print(f"[DIFF] missing ({len(missing)}):")
print(missing[:50])
print(f"[DIFF] unexpected ({len(unexpected)}):")
print(unexpected[:50])


[DIFF] missing (0):
[]
[DIFF] unexpected (370):
['total_ops', 'total_params', 'vit.total_ops', 'vit.total_params', 'vit.patch_embed.total_ops', 'vit.patch_embed.total_params', 'vit.patch_embed.norm.total_ops', 'vit.patch_embed.norm.total_params', 'vit.patch_drop.total_ops', 'vit.patch_drop.total_params', 'vit.norm_pre.total_ops', 'vit.norm_pre.total_params', 'vit.blocks.0.total_ops', 'vit.blocks.0.total_params', 'vit.blocks.0.norm1.total_ops', 'vit.blocks.0.norm1.total_params', 'vit.blocks.0.attn.total_ops', 'vit.blocks.0.attn.total_params', 'vit.blocks.0.attn.q_norm.total_ops', 'vit.blocks.0.attn.q_norm.total_params', 'vit.blocks.0.attn.k_norm.total_ops', 'vit.blocks.0.attn.k_norm.total_params', 'vit.blocks.0.attn.norm.total_ops', 'vit.blocks.0.attn.norm.total_params', 'vit.blocks.0.ls1.total_ops', 'vit.blocks.0.ls1.total_params', 'vit.blocks.0.drop_path1.total_ops', 'vit.blocks.0.drop_path1.total_params', 'vit.blocks.0.norm2.total_ops', 'vit.blocks.0.norm2.total_params', 'vit.blocks.

In [None]:
# ============================================================
# STEP 6 — CPU EVALUATION
# Cell 4/4 — Warm-up, timed run, metrics, save CSV+JSON
# (FIX: added `import math`)
# ------------------------------------------------------------
import time, numpy as np, math, json, os, psutil, pandas as pd
from datetime import datetime
from pathlib import Path

# ---- helpers ----
def apply_test_transform(img, msk):
    # Ensure 3-ch image & binary mask; then apply test_tf (resize+norm+tensor)
    if img.ndim == 2: img = np.stack([img, img, img], axis=-1)
    if msk.max() > 1: msk = (msk > 127).astype(np.uint8)
    out = test_tf(image=img, mask=msk)  # includes resize+normalize
    x = out["image"]                           # [C,H,W] float32
    y = out["mask"].unsqueeze(0) if out["mask"].ndim==2 else out["mask"]  # [1,H,W]
    return x, y

def binarize(prob, thr): return (prob >= thr).astype(np.uint8)

def dice_iou(pred, mask, eps=1e-7):
    inter = (pred & mask).sum()
    union = (pred | mask).sum()
    dice = (2*inter + eps) / (pred.sum() + mask.sum() + eps)
    iou  = (inter + eps) / (union + eps)
    return float(dice), float(iou)

# ---- warm-up (excluded) ----
with torch.no_grad():
    for i in range(WARMUP):
        img, msk = raw_samples[i]
        x, y = apply_test_transform(img, msk)
        _ = model(x.unsqueeze(0))

# ---- timed run ----
lat_ms, dices, ious = [], [], []
cpu_hist, ram_hist  = [], []
proc = psutil.Process(os.getpid())
t_start = time.perf_counter()

with torch.no_grad():
    for i in range(WARMUP, NUM_SAMPLES):
        img, msk = raw_samples[i]

        t0 = time.perf_counter()
        x, y = apply_test_transform(img, msk)       # include resize+normalize in timing
        out = model(x.unsqueeze(0))                 # forward
        logits = out["logits"] if isinstance(out, dict) and "logits" in out else out
        prob = torch.sigmoid(logits).cpu().numpy().squeeze()
        pred = binarize(prob, THRESH)
        gt   = (y.squeeze().cpu().numpy() > 0.5).astype(np.uint8)

        d, j = dice_iou(pred, gt)
        t1 = time.perf_counter()

        dices.append(d); ious.append(j)
        lat_ms.append((t1 - t0) * 1000.0)
        cpu_hist.append(psutil.cpu_percent(interval=None))
        ram_hist.append(proc.memory_info().rss)

t_end = time.perf_counter()

# ---- summary ----
def pct(vals, p):
    if not vals: return float('nan')
    a = sorted(vals)
    k = (len(a)-1)*(p/100.0)
    f,c = math.floor(k), math.ceil(k)
    return a[int(k)] if f==c else a[f]*(c-k)+a[c]*(k-f)

n = len(lat_ms)
fps = (NUM_SAMPLES - WARMUP) / (t_end - t_start) if (t_end - t_start) > 0 else float('nan')
summary = dict(
    dice_mean=float(np.mean(dices)) if dices else float('nan'),
    iou_mean=float(np.mean(ious)) if ious else float('nan'),
    lat_median_ms=float(np.median(lat_ms)) if n else float('nan'),
    lat_p90_ms=float(pct(lat_ms, 90)),
    lat_p95_ms=float(pct(lat_ms, 95)),
    lat_min_ms=float(np.min(lat_ms)) if n else float('nan'),
    lat_max_ms=float(np.max(lat_ms)) if n else float('nan'),
    wall_time_s=float(t_end - t_start),
    fps=float(fps),
    peak_ram_mb=float(max(ram_hist)/(1024*1024)) if ram_hist else float('nan'),
    cpu_mean_pct=float(np.mean(cpu_hist)) if cpu_hist else float('nan'),
    threshold=float(THRESH),
    samples=int(NUM_SAMPLES - WARMUP),
    threads=dict(
        torch_num_threads=torch.get_num_threads(),
        torch_num_interop=torch.get_num_interop_threads() if hasattr(torch, "get_num_interop_threads") else "N/A",
        OMP_NUM_THREADS=os.getenv("OMP_NUM_THREADS"),
        MKL_NUM_THREADS=os.getenv("MKL_NUM_THREADS"),
    ),
    ckpt=CKPT_PATH,
    data=str(busi_file),
)

print(json.dumps(summary, indent=2))

# ---- save artifacts ----
outdir = Path("./cpu_eval_runs"); outdir.mkdir(exist_ok=True, parents=True)
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
pd.DataFrame({"idx": list(range(n)), "latency_ms": lat_ms, "dice": dices, "iou": ious}).to_csv(outdir / f"per_image_{stamp}.csv", index=False)
with open(outdir / f"summary_{stamp}.json","w") as f: json.dump(summary, f, indent=2)
print("Saved:", outdir / f"per_image_{stamp}.csv", "|", outdir / f"summary_{stamp}.json")


{
  "dice_mean": 0.6293409686141899,
  "iou_mean": 0.5180267852453141,
  "lat_median_ms": 717.4813509999467,
  "lat_p90_ms": 883.691020999936,
  "lat_p95_ms": 899.2539020000777,
  "lat_min_ms": 692.3473820000936,
  "lat_max_ms": 961.0881739999968,
  "wall_time_s": 33.41650012899993,
  "fps": 1.346640127669969,
  "peak_ram_mb": 4210.1015625,
  "cpu_mean_pct": 59.57777777777779,
  "threshold": 0.5,
  "samples": 45,
  "threads": {
    "torch_num_threads": 1,
    "torch_num_interop": 1,
    "OMP_NUM_THREADS": "1",
    "MKL_NUM_THREADS": "1"
  },
  "ckpt": "/content/drive/MyDrive/unetr_model_busi_IMG256_SEED42_2025-11-04_13-32-50_best.pt",
  "data": "/content/data/MedSegBenchCache/busi_256.npz"
}
Saved: cpu_eval_runs/per_image_20251104_165103.csv | cpu_eval_runs/summary_20251104_165103.json


In [None]:
import pandas as pd; from pathlib import Path
p = sorted(Path("cpu_eval_runs").glob("per_image_*.csv"))[-1]
df = pd.read_csv(p)
print(df.describe(percentiles=[.5,.9,.95]))
print("Top 5 slowest:\n", df.sort_values("latency_ms", ascending=False).head())


             idx  latency_ms          dice           iou
count  45.000000   45.000000  4.500000e+01  4.500000e+01
mean   22.000000  742.264036  6.293410e-01  5.180268e-01
std    13.133926   68.089516  2.947477e-01  2.834800e-01
min     0.000000  692.347382  4.968203e-12  4.968203e-12
50%    22.000000  717.481351  7.549119e-01  6.063120e-01
90%    39.600000  883.691021  9.087899e-01  8.328336e-01
95%    41.800000  899.253902  9.162268e-01  8.454083e-01
max    44.000000  961.088174  9.529537e-01  9.101351e-01
Top 5 slowest:
     idx  latency_ms      dice       iou
27   27  961.088174  0.842170  0.727369
11   11  908.041226  0.754912  0.606312
44   44  899.701846  0.873984  0.776174
12   12  897.462126  0.841581  0.726491
43   43  883.790197  0.460809  0.299384
