In [None]:
!nvidia-smi

## image-only / text-only training + validation + test-inference stack

Supports MOD.img and MOD.txt.

Trains with class-balanced sampling for image data

Saves the checkpoint named (best_<NETWORK>_<MODALITY>_val_loss.pt), selected by highest validation MCC.

Writes best_*_val_loss_log.json that includes best_val_mcc, best_epoch, etc.

Test inference recursively scans under cfg.out_dir for all model-aware checkpoints of the chosen stem, reads each run’s log to get best_val_mcc, selects the highest-MCC checkpoint (ties → newest file), prints why, and evaluates on the test set.


In [None]:
from __future__ import annotations

# =========================
# Standard library imports
# =========================
import ast
import glob
import inspect
import io
import json
import logging
import math
import os
import pathlib
import random
import re
import shutil
import sys
import time
import warnings
import dataclasses
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

# =========================
# Third-party imports
# =========================
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np
import pandas as pd
from PIL import Image
import sklearn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Sampler, WeightedRandomSampler
from torchvision import transforms as T
from transformers import AutoModel, AutoTokenizer
from pytorch_grad_cam import (
    AblationCAM,
    EigenCAM,
    EigenGradCAM,
    GradCAM,
    GradCAMPlusPlus,
    HiResCAM,
    LayerCAM,
    ScoreCAM,
    XGradCAM,
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# Optional dependency
try:
    import timm
    _HAS_TIMM = True
except Exception:
    timm = None
    _HAS_TIMM = False


In [None]:
#set environmental variables

print(torch.__version__)
print(sys.executable)
print(sys.version)

os.environ["HF_HOME"] = "/gpfs/gsfs12/users/rajaramans2/.cache/huggingface"
os.environ["HF_HUB_CACHE"] = "/gpfs/gsfs12/users/rajaramans2/.cache/huggingface/hub"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"

In [None]:
# Configuration

class MOD(Enum):
    img = auto()
    txt = auto()

class NETWORK(Enum):
    # timm / torchvision families
    dpn68 = auto()
    coatnet0 = auto()
    convnext_nano = auto()
    hrnet32 = auto()
    resnet18 = auto()
    densenet121 = auto()
    mobilenet_v2 = auto()
    # NEW: VGG family
    vgg11 = auto()
    vgg13 = auto()
    vgg16 = auto()
    vgg19 = auto()

@dataclass
class RunConfig:
    modality: MOD
    network: NETWORK
    n_classes: int = 2
    hidden_dim: int = 256
    epochs: int = 64
    lr: float = 5e-5
    weight_decay: float = 1e-4
    img_size: int = 224
    batch_size: int = 64
    num_workers: int = 2
    pin_memory: bool = True

    # class names
    class_names: Tuple[str, ...] = ("normal", "tb")

    # data locations
    dataset_root: str = "/dataset"
    csv_train: Optional[str] = None
    csv_valid: Optional[str] = None
    csv_test:  Optional[str] = None
    images_subdir: str = "images"
    reports_subdir: str = "reports"
    out_dir: str = "/models"

    # text/HF (offline-first)
    hf_offline: bool = True
    hf_local_dir: Optional[str] = "/huggingface/hub/models--microsoft--BiomedVLP-CXR-BERT-general"
    hf_text_model_name: str = "microsoft/BiomedVLP-CXR-BERT-general"
    max_tokens: int = 512
    hf_tokenizer_local_dir: Optional[str] = "/huggingface/hub/models--microsoft--BiomedVLP-CXR-BERT-general"

    def __post_init__(self):
        if self.csv_train is None:
            self.csv_train = os.path.join(self.dataset_root, "label_train.csv")
        if self.csv_valid is None:
            self.csv_valid = os.path.join(self.dataset_root, "label_valid.csv")
        if self.csv_test is None:
            self.csv_test = os.path.join(self.dataset_root, "label_test.csv")

    @property
    def images_dir(self) -> str:
        return os.path.join(self.dataset_root, self.images_subdir)

    @property
    def reports_dir(self) -> str:
        return os.path.join(self.dataset_root, self.reports_subdir)

def ckpt_stem(cfg: RunConfig) -> str:
    return f"{cfg.network.name}_{cfg.modality.name}"

def ckpt_paths(cfg: RunConfig) -> Dict[str, str]:
    stem = ckpt_stem(cfg)
    base = cfg.out_dir
    return {
        "best_pt":         os.path.join(base, f"best_{stem}_val_loss.pt"),
        "curves_png":      os.path.join(base, f"best_{stem}_val_loss_curves.png"),
        "curves_csv":      os.path.join(base, f"best_{stem}_val_loss_history.csv"),
        "curves_png_full": os.path.join(base, f"best_{stem}_val_loss_curves_full.png"),
        "curves_csv_full": os.path.join(base, f"best_{stem}_val_loss_history_full.csv"),
        "log_json":        os.path.join(base, f"best_{stem}_val_loss_log.json"),
        "mcc_png":         os.path.join(base, f"best_{stem}_val_mcc_curve.png"),
    }

# running configuration
cfg = RunConfig(
    modality=MOD.img, # or MOD.txt
    network=NETWORK.vgg11,
    n_classes=2,
    hidden_dim=256,
    epochs=64,
    lr=5e-5,
    weight_decay=1e-4,
    img_size=224,
    batch_size=64,
    num_workers=2,
    pin_memory=True,
    dataset_root="/dataset",
    csv_train="/label_train.csv",
    csv_valid="/label_valid.csv",
    csv_test ="/label_test.csv",
    images_subdir="images",
    reports_subdir="reports",
    out_dir="/models",
    class_names=("normal","tb"),
    hf_text_model_name="microsoft/BiomedVLP-CXR-BERT-general",
    hf_local_dir="/huggingface/hub/models--microsoft--BiomedVLP-CXR-BERT-general",
    hf_offline=True
)

In [None]:
# CSV dataset & loaders (Albumentations for image)

def _read_csv_two_cols_no_header(path: str) -> Tuple[List[str], List[int]]:
    df = pd.read_csv(path, header=None)
    if df.shape[1] < 2:
        raise ValueError(f"{path} must have at least 2 columns (img,label)")
    imgs = df.iloc[:, 0].astype(str).tolist()
    labs = df.iloc[:, 1].astype(int).tolist()
    return imgs, labs

def _abs_image_path(dataset_root: str, images_subdir: str, fname: str) -> Path:
    return Path(dataset_root) / images_subdir / fname

def _abs_report_path(dataset_root: str, reports_subdir: str, fname: str) -> Path:
    # prefer <stem>.txt inside reports_subdir, else a direct filename
    p = Path(dataset_root) / reports_subdir / (Path(fname).stem + ".txt")
    if p.exists(): return p
    q = Path(dataset_root) / reports_subdir / fname
    return q

def _one_hot(idx: int, n_classes: int) -> torch.Tensor:
    y = torch.zeros(n_classes, dtype=torch.float32)
    if 0 <= idx < n_classes: y[idx] = 1.0
    return y

class CSVImageOrTextDataset(Dataset):
    """
    - For MOD.img: reads grayscale → RGB with Albumentations, Normalizes (ImageNet), returns {"image", "y_onehot", "filename"}
    - For MOD.txt: reads report text and tokenizes, returns {"text": {"input_ids","attention_mask"}, "y_onehot", "filename"}
    """
    def __init__(self, cfg: RunConfig, which: str, modality: MOD,
                 n_classes: int, tokenizer: Optional[AutoTokenizer] = None,
                 max_len: int = 192) -> None:
        assert modality in (MOD.img, MOD.txt), "This dataset only supports image-only OR text-only."
        self.cfg = cfg
        self.modality = modality
        self.n_classes = int(n_classes)
        self.tokenizer = tokenizer
        self.max_len = int(max_len)
        csv_path = {"train": cfg.csv_train, "valid": cfg.csv_valid, "test": cfg.csv_test}[which]
        img_names, labels = _read_csv_two_cols_no_header(csv_path)

        self.items: List[Tuple[Path, int, Optional[Path]]] = []
        for fn, y in zip(img_names, labels):
            ip = _abs_image_path(cfg.dataset_root, cfg.images_subdir, fn)
            rp = _abs_report_path(cfg.dataset_root, cfg.reports_subdir, fn)
            if modality == MOD.img and not ip.exists():
                continue
            if modality == MOD.txt and not rp.exists():
                continue
            self.items.append((ip, int(y), rp if rp.exists() else None))

        size = int(cfg.img_size)
        self.train_tfm = A.Compose([
            A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.5),
            A.Resize(size, size, interpolation=cv2.INTER_AREA),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])
        self.eval_tfm = A.Compose([
            A.Resize(size, size, interpolation=cv2.INTER_AREA),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])
        self.tfm = self.train_tfm if which == "train" else self.eval_tfm
        self.return_paths = (which == "test")

    def __len__(self) -> int:
        return len(self.items)

    def _read_img(self, p: Path) -> torch.Tensor:
        img = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
        if img is None:
            img = np.zeros((self.cfg.img_size, self.cfg.img_size), dtype=np.uint8)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        img = self.tfm(image=img)["image"]
        return img

    def _read_text(self, rp: Optional[Path]) -> Dict[str, torch.Tensor]:
        if self.tokenizer is None:
            raise RuntimeError("Tokenizer is None but MOD.txt was requested.")
        text = ""
        if rp is not None:
            try:
                text = rp.read_text(encoding="utf-8", errors="ignore").strip()
            except Exception:
                text = ""
        if not text:
            text = "[EMPTY]"
        tok = self.tokenizer(
            text, padding="max_length", truncation=True,
            max_length=self.max_len, return_tensors="pt"
        )
        return {"input_ids": tok["input_ids"][0], "attention_mask": tok["attention_mask"][0]}

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        ip, y_idx, rp = self.items[idx]
        y = _one_hot(y_idx, self.n_classes)
        out: Dict[str, torch.Tensor] = {"y_onehot": y, "filename": ip.name}
        if self.modality == MOD.img:
            out["image"] = self._read_img(ip)
        else:  # MOD.txt
            out["text"]  = self._read_text(rp)
        return out

In [None]:
# Tokenizer, loader builders, imbalanced sampler

# --- robust CSV reader (headerless, keeps first data row) ---
def _read_split_csv_headerless(path: str) -> pd.DataFrame:
    """
    Read a split CSV that may or may not have a header.
    Returns DataFrame with columns ['img','label'] and integer labels.
    """
    if not os.path.isfile(path):
        raise FileNotFoundError(f"CSV not found: {path}")

    # Always header=None so we never drop the first row
    try:
        df = pd.read_csv(path, header=None, sep=None, engine="python", encoding="utf-8-sig")
    except Exception:
        df = pd.read_csv(path, header=None, encoding="utf-8-sig")

    if df.shape[1] < 2:
        raise ValueError(f"{os.path.basename(path)} must have ≥2 columns (filename, label).")

    df = df.iloc[:, :2].copy()
    df.columns = ["img", "label"]

    # Detect and drop a header row if present
    def _is_hdr_token(s: str) -> bool:
        s = (str(s) or "").strip().lower()
        return s in {"img", "image", "filename", "file", "path", "label", "class", "target", "y"} or s.startswith("img")

    if _is_hdr_token(df.iloc[0, 0]) and _is_hdr_token(df.iloc[0, 1]):
        df = df.iloc[1:].reset_index(drop=True)

    df["img"] = df["img"].astype(str).str.strip()
    df["label"] = pd.to_numeric(df["label"], errors="coerce").fillna(0).astype(int)
    return df

# compact per-split summary
def _print_split_summary(cfg: "RunConfig", split: str, *, modality: "MOD") -> None:
    csv_path = {
        "train": getattr(cfg, "csv_train", None),
        "valid": getattr(cfg, "csv_valid", None),
        "test":  getattr(cfg, "csv_test",  None),
    }[split]
    if not csv_path:
        print(f"[{split}] CSV path not provided.", flush=True)
        return

    try:
        df = _read_split_csv_headerless(csv_path)
    except Exception as e:
        print(f"[{split}] ERROR reading CSV: {e}", flush=True)
        return

    total_csv = len(df)

    # Build absolute paths and check existence depending on modality
    img_dir = os.path.join(cfg.dataset_root, getattr(cfg, "images_subdir", "images"))
    txt_dir = os.path.join(cfg.dataset_root, getattr(cfg, "reports_subdir", "reports"))

    df["img_path"] = df["img"].apply(lambda x: os.path.join(img_dir, x))
    df["txt_path"] = df["img"].apply(lambda x: os.path.join(txt_dir, os.path.splitext(x)[0] + ".txt"))

    if modality == MOD.img:
        mask = df["img_path"].apply(os.path.isfile)
        kept = int(mask.sum()); missing_img = int((~mask).sum())
        df_use = df[mask].copy()
        mod_str = "images only"
        miss_str = f"Missing images={missing_img}"
    elif modality == MOD.txt:
        mask = df["txt_path"].apply(os.path.isfile)
        kept = int(mask.sum()); missing_txt = int((~mask).sum())
        df_use = df[mask].copy()
        mod_str = "texts only"
        miss_str = f"Missing texts={missing_txt}"
    else:
        mask = (df["img_path"].apply(os.path.isfile)) & (df["txt_path"].apply(os.path.isfile))
        kept = int(mask.sum())
        missing_img = int((~df["img_path"].apply(os.path.isfile)).sum())
        missing_txt = int((~df["txt_path"].apply(os.path.isfile)).sum())
        df_use = df[mask].copy()
        mod_str = "require image & text"
        miss_str = f"Missing images={missing_img}; Missing texts={missing_txt}"

    # Class counts (in the usable subset)
    cls_counts: Dict[int, int] = df_use["label"].value_counts().sort_index().to_dict()
    cls_str = " ".join([f"{k}={v}" for k, v in cls_counts.items()]) if cls_counts else "(none)"

    # Show a few filename→label pairs to verify mapping
    head_n = min(5, kept)
    if head_n > 0:
        examples = "\n".join([f"      • {r['img']}  →  {int(r['label'])}"
                              for _, r in df_use.head(head_n).iterrows()])
    else:
        examples = "(no usable rows)"

    print(f"[{split}] CSV rows={total_csv} | kept={kept} ({mod_str}). {miss_str}.", flush=True)
    print(f"[{split}] label counts: {cls_str}", flush=True)
    print(f"[{split}] first {head_n} examples:\n{examples}\n", flush=True)

def build_tokenizer(cfg: "RunConfig") -> Optional[AutoTokenizer]:
    if cfg.modality != MOD.txt:
        return None
    offline = bool(getattr(cfg, "hf_offline", True))
    local_tok = getattr(cfg, "hf_tokenizer_local_dir", None)
    local_mod = getattr(cfg, "hf_local_dir", None)
    model_id = getattr(cfg, "hf_text_model_name", "microsoft/BiomedVLP-CXR-BERT-general")

    def _try_dir(p):
        if p and os.path.isdir(p):
            try:
                return AutoTokenizer.from_pretrained(p, local_files_only=True, use_fast=True)
            except Exception:
                return None
        return None

    tok = _try_dir(local_tok) or _try_dir(local_mod)
    if tok is not None:
        return tok
    try:
        return AutoTokenizer.from_pretrained(model_id, local_files_only=True, use_fast=True)
    except Exception:
        pass
    if offline:
        raise RuntimeError("Offline and no local tokenizer snapshot found.")
    return AutoTokenizer.from_pretrained(model_id, use_fast=True)

class ImbalancedDatasetSampler(Sampler[int]):
    def __init__(self, labels: torch.Tensor):
        self.labels = labels.clone().cpu().long()
        k = int(self.labels.max().item()) + 1 if self.labels.numel() > 0 else 2
        counts = torch.bincount(self.labels, minlength=k).clamp_min(1)
        weights = (1.0 / counts)[self.labels]
        self.sample_weights = weights.double()
        self.num_samples = len(self.labels)

    def __iter__(self):
        idx = torch.multinomial(self.sample_weights, num_samples=self.num_samples, replacement=True)
        return iter(idx.tolist())

    def __len__(self):
        return self.num_samples

def make_loaders(cfg: "RunConfig") -> Tuple[DataLoader, DataLoader, DataLoader]:
    # Print summaries *before* constructing datasets
    _print_split_summary(cfg, "train", modality=cfg.modality)
    _print_split_summary(cfg, "valid", modality=cfg.modality)
    _print_split_summary(cfg, "test",  modality=cfg.modality)

    # Build tokenizer if needed
    tok = build_tokenizer(cfg)

    # Build datasets
    ds_tr = CSVImageOrTextDataset(cfg, "train", cfg.modality, cfg.n_classes, 
                                  tokenizer=tok, max_len=int(cfg.max_tokens))
    ds_va = CSVImageOrTextDataset(cfg, "valid", cfg.modality, cfg.n_classes, 
                                  tokenizer=tok, max_len=int(cfg.max_tokens))
    ds_te = CSVImageOrTextDataset(cfg, "test",  cfg.modality, cfg.n_classes, 
                                  tokenizer=tok, max_len=int(cfg.max_tokens))

    # Build loaders (class-balanced sampling for image branch)
    if cfg.modality == MOD.img:
        df_tr = _read_split_csv_headerless(cfg.csv_train)
        labels_tr = torch.as_tensor(df_tr["label"].values, dtype=torch.long)
        sampler = ImbalancedDatasetSampler(labels_tr)
        dl_tr = DataLoader(
            ds_tr, batch_size=int(cfg.batch_size), sampler=sampler, shuffle=False,
            num_workers=int(cfg.num_workers), pin_memory=bool(cfg.pin_memory),
            drop_last=True, persistent_workers=bool(int(cfg.num_workers) > 0)
        )
    else:
        dl_tr = DataLoader(
            ds_tr, batch_size=int(cfg.batch_size), shuffle=True,
            num_workers=int(cfg.num_workers), pin_memory=bool(cfg.pin_memory),
            drop_last=False, persistent_workers=bool(int(cfg.num_workers) > 0)
        )

    dl_va = DataLoader(
        ds_va, batch_size=int(cfg.batch_size), shuffle=False,
        num_workers=int(cfg.num_workers), pin_memory=bool(cfg.pin_memory),
        drop_last=False, persistent_workers=bool(int(cfg.num_workers) > 0)
    )
    dl_te = DataLoader(
        ds_te, batch_size=int(cfg.batch_size), shuffle=False,
        num_workers=int(cfg.num_workers), pin_memory=bool(cfg.pin_memory),
        drop_last=False, persistent_workers=bool(int(cfg.num_workers) > 0)
    )

    # Also print batch counts to sanity check batch sizing
    def _n_batches(n_items: int, bs: int, drop_last: bool) -> int:
        if drop_last:
            return max(0, n_items // max(1, bs))
        return (n_items + max(1, bs) - 1) // max(1, bs)

    try:
        print(f"[batches] train={_n_batches(len(ds_tr), int(cfg.batch_size), True)}, "
              f"valid={_n_batches(len(ds_va), int(cfg.batch_size), False)}, "
              f"test={_n_batches(len(ds_te), int(cfg.batch_size), False)}",
              flush=True)
    except Exception:
        pass
    return dl_tr, dl_va, dl_te

# Make sure these are imported/defined from the stack
train_loader, valid_loader, test_loader = make_loaders(cfg)  # prints split details

In [None]:
# Model (Torch/Timm image backbones; optional post-3×3; Text encoder)

# timm (optional import)
try:
    import timm
    _HAS_TIMM = True
except Exception:
    timm = None
    _HAS_TIMM = False

def _net_name_like(net_obj) -> str:
    """
    Robustly resolve a backbone name:
      - NETWORK enum with `.name`
      - or string
      - or fallback to str(obj)
    """
    if isinstance(net_obj, str):
        return net_obj.lower()
    name = getattr(net_obj, "name", None)
    if name is not None:
        return str(name).lower()
    return str(net_obj).lower()

# -------- timm IDs
TIMM_NAME_MAP: Dict[str, str] = {
    "dpn68":         "dpn68.mx_in1k",
    "coatnet0":      "coatnet_0_rw_224.sw_in1k",
    "convnext_nano": "convnext_nano.in12k",
    "hrnet32":       "hrnet_w32.ms_in1k",
}

# -------- helpers --------
def _ends_with_3x3(module: nn.Module) -> bool:
    last_conv = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            last_conv = m
    if last_conv is None:
        return False
    k = getattr(last_conv, "kernel_size", (0, 0))
    if not isinstance(k, tuple):
        k = (k, k)
    return tuple(k) == (3, 3)

def _insert_post3x3(ch: int) -> nn.Sequential:
    # Conv(3×3) + BN + SiLU
    return nn.Sequential(
        nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(ch),
        nn.SiLU(inplace=True),
    )

def _resolve_hf_local_dir(local_dir: Optional[str]) -> Optional[str]:
    if not local_dir or not os.path.isdir(local_dir):
        return None
    needed = {"config.json", "pytorch_model.bin", "model.safetensors", "rust_model.ot"}
    files = set(os.listdir(local_dir))
    if files & needed:
        return local_dir
    snaps = os.path.join(local_dir, "snapshots")
    refs  = os.path.join(local_dir, "refs", "main")
    if os.path.isdir(snaps):
        if os.path.isfile(refs):
            with open(refs, "r") as f:
                commit = f.read().strip()
            cand = os.path.join(snaps, commit)
            if os.path.isdir(cand):
                return cand
        subdirs = [
            os.path.join(snaps, d) for d in os.listdir(snaps)
            if os.path.isdir(os.path.join(snaps, d))
        ]
        if subdirs:
            subdirs.sort(key=lambda p: os.path.getmtime(p), reverse=True)
            return subdirs[0]
    return None

# Torchvision image backbones
class TorchvisionBackbone(nn.Module):
    """
    ResNet18 / DenseNet121 / MobileNetV2
    encoder (features) → [optional post-3×3] → GAP → Dropout → (B, C)
    """
    def __init__(self, which: "NETWORK", p_drop: float = 0.3):
        super().__init__()
        # --- encoder selection
        if which == NETWORK.resnet18:
            try: m = tvm.resnet18(weights=tvm.ResNet18_Weights.IMAGENET1K_V1)
            except: m = tvm.resnet18(weights=None)
            self.encoder = nn.Sequential(*(list(m.children())[:-2]))  # (B, C, H, W)
            ch = m.fc.in_features
        elif which == NETWORK.densenet121:
            try: m = tvm.densenet121(weights=tvm.DenseNet121_Weights.IMAGENET1K_V1)
            except: m = tvm.densenet121(weights=None)
            self.encoder = m.features  # (B, C, H, W)
            ch = m.classifier.in_features
        elif which == NETWORK.mobilenet_v2:
            try: m = tvm.mobilenet_v2(weights=tvm.MobileNet_V2_Weights.IMAGENET1K_V1)
            except: m = tvm.mobilenet_v2(weights=None)
            self.encoder = m.features  # (B, C, H, W)
            ch = m.classifier[1].in_features
        else:
            raise ValueError(which)

        # ---- RIGHT order: post-3×3 → GAP → Dropout
        self.post3x3 = _insert_post3x3(ch) if not _ends_with_3x3(self.encoder) else nn.Identity()
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.drop = nn.Dropout(p=p_drop)
        self.out_dim = ch

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.encoder(x)           # (B, C, H, W)
        f = self.post3x3(f)           # (B, C, H, W)
        f = self.gap(f).flatten(1)    # (B, C)
        f = self.drop(f)              # (B, C)
        return f

class TorchvisionVGGBackbone(nn.Module):
    """
    VGG11/13/16/19 (BN)
    encoder (features) → [optional post-3×3] → GAP → Dropout → (B, C)
    """
    def __init__(self, which: "NETWORK", p_drop: float = 0.3):
        super().__init__()
        vgg_map = {
            NETWORK.vgg11: (tvm.vgg11_bn,  getattr(tvm, "VGG11_BN_Weights", None)),
            NETWORK.vgg13: (tvm.vgg13_bn,  getattr(tvm, "VGG13_BN_Weights", None)),
            NETWORK.vgg16: (tvm.vgg16_bn,  getattr(tvm, "VGG16_BN_Weights", None)),
            NETWORK.vgg19: (tvm.vgg19_bn,  getattr(tvm, "VGG19_BN_Weights", None)),
        }
        ctor, weights_enum = vgg_map[which]
        try:
            if weights_enum is not None:
                weights = getattr(weights_enum, "IMAGENET1K_V1")
                m = ctor(weights=weights)
            else:
                m = ctor(weights=None)
        except Exception:
            m = ctor(weights=None)

        self.encoder = m.features  # (B, C, H, W)
        ch = None
        for mod in self.encoder.modules():
            if isinstance(mod, nn.Conv2d):
                ch = mod.out_channels
        if ch is None:
            raise RuntimeError("VGG out channels not found")

        self.post3x3 = _insert_post3x3(ch) if not _ends_with_3x3(self.encoder) else nn.Identity()
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.drop = nn.Dropout(p=p_drop)
        self.out_dim = ch

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.encoder(x)
        f = self.post3x3(f)
        f = self.gap(f).flatten(1)
        f = self.drop(f)
        return f

# timm backbones
def _maybe_create_timm(timm_name: str, img_size: int) -> nn.Module:
    """
    Create a timm model safely:
      - Don’t pass img_size to models that don’t accept it (e.g., DPN).
      - num_classes=0 + global_pool='' to expose raw features.
    """
    if not _HAS_TIMM:
        raise RuntimeError("timm is not available")

    base_kwargs = dict(num_classes=0, global_pool="")

    # Some models accept img_size, others don't — try with, then without.
    try:
        return timm.create_model(timm_name, pretrained=True, img_size=img_size, **base_kwargs)
    except TypeError:
        # Retry without img_size
        try:
            return timm.create_model(timm_name, pretrained=True, **base_kwargs)
        except Exception:
            # Fallback to pretrained=False
            return timm.create_model(timm_name, pretrained=False, **base_kwargs)

class TimmBackbone(nn.Module):
    """
    timm forward_features → tensor
      - If 4D: encoder → [optional post-3×3] → GAP → Dropout → (B, C)
      - If 3D: token/CLS path → Dropout → (B, C)
    """
    def __init__(self, key_name: str, img_size: int, p_drop: float = 0.1):
        super().__init__()
        if key_name not in TIMM_NAME_MAP:
            raise ValueError(f"Unknown timm backbone key '{key_name}'. Options: {list(TIMM_NAME_MAP.keys())}")
        timm_name = TIMM_NAME_MAP[key_name]

        self.m = _maybe_create_timm(timm_name, img_size=img_size)

        # Probe feature shape to configure heads
        with torch.no_grad():
            dummy = torch.zeros(1, 3, img_size, img_size)
            f = self.m.forward_features(dummy)

        if f.ndim == 4:
            ch = int(f.shape[1])
            # RIGHT ORDER: post-3×3 → GAP → Dropout
            needs_post = not _ends_with_3x3(self.m)
            self.post3x3 = _insert_post3x3(ch) if needs_post else nn.Identity()
            self.gap = nn.AdaptiveAvgPool2d((1, 1))
            self.drop = nn.Dropout(p=p_drop)
            self.kind = "map4d"
            self.out_dim = ch
        elif f.ndim == 3:
            # ViT-like: CLS/token dim
            self.post3x3 = nn.Identity()
            self.gap = nn.Identity()
            self.drop = nn.Dropout(p=p_drop)
            self.kind = "tokens"
            self.out_dim = int(f.shape[-1])
        else:
            raise RuntimeError("Unexpected feature shape from timm model")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.m.forward_features(x)
        if self.kind == "map4d":
            f = self.post3x3(f)          # (B, C, H, W)
            f = self.gap(f).flatten(1)   # (B, C)
            f = self.drop(f)             # (B, C)
            return f
        else:
            # tokens / CLS at index 0
            f = f[:, 0]                   # (B, C)
            f = self.drop(f)
            return f

# ------------------------------------------------------------------
# Text fine-tuning policy (copied semantically from multimodal)
# ------------------------------------------------------------------

@dataclass
class TextFTPolicy:
    """
    Text encoder fine-tuning policy:
      - authors_default: freeze embeddings + pooler, finetune last encoder block only
      - freeze_all:      no text encoder params train
      - last_n:          finetune last `last_n` encoder blocks
      - train_all:       finetune all text encoder params
    """
    mode: Literal["authors_default", "freeze_all", "last_n", "train_all"] = "authors_default"
    last_n: int = 1


def configure_text_finetune(bert: nn.Module, policy: TextFTPolicy) -> None:
    """Set requires_grad on HF text encoder parameters according to policy."""
    def set_all(req: bool):
        for p in bert.parameters():
            p.requires_grad = req

    if policy.mode == "authors_default":
        # Freeze embeddings + pooler, finetune only final encoder layer
        if hasattr(bert, "embeddings"):
            for p in bert.embeddings.parameters():
                p.requires_grad = False
        if hasattr(bert, "pooler"):
            for p in bert.pooler.parameters():
                p.requires_grad = False

        enc_layers = getattr(getattr(bert, "encoder", None), "layer", [])
        for i, layer in enumerate(enc_layers):
            req = (i == len(enc_layers) - 1)
            for p in layer.parameters():
                p.requires_grad = req
        return

    if policy.mode == "freeze_all":
        set_all(False)
        return

    if policy.mode == "train_all":
        set_all(True)
        return

    if policy.mode == "last_n":
        # Freeze embeddings + pooler, finetune last `last_n` encoder layers
        if hasattr(bert, "embeddings"):
            for p in bert.embeddings.parameters():
                p.requires_grad = False
        if hasattr(bert, "pooler"):
            for p in bert.pooler.parameters():
                p.requires_grad = False

        enc = getattr(bert, "encoder", None)
        total = len(enc.layer)
        cutoff = max(0, total - int(policy.last_n))
        for i, layer in enumerate(enc.layer):
            req = (i >= cutoff)
            for p in layer.parameters():
                p.requires_grad = req
        return

    raise ValueError(f"Unknown TextFTPolicy.mode: {policy.mode}")

# Unimodal classifier (Img-only OR Txt-only)
class ImgOrTxtClassifier(nn.Module):
    """   
      - Image: backbone → [optional post-3×3] → GAP → Dropout → Linear → ReLU → Linear
      - Text : HF encoder (pooled/CLS) → Dropout → Linear → ReLU → Linear
    """
    def __init__(self, cfg: "RunConfig", text_ft: Optional[TextFTPolicy] = None):
        super().__init__()
        self.modality = cfg.modality
        self.n_classes = int(cfg.n_classes)
        self.hidden = int(cfg.hidden_dim)
        self.img_size = int(getattr(cfg, "img_size", 224))
        if text_ft is None:
            text_ft = TextFTPolicy(mode="freeze_all")

        # Image branch
        self.img_enc = None
        img_out = 0
        if self.modality == MOD.img:
            net_name = _net_name_like(cfg.network).lower()

            if _HAS_TIMM and (net_name in TIMM_NAME_MAP):
                self.img_enc = TimmBackbone(net_name, img_size=self.img_size, p_drop=0.3)
            else:
                if net_name in ("resnet18", "densenet121", "mobilenet_v2"):
                    self.img_enc = TorchvisionBackbone(cfg.network, p_drop=0.3)
                elif net_name in ("vgg11", "vgg13", "vgg16", "vgg19"):
                    self.img_enc = TorchvisionVGGBackbone(cfg.network, p_drop=0.3)
                else:
                    raise ValueError(f"Unsupported backbone: {cfg.network} (name='{net_name}')")

            img_out = int(self.img_enc.out_dim)

        # Text branch
        self.txt_enc = None
        txt_hid = 0
        if self.modality == MOD.txt:
            resolved = _resolve_hf_local_dir(getattr(cfg, "hf_local_dir", None))
            offline = bool(getattr(cfg, "hf_offline", True))

            if offline:
                if not resolved:
                    raise RuntimeError("Offline mode but no valid local HF snapshot found.")
                self.txt_enc = AutoModel.from_pretrained(resolved, local_files_only=True)
            else:
                model_id = getattr(cfg, "hf_text_model_name", "microsoft/BiomedVLP-CXR-BERT-general")
                self.txt_enc = AutoModel.from_pretrained(resolved or model_id)

            # Apply multimodal-style fine-tuning policy
            configure_text_finetune(self.txt_enc, text_ft)

            txt_hid = int(self.txt_enc.config.hidden_size)
            self.txt_drop = nn.Dropout(p=0.3)

        # Shared head
        feat_in = img_out if self.modality == MOD.img else txt_hid
        self.proj = nn.Linear(feat_in, self.hidden)
        self.act = nn.ReLU(inplace=True)
        self.mid = nn.Linear(self.hidden, self.hidden)
        self.cls = nn.Linear(self.hidden, self.n_classes)

    # Image-only path
    def forward_image(self, x: torch.Tensor) -> torch.Tensor:
        f = self.img_enc(x)                  # (B, C)
        z = self.mid(self.act(self.proj(f)))
        return self.cls(z)                   # logits

    # Text-only path
    def forward_text(self, t: Dict[str, torch.Tensor]) -> torch.Tensor:
        out = self.txt_enc(input_ids=t["input_ids"], attention_mask=t["attention_mask"])
        pooled = out.pooler_output if getattr(out, "pooler_output", None) is not None else out.last_hidden_state[:, 0]
        pooled = self.txt_drop(pooled)
        z = self.mid(self.act(self.proj(pooled)))
        return self.cls(z)

    def forward(self, image=None, text=None) -> torch.Tensor:
        if self.modality == MOD.img:
            return self.forward_image(image)
        if self.modality == MOD.txt:
            return self.forward_text(text)
        raise ValueError(f"ImgOrTxtClassifier supports only MOD.img or MOD.txt, got {self.modality}")

In [None]:
# Loss (CE only) for either image or text, and related metrics

_ce = nn.CrossEntropyLoss(reduction="mean")

def onehot_to_idx(y_onehot: torch.Tensor) -> torch.Tensor:
    return torch.argmax(y_onehot, dim=1)

def compute_total_loss(logits: torch.Tensor, y_onehot: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]:
    y_idx = onehot_to_idx(y_onehot)
    ce = _ce(logits, y_idx)
    return ce, {"total": float(ce.detach().item()), "ce": float(ce.detach().item())}

@torch.no_grad()
def evaluate_dataloader_classifier(model: nn.Module, loader: DataLoader, modality: MOD) -> Dict[str, float]:
    device = next(model.parameters()).device
    y_true, y_pred, y_prob = [], [], []
    for batch in loader:
        if batch["y_onehot"].numel() == 0: continue
        if modality == MOD.img:
            img = batch["image"].to(device, non_blocking=True)
            logits = model.forward_image(img)
        else:
            txt = {k: v.to(device, non_blocking=True) for k, v in batch["text"].items()}
            logits = model.forward_text(txt)
        probs1 = F.softmax(logits.float(), dim=1)[:,1].detach().cpu().numpy()
        preds  = logits.argmax(dim=1).detach().cpu().numpy()
        y_idx  = torch.argmax(batch["y_onehot"], dim=1).detach().cpu().numpy()
        y_true.extend(y_idx.tolist()); y_pred.extend(preds.tolist()); y_prob.extend(probs1.tolist())
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    y_prob = np.asarray(y_prob, dtype=float)
    TP = int(((y_true==1)&(y_pred==1)).sum())
    FP = int(((y_true==0)&(y_pred==1)).sum())
    FN = int(((y_true==1)&(y_pred==0)).sum())
    TN = int(((y_true==0)&(y_pred==0)).sum())
    bal_acc = balanced_accuracy_score(y_true, y_pred) if y_true.size else 0.0
    sens = recall_score(y_true, y_pred, pos_label=1) if y_true.size else 0.0
    spec = TN / (TN + FP) if (TN + FP) else 0.0
    prec = precision_score(y_true, y_pred, pos_label=1) if (TP + FP) else 0.0
    npv = TN / (TN + FN) if (TN + FN) else 0.0
    f1s = f1_score(y_true, y_pred, pos_label=1) if y_true.size else 0.0
    mcc = matthews_corrcoef(y_true, y_pred) if y_true.size else 0.0
    kappa = cohen_kappa_score(y_true, y_pred) if y_true.size else 0.0
    try: roc_auc = roc_auc_score(y_true, y_prob)
    except ValueError: roc_auc = float("nan")
    return {
        "TP": TP, "FP": FP, "FN": FN, "TN": TN,
        "balanced_accuracy": float(bal_acc), "sensitivity": float(sens), "specificity": float(spec),
        "precision": float(prec), "NPV": float(npv), "F1_score": float(f1s),
        "MCC": float(mcc), "Cohen_Kappa": float(kappa), "ROC_AUC": float(roc_auc)
    }

In [None]:
# Train loop (save best by highest validation MCC)

def _count_params(m: nn.Module) -> Tuple[int, int]:
    tot = sum(p.numel() for p in m.parameters())
    trn = sum(p.numel() for p in m.parameters() if p.requires_grad)
    return tot, trn

def _format_params(n: int) -> str:
    if n >= 1e6: return f"{n/1e6:.2f}M"
    if n >= 1e3: return f"{n/1e3:.2f}k"
    return str(n)

def _save_curves_png_csv(history: Dict[str, List[float]], out_png: str, out_csv: str) -> None:
    os.makedirs(os.path.dirname(out_png), exist_ok=True)
    df = pd.DataFrame(history); df.to_csv(out_csv, index=False)
    fig, ax = plt.subplots(figsize=(8,6))
    if "train_total" in df and "val_total" in df:
        ax.plot(df["epoch"], df["train_total"], label="train_total", linewidth=2)
        ax.plot(df["epoch"], df["val_total"],   label="val_total",   linewidth=2, linestyle="--")
    ax.set_xlabel("Epoch"); ax.set_ylabel("Loss"); ax.set_title("Training / Validation Loss")
    ax.legend(loc="upper right"); ax.grid(True, alpha=0.3)
    fig.tight_layout(); fig.savefig(out_png, dpi=200); plt.close(fig)

def _save_mcc_curve(history: Dict[str, List[float]], out_png: str) -> None:
    try:
        epochs = history["epoch"]; val_mcc = history["val_MCC"]
        fig, ax = plt.subplots(figsize=(8,5))
        ax.plot(epochs, val_mcc, linewidth=2)
        ax.set_xlabel("Epoch"); ax.set_ylabel("Validation MCC"); ax.set_title("Validation MCC")
        ax.grid(True, alpha=0.3); fig.tight_layout(); fig.savefig(out_png, dpi=200); plt.close(fig)
    except Exception:
        pass

def _print_model_overview(model: nn.Module, cfg: "RunConfig") -> None:
    # Full architecture + parameter counts + useful heads info
    print("\n========== Model Architecture (FULL) ==========", flush=True)
    print(f"Modality: {cfg.modality.name}", flush=True)
    print(f"Backbone: {cfg.network.name}", flush=True)
    print(f"Hidden dim: {cfg.hidden_dim} | Classes: {cfg.n_classes}", flush=True)
    if hasattr(model, "img_enc") and (model.img_enc is not None):
        img_out = getattr(model.img_enc, "out_dim", None)
        print(f"[Image] out_dim={img_out}", flush=True)
        # identify if post3x3 exists and is active
        if hasattr(model.img_enc, "post3x3"):
            print(f"[Image] post3x3: {model.img_enc.post3x3.__class__.__name__}", flush=True)
    if hasattr(model, "txt_enc") and (model.txt_enc is not None):
        hid = getattr(getattr(model.txt_enc, "config", None), "hidden_size", None)
        print(f"[Text ] hidden_size={hid}", flush=True)

    tot, trn = _count_params(model)
    frz = tot - trn
    print(f"Parameters: total={_format_params(tot)}, trainable={_format_params(trn)}, frozen={_format_params(frz)}", flush=True)
    print("-------------- BEGIN torch.nn print(model) --------------", flush=True)
    print(model, flush=True)
    print("--------------- END torch.nn print(model) ---------------\n", flush=True)

@dataclass
class TrainRuntimeCfg:
    amp: bool = True
    patience: int = 10
    grad_clip_norm: Optional[float] = 1.0
    amp_dtype: torch.dtype = torch.float16
    log_every: int = 1

def run_epoch(model: nn.Module, loader: DataLoader, cfg: "RunConfig",
              *, train: bool, optim: Optional[torch.optim.Optimizer],
              trcfg: TrainRuntimeCfg) -> Dict[str, float]:
    device = next(model.parameters()).device
    use_amp = trcfg.amp and (device.type == "cuda")
    model.train(mode=train)
    sums = {"total": 0.0, "ce": 0.0}
    n_samples = 0
    # scaler selection for torch versions
    scaler = (torch.amp.GradScaler("cuda") if (train and use_amp and hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"))
              else (torch.cuda.amp.GradScaler() if (train and use_amp and hasattr(torch.cuda, "amp")) else None))
    for batch in loader:
        if batch["y_onehot"].numel() == 0: continue
        bsz = batch["y_onehot"].size(0); n_samples += bsz
        if cfg.modality == MOD.img:
            x = batch["image"].to(device, non_blocking=True)
            target = batch["y_onehot"].to(device, non_blocking=True)
            autocast_ctx = (torch.autocast("cuda", dtype=trcfg.amp_dtype) if use_amp else nullcontext())
            with torch.set_grad_enabled(train), autocast_ctx:
                logits = model.forward_image(x)
                loss, scalars = compute_total_loss(logits, target)
        else:
            t = {k: v.to(device, non_blocking=True) for k, v in batch["text"].items()}
            target = batch["y_onehot"].to(device, non_blocking=True)
            autocast_ctx = (torch.autocast("cuda", dtype=trcfg.amp_dtype) if use_amp else nullcontext())
            with torch.set_grad_enabled(train), autocast_ctx:
                logits = model.forward_text(t)
                loss, scalars = compute_total_loss(logits, target)
        if train:
            optim.zero_grad(set_to_none=True)
            if scaler is not None:
                scaler.scale(loss).backward()
                if trcfg.grad_clip_norm is not None:
                    scaler.unscale_(optim); nn.utils.clip_grad_norm_(model.parameters(), trcfg.grad_clip_norm)
                scaler.step(optim); scaler.update()
            else:
                loss.backward()
                if trcfg.grad_clip_norm is not None:
                    nn.utils.clip_grad_norm_(model.parameters(), trcfg.grad_clip_norm)
                optim.step()
        for k in sums.keys():
            sums[k] += scalars.get(k, 0.0) * bsz
    eps = 1e-12
    return {k: (v / max(n_samples, eps)) for k, v in sums.items()}

def fit_img_or_txt_with_best_mcc_checkpoint(
    model: nn.Module,
    train_loader: DataLoader,
    valid_loader: DataLoader,
    cfg: "RunConfig",
    trcfg: TrainRuntimeCfg = TrainRuntimeCfg(),
) -> Dict[str, object]:

    assert cfg.modality in (MOD.img, MOD.txt), "Training here supports image-only or text-only."

    # Print full summary **before** training begins
    _print_model_overview(model, cfg)

    os.makedirs(cfg.out_dir, exist_ok=True)
    paths = ckpt_paths(cfg)

    # Backward-compat fallback for mcc_png if missing in ckpt_paths
    if "mcc_png" not in paths:
        stem = f"{cfg.network.name}_{cfg.modality.name}"
        paths["mcc_png"] = os.path.join(cfg.out_dir, f"best_{stem}_val_loss_mcc.png")

    optim = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    history = {"epoch": [], "train_total": [], "val_total": [], "val_MCC": []}

    best_mcc = -1.0
    best_epoch = 0
    bad_epochs = 0
    improvements = 0
    stopped_early = False
    last_va_total = None
    last_va_mcc = None

    for ep in range(1, cfg.epochs + 1):
        tr = run_epoch(model, train_loader, cfg, train=True,  optim=optim, trcfg=trcfg)
        va = run_epoch(model, valid_loader, cfg, train=False, optim=None,  trcfg=trcfg)
        last_va_total = va["total"]

        # compute validation metrics → MCC
        val_metrics = evaluate_dataloader_classifier(model, valid_loader, cfg.modality)
        last_va_mcc = val_metrics["MCC"]

        best_str = ("-∞" if best_mcc < 0 else f"{best_mcc:.4f}")
        streak_str = f"{bad_epochs}/{trcfg.patience}"
        print(f"[{ep:03d}/{cfg.epochs} | best_val_MCC={best_str} | no_improve={streak_str}] "
              f"Train: total={tr['total']:.4f}  |  Valid: total={va['total']:.4f}  MCC={last_va_mcc:.4f}")

        history["epoch"].append(ep)
        history["train_total"].append(tr["total"]); history["val_total"].append(va["total"]); history["val_MCC"].append(last_va_mcc)

        # selection criterion: **maximize validation MCC**
        if last_va_mcc > best_mcc + 1e-8:
            print(f"  ↑ val_MCC improved: {best_mcc:.6f} → {last_va_mcc:.6f}  (saving checkpoint; reset no_improve=0)")
            best_mcc = last_va_mcc; best_epoch = ep; bad_epochs = 0; improvements += 1
            torch.save(model.state_dict(), paths["best_pt"])
            print(f"  ↳ saved: {paths['best_pt']}")
            with open(paths["log_json"], "w") as f:
                json.dump({
                    "best_val_mcc": float(best_mcc),
                    "best_epoch": int(best_epoch),
                    "epochs_run": int(ep)
                }, f, indent=2)
        else:
            bad_epochs += 1
            print(f"  ↳ no MCC improvement (no_improve={bad_epochs}/{trcfg.patience})")
            if bad_epochs >= trcfg.patience:
                print(f"EARLY STOP: validation MCC did not improve for {bad_epochs} consecutive epochs "
                      f"(patience={trcfg.patience}). Stopping at epoch {ep}.")
                stopped_early = True
                break

    # Curves
    _save_curves_png_csv(history, paths["curves_png"], paths["curves_csv"])
    _save_curves_png_csv(history, paths["curves_png_full"], paths["curves_csv_full"])
    _save_mcc_curve(history, paths["mcc_png"])

    # Restore best checkpoint
    try:
        state = torch.load(paths["best_pt"], map_location="cpu")
        model.load_state_dict(state, strict=True)
        print(f"Restored best (by MCC) from {paths['best_pt']}  |  best_epoch={best_epoch}  best_val_MCC={best_mcc:.6f}")
    except Exception as e:
        print(f"[WARN] Could not reload best checkpoint ({e}).")

    return {
        "best_val_MCC": float(best_mcc),
        "best_epoch": int(best_epoch),
        "improvements": int(improvements),
        "early_stopped": bool(stopped_early),
        "last_epoch_val_total": float(last_va_total) if last_va_total is not None else None,
        "last_epoch_val_MCC": float(last_va_mcc) if last_va_mcc is not None else None,
        "paths": paths,
        "history": history,
    }

In [None]:
# TRAINING RUN (pick MOD.img OR MOD.txt)

if __name__ == "__main__":
    # --- Choose one modality: MOD.img or MOD.txt
    cfg = RunConfig(
        modality=MOD.img,    # <-- change to MOD.txt for text-only
        network=NETWORK.vgg11, # choose backbone (dpn68 / coatnet0, etc)
        n_classes=2,
        hidden_dim=256,
        epochs=64,
        lr=5e-5,
        weight_decay=1e-4,
        img_size=224,
        batch_size=64,
        num_workers=2,
        pin_memory=True,
        dataset_root="/dataset",
        csv_train="/label_train.csv",
        csv_valid="/label_valid.csv",
        csv_test ="/label_test.csv",
        images_subdir="images",
        reports_subdir="reports",
        out_dir="/models",
        class_names=("normal","tb"),
        hf_offline=True,
        hf_local_dir="/huggingface/hub/models--microsoft--BiomedVLP-CXR-BERT-general",
        hf_tokenizer_local_dir="/huggingface/hub/models--microsoft--BiomedVLP-CXR-BERT-general",
    )

    # Build loaders
    train_loader, valid_loader, test_loader = make_loaders(cfg)

    # Device + model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ImgOrTxtClassifier(cfg).to(device)

    # Print model summary BEFORE training (full architecture + param counts)
    _print_model_overview(model, cfg)

    # Train (select best by validation MCC; filename unchanged)
    trcfg = TrainRuntimeCfg(amp=True, patience=10, grad_clip_norm=1.0, amp_dtype=torch.float16)
    out = fit_img_or_txt_with_best_mcc_checkpoint(model, train_loader, valid_loader, cfg, trcfg=trcfg)

    # Evaluate best on VALID and TEST
    valid_metrics = evaluate_dataloader_classifier(model, valid_loader, cfg.modality)
    test_metrics = evaluate_dataloader_classifier(model, test_loader, cfg.modality)

    print("\n[BEST CHECKPOINT (by MCC)]", flush=True)
    print(f"checkpoint : {out['paths']['best_pt']}", flush=True)
    print(f"validation MCC : {valid_metrics['MCC']:.6f}", flush=True)
    print(f"test MCC  : {test_metrics['MCC']:.6f}", flush=True)
    print(f"valid metrics  : {valid_metrics}", flush=True)
    print(f"test metrics  : {test_metrics}", flush=True)

In [None]:
# Test inference (reuse training loaders; pick highest val-MCC)

# ---------- model-aware checkpoint naming ----------
def _ckpt_stem(cfg: "RunConfig") -> str:
    # e.g., "vgg16_img"
    return f"{cfg.network.name}_{cfg.modality.name}"

def _ckpt_paths(cfg: "RunConfig") -> Dict[str, str]:
    stem = _ckpt_stem(cfg)
    base = cfg.out_dir
    return {
        "best_pt":         os.path.join(base, f"best_{stem}_val_loss.pt"),
        "curves_png":      os.path.join(base, f"best_{stem}_val_loss_curves.png"),
        "curves_csv":      os.path.join(base, f"best_{stem}_val_loss_history.csv"),
        "curves_png_full": os.path.join(base, f"best_{stem}_val_loss_curves_full.png"),
        "curves_csv_full": os.path.join(base, f"best_{stem}_val_loss_history_full.csv"),
        "log_json":        os.path.join(base, f"best_{stem}_val_loss_log.json"),
    }

# robust best-MCC reader
def _coerce_float(x: Any) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _read_best_mcc_from_log(log_json_path: str) -> Optional[float]:
    """
    Robust to capitalization and placements:
      - best_val_mcc / best_val_MCC / val_mcc / val_MCC / best_mcc / best_MCC
      - history: { val_mcc: [...], val_MCC: [...] } → max
    """
    try:
        with open(log_json_path, "r") as f:
            meta = json.load(f)
    except Exception:
        return None
    if not isinstance(meta, dict):
        return None

    for k in ("best_val_mcc", "best_val_MCC", "val_mcc", "val_MCC", "best_mcc", "best_MCC"):
        v = _coerce_float(meta.get(k, None))
        if v is not None:
            return v

    hist = meta.get("history", None)
    if isinstance(hist, dict):
        for hk in ("val_mcc", "val_MCC"):
            arr = hist.get(hk, None)
            if isinstance(arr, list) and len(arr) > 0:
                try:
                    return float(np.nanmax(np.asarray(arr, dtype=float)))
                except Exception:
                    pass
    return None

# scan out_dir and select best by MCC
def scan_runs_and_select_best(cfg: "RunConfig") -> Dict[str, str]:
    stem = _ckpt_stem(cfg)
    pattern = os.path.join(cfg.out_dir, "**", f"best_{stem}_val_loss.pt")
    paths = glob.glob(pattern, recursive=True)
    rows = []
    for pt in paths:
        subdir = os.path.dirname(pt)
        log_json = os.path.join(subdir, f"best_{stem}_val_loss_log.json")
        mtime = os.path.getmtime(pt)
        mcc = _read_best_mcc_from_log(log_json) if os.path.isfile(log_json) else None
        rows.append({
            "ckpt_path": pt,
            "log_json": log_json if os.path.isfile(log_json) else None,
            "mtime": mtime,
            "val_mcc": (None if mcc is None else float(mcc)),
            "subdir": subdir
        })

    if not rows:
        raise FileNotFoundError(f"No checkpoints found under '{cfg.out_dir}' for stem 'best_{stem}_val_loss.pt'.")

    df = pd.DataFrame(rows)
    df["val_mcc"] = pd.to_numeric(df["val_mcc"], errors="coerce")
    df = df.sort_values(by=["val_mcc", "mtime"], ascending=[False, False], na_position="last")
    return df.iloc[0].to_dict()

# load image-only trained model; modify for loading text-only trained model
@torch.no_grad()
def load_img_model(cfg_mod: "RunConfig", ckpt_path: str) -> nn.Module:
    assert cfg_mod.modality == MOD.img, "This loader is for image-only models."
    model = ImgOrTxtClassifier(cfg_mod).to(device).eval()
    state = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(state, strict=True)
    return model

# print model name, arch, and parameter counts
def _print_model_summary_for_test(model: nn.Module, cfg_mod: "RunConfig", ckpt_path: str) -> None:
    # Detect backbone class name (timm or torchvision wrapper)
    try:
        if hasattr(model, "img_enc") and hasattr(model.img_enc, "m"):       # timm-backed wrapper
            backbone_detected = model.img_enc.m.__class__.__name__
        elif hasattr(model, "img_enc"):                                     # torchvision-backed wrapper
            backbone_detected = model.img_enc.__class__.__name__
        else:
            backbone_detected = model.__class__.__name__
    except Exception:
        backbone_detected = model.__class__.__name__

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params

    print("\n========== Model Loaded for Test ==========", flush=True)
    print(f"Checkpoint           : {ckpt_path}", flush=True)
    print(f"Backbone (cfg)       : {getattr(cfg_mod.network, 'name', str(cfg_mod.network))}", flush=True)
    print(f"Backbone (detected)  : {backbone_detected}", flush=True)
    print(f"Modality             : {getattr(cfg_mod.modality, 'name', str(cfg_mod.modality))}", flush=True)
    print(f"Device               : {device.type}", flush=True)
    print(f"Total params         : {total_params:,}", flush=True)
    print(f"Trainable params     : {trainable_params:,}", flush=True)
    print(f"Frozen params        : {frozen_params:,}", flush=True)
    print("--------------- Full Architecture ---------------", flush=True)
    print(model, flush=True)
    print("=================================================\n", flush=True)


# evaluate using the SAME function as training
@torch.inference_mode()
def evaluate_image_only_with_training_stack(
    cfg_mod: "RunConfig",
    ckpt_path: str,
    save_subdir: str = "test_infer_unimodal_shenzhen_new",
) -> Dict[str, float]:

    assert cfg_mod.modality == MOD.img, "Use MOD.img for image-only evaluation."
    # Reuse loaders to get identical transforms & batching
    _, _, test_loader = make_loaders(cfg_mod)

    # Load the best-MCC checkpoint and immediately print summary
    model = load_img_model(cfg_mod, ckpt_path)
    model.eval()  # ensure eval mode
    _print_model_summary_for_test(model, cfg_mod, ckpt_path)

    # Evaluate with the exact same helper used during training
    metrics = evaluate_dataloader_classifier(model, test_loader, cfg_mod.modality)

    # Save metrics next to the checkpoint
    save_dir = os.path.join(os.path.dirname(ckpt_path), save_subdir)
    os.makedirs(save_dir, exist_ok=True)   
    pd.DataFrame([metrics]).to_csv(os.path.join(save_dir, "test_metrics.csv"), index=False)

    # APPEND: save per-sample softmax predictions
    rows: list[dict] = []
    for batch in test_loader:
        # Strictly follow image-only evaluator: use "image" and "y_onehot"
        if not isinstance(batch, dict) or ("image" not in batch) or ("y_onehot" not in batch):
            continue
        if batch["y_onehot"].numel() == 0:
            continue

        x = batch["image"].to(device, non_blocking=True)
        logits = model.forward_image(x)  # image-only path
        probs = F.softmax(logits.float(), dim=1).detach().cpu().numpy()  # (B,2)
        y_idx = torch.argmax(batch["y_onehot"], dim=1).detach().cpu().numpy().astype(int)
        pred = probs.argmax(axis=1).astype(int)

        # Filenames: prefer 'filename'; else basename of 'image_path'; else idx_i fallback
        B = probs.shape[0]
        # Gather names from potential keys without tensor truthiness
        def _to_list(v):
            if v is None: return None
            if isinstance(v, (list, tuple, np.ndarray)):
                return [os.path.basename(str(s)) for s in list(v)]
            return [os.path.basename(str(v))] * B
        names = None
        if "filename" in batch:
            names = _to_list(batch["filename"])
        if (names is None) and ("image_path" in batch):
            names = _to_list(batch["image_path"])
        if (names is None) or (len(names) != B):
            names = [f"idx_{i}" for i in range(B)]
        for i in range(B):
            rows.append({
                "img": names[i],
                "true_label": int(y_idx[i]),
                "prob_0": float(probs[i, 0]),
                "prob_1": float(probs[i, 1]),
                "predicted_label": int(pred[i]),
            })

    out_csv = os.path.join(save_dir, "softmax_preds.csv")
    pd.DataFrame(rows).to_csv(out_csv, index=False)
    print(f"[SOFTMAX] Saved per-sample probabilities to: {out_csv}")
    return metrics

# runner
if __name__ == "__main__":
    # SAME modality/backbone/out_dir as image-only training run
    cfg_eval = RunConfig(
        modality=MOD.img,
        network=NETWORK.vgg11,
        n_classes=2,
        hidden_dim=256,
        img_size=224,
        batch_size=64,
        num_workers=2,
        pin_memory=True,
        dataset_root="/dataset",
        csv_test ="/label_test.csv",
        images_subdir="images",
        out_dir="/models",
        hf_offline=True
    )

    # OPTIONAL: set this if you want to force a specific checkpoint
    # MANUAL_CKPT = "/path/to/your/best_model.pt"

    if MANUAL_CKPT is None:
        # Scan runs and list them (for visibility)
        def _fmt(x):
            if x is None or (isinstance(x, float) and (math.isnan(x))): return "—"
            return f"{x:.4f}" if isinstance(x, float) else str(x)

        print("\n================ Available runs for this stem ================")
        stem = _ckpt_stem(cfg_eval)
        for pt in glob.glob(os.path.join(cfg_eval.out_dir, "**", f"best_{stem}_val_loss.pt"), recursive=True):
            logp = os.path.join(os.path.dirname(pt), f"best_{stem}_val_loss_log.json")
            mcc = _read_best_mcc_from_log(logp) if os.path.isfile(logp) else None
            print(f"• {pt}\n    val_MCC={_fmt(mcc)}")
        print("==============================================================\n")

        best_row = scan_runs_and_select_best(cfg_eval)
        ckpt_path = best_row["ckpt_path"]
        print(">>> Selecting checkpoint:")
        print(f"    {ckpt_path}")
        reason = f"highest validation MCC={_fmt(best_row['val_mcc'])}"
        if best_row['val_mcc'] is None or (isinstance(best_row['val_mcc'], float) and math.isnan(best_row['val_mcc'])):
            reason = "no MCC recorded; selecting most recent checkpoint"
        print(f"    Because: {reason}\n")
    else:
        ckpt_path = MANUAL_CKPT
        print(">>> Using manual checkpoint:")
        print(f"    {ckpt_path}\n")

    # Evaluate on test set using the *same* evaluation helper as training
    save_subdir = f"test_infer_unimodal_tbx11k_external_new1_{_ckpt_stem(cfg_eval)}"
    res = evaluate_image_only_with_training_stack(cfg_eval, ckpt_path, save_subdir=save_subdir)

    print(f"\n=== Test results (selected run) [{cfg_eval.network.name}/{cfg_eval.modality.name}] ===")
    for k, v in res.items():
        print(f"{k:>18s}: {v}")

### GRAD-CAM VISUALIZATION USING THE SHENZHEN TEST SET

1. Scans all runs under cfg.out_dir, reads each run’s log to get validation MCC, and selects the highest MCC.

2. Builds YOLO GT txts from lung masks (normalized to the original image size), then scales GT boxes to 1024×1024 for drawing.

3. Runs Grad-CAM on the test images and saves: the base 1024 image, heatmaps (with red GT boxes), contours (with red GT boxes), and bboxes (blue model boxes + red GT boxes).

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Logging
logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] %(levelname)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger("gradcam-img-gt")

# Reproducibility
def set_deterministic(seed: int = 1337) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # no-op on CPU
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_deterministic(1337)

# User config (Unimodal)
    modality=MOD.img,
    network=NETWORK.vgg11,
    n_classes=2,
    hidden_dim=256,
    epochs=64,
    lr=5e-5,
    weight_decay=1e-4,
    img_size=224,
    batch_size=64,
    num_workers=2,
    pin_memory=True,
    dataset_root="/dataset",
    csv_test ="/dataset/label_test.csv",
    images_subdir="images",
    reports_subdir="reports", # unused here
    out_dir="/unimodal/models_img_vgg11",
    class_names=("normal","tb"),
    hf_offline=True,
    hf_local_dir=None,
)

# Original images (cropped) + masks (cropped) to derive GT from
ORIG_DIR = os.path.join(cfg.dataset_root, "shen_orig_crop")
MASK_DIR = os.path.join(cfg.dataset_root, "shen_mask_crop")
YOLO_TXT_DIR = os.path.join(cfg.dataset_root, "shen_mask_crop_yolo_from_mask")  # will be created
RESIZED_1024_DIR = os.path.join(cfg.dataset_root, "images")  # 1024×1024 viz base

# Output folder name under the chosen run
SAVE_ROOT_NAME = f"gradcam_{cfg.network.name}_{cfg.modality.name}_internal_overlap"

# Grad-CAM + drawing params
CAM_METHOD  = "gradcam" # "gradcam", "gradcam++", "xgradcam", ...
HEATMAP_ALPHA = 0.5
BIN_THR = 0.4         # for contour extraction

# Colors / thickness (BGR)
CONTOUR_COLOR = (0,0,255) # red for contours
CONTOUR_THICK = 3
BB_MODEL_COLOR = (255,0,0) # blue for model boxes
BB_GT_COLOR = (0,0,255) # red  for GT boxes
BB_THICK = 4

TARGET_CLASS_IDX = 1 # "TB" class
RESIZE_TO = 1024 # visualization side
IMG_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
AUTO_CLEAN = False  # set True to purge junk files (never touches valid images)

# Path helpers & junk filtering
JUNK_DIR_TOKENS = {"__pycache__", ".ipynb_checkpoints", ".git", ".svn", ".DS_Store"}
JUNK_FILE_BASENAMES = {"thumbs.db", "desktop.ini"}
_CHECKPOINT_RE = re.compile(r"(?i)(?:^|[^A-Za-z0-9])checkpoint(?:$|[^A-Za-z0-9])")
_DOTFILE_RE    = re.compile(r"^\.")

def _is_img(path_or_name: str) -> bool:
    return os.path.splitext(path_or_name)[1].lower() in IMG_EXTS

def _basename_no_ext(path: str) -> str:
    return os.path.splitext(os.path.basename(path))[0]

def _is_junk_path(path: str) -> bool:
    p = pathlib.Path(path)
    for part in p.parts:
        if part in JUNK_DIR_TOKENS or _DOTFILE_RE.search(part):
            return True
    base = p.name
    if base.lower() in JUNK_FILE_BASENAMES:
        return True
    if _DOTFILE_RE.match(base) or _CHECKPOINT_RE.search(os.path.splitext(base)[0]):
        return True
    return False

def _list_images(dir_path: str) -> List[str]:
    files: List[str] = []
    for root, dirs, fns in os.walk(dir_path):
        dirs[:] = [d for d in dirs if not _is_junk_path(os.path.join(root, d))]
        for fn in fns:
            fp = os.path.join(root, fn)
            if _is_junk_path(fp): continue
            if _is_img(fn): files.append(fp)
    files.sort()
    return files

def _ensure_dir(p: str) -> None:
    os.makedirs(p, exist_ok=True)

def _safe_remove(path: str) -> bool:
    try:
        os.remove(path)
        return True
    except Exception:
        return False

def clean_tree_of_junk(root_dir: str) -> Dict[str, int]:
    removed = 0
    skipped = 0
    for r, _, files in os.walk(root_dir):
        for f in files:
            fp = os.path.join(r, f)
            junk = _is_junk_path(fp)
            if not junk and _is_img(fp):
                base_woext = os.path.splitext(os.path.basename(fp))[0]
                if _CHECKPOINT_RE.search(base_woext):
                    junk = True
            if junk:
                if _safe_remove(fp): removed += 1
                else: skipped += 1
    return {"removed": removed, "skipped": skipped}

# Checkpoint naming & scanner (IMG ONLY)
def _ckpt_stem(cfg_scan: "RunConfig") -> str:
    return f"{cfg_scan.network.name}_{cfg_scan.modality.name}"  # e.g., vgg11_img

def _read_val_mcc_from_log(log_json_path: str) -> Optional[float]:
    try:
        with open(log_json_path, "r") as f: meta = json.load(f)
    except Exception:
        return None
    # common keys
    for k in ("best_val_mcc", "best_val_MCC", "val_mcc", "val_MCC", "best_mcc", "best_MCC"):
        v = meta.get(k, None)
        if isinstance(v, (int, float)): return float(v)
    # nested dicts
    for nk in ("valid_metrics","validation_metrics","val_metrics","best_metrics","best_valid_metrics"):
        d = meta.get(nk, None)
        if isinstance(d, dict):
            for cand in ("MCC","mcc","val_MCC","val_mcc"):
                v = d.get(cand, None)
                if isinstance(v, (int, float)): return float(v)
    # history → max
    hist = meta.get("history", None)
    if isinstance(hist, dict):
        for hk in ("val_mcc","val_MCC"):
            arr = hist.get(hk, None)
            if isinstance(arr, list) and len(arr) > 0:
                try: return float(np.nanmax(np.asarray(arr, dtype=float)))
                except Exception: pass
    return None

def _guess_log_for_checkpoint(pt_path: str, stem: str) -> Optional[str]:
    """
    Map:
      best_{stem}_val_loss.pt                  -> best_{stem}_val_loss_log.json
      best_{stem}_val_loss_dimension_256.pt    -> best_{stem}_val_loss_log_dimension_256.json
    """
    subdir = os.path.dirname(pt_path)
    cand1  = re.sub(r"(_val_loss)(.*)\.pt$", r"\1_log\2.json", pt_path)
    cand2  = os.path.join(subdir, f"best_{stem}_val_loss_log.json")
    for c in (cand1, cand2):
        if c and os.path.isfile(c): return c
    return None

def scan_img_checkpoints_by_mcc(cfg_scan: "RunConfig") -> pd.DataFrame:
    assert cfg_scan.modality == MOD.img, "scan_img_checkpoints_by_mcc expects MOD.img"
    stem = _ckpt_stem(cfg_scan)
    patterns = [
        os.path.join(cfg_scan.out_dir, "**", f"best_{stem}_val_loss.pt"), # legacy (preferred now)
        os.path.join(cfg_scan.out_dir, "**", f"best_{stem}_val_loss_dimension_*.pt"), # older runs
    ]
    paths: List[str] = []
    for pat in patterns:
        paths.extend(glob.glob(pat, recursive=True))
    paths = sorted(set(paths))

    rows = []
    for pt in paths:
        subdir   = os.path.dirname(pt)
        log_json = _guess_log_for_checkpoint(pt, stem)
        val_mcc  = _read_val_mcc_from_log(log_json) if log_json else None
        mtime    = os.path.getmtime(pt) if os.path.isfile(pt) else 0.0
        rows.append({
            "val_mcc": (None if val_mcc is None else float(val_mcc)),
            "ckpt_path": pt, "log_json": log_json,
            "mtime": mtime, "subdir": subdir
        })
    df = pd.DataFrame(rows)
    if df.empty:
        raise FileNotFoundError(
            f"No checkpoints found under '{cfg_scan.out_dir}'. Tried: {patterns}"
        )
    return df

def select_best_by_mcc(df: pd.DataFrame) -> pd.Series:
    """
    Highest non-negative MCC; ties → newest mtime.
    If none have a valid MCC, fallback to newest mtime.
    """
    df2 = df.copy()
    df2["val_mcc"] = pd.to_numeric(df2["val_mcc"], errors="coerce")
    ok = df2[df2["val_mcc"].notna() & (df2["val_mcc"] >= 0.0)].copy()
    if len(ok) > 0:
        ok = ok.sort_values(by=["val_mcc","mtime"], ascending=[False, False])
        return ok.iloc[0]
    return df2.sort_values(by=["mtime"], ascending=[False]).iloc[0]

# -------------------- Model utilities (IMG ONLY) -------------------- #
def print_model_overview_for_cam(model: nn.Module, img_size: int = 224) -> None:
    total = sum(p.numel() for p in model.parameters())
    train = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.info("Grad-CAM — Model Overview")
    log.info(f"Device: {device.type}")
    log.info(f"Input size (C,H,W): (3, {img_size}, {img_size})")
    log.info(f"Total params: {total:,}")
    log.info(f"Trainable params: {train:,}")
    log.info(f"Frozen params: {total-train:,}")

def _kernel_tuple(m: nn.Conv2d) -> Tuple[int,int]:
    k = m.kernel_size
    return (k if isinstance(k, tuple) else (k, k))

def _last_conv_kgt1(module: nn.Module) -> Optional[nn.Conv2d]:
    last_any = None; last_kgt1 = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            last_any = m
            if max(_kernel_tuple(m)) > 1:
                last_kgt1 = m
    return last_kgt1 if last_kgt1 is not None else last_any

def _vit_patch_embed_conv(module: nn.Module) -> Optional[nn.Conv2d]:
    target = None
    for name, m in module.named_modules():
        if isinstance(m, nn.Conv2d) and ("patch_embed" in name or "patch" in name):
            target = m
    return target

def _qualname_of_module(root: nn.Module, target: nn.Module) -> str:
    for n, m in root.named_modules():
        if m is target:
            return n or "<root>"
    return "<unknown>"

def select_target_layer_for_cam(model: nn.Module) -> nn.Conv2d:
    enc = getattr(model, "img_enc", None)
    if enc is not None:
        p3 = getattr(enc, "post3x3", None)
        if isinstance(p3, nn.Sequential):
            for m in p3.modules():
                if isinstance(m, nn.Conv2d) and _kernel_tuple(m) == (3,3):
                    return m
        elif isinstance(p3, nn.Conv2d) and _kernel_tuple(p3) == (3,3):
            return p3
        tgt = _last_conv_kgt1(enc)
        if isinstance(tgt, nn.Conv2d): return tgt
        vitp = _vit_patch_embed_conv(enc)
        if isinstance(vitp, nn.Conv2d): return vitp
    tgt = _last_conv_kgt1(model)
    if isinstance(tgt, nn.Conv2d): return tgt
    raise RuntimeError("No suitable Conv2d layer found for Grad-CAM.")

class CamImageWrapper(nn.Module):
    def __init__(self, model_img_only: nn.Module):
        super().__init__()
        self.mm = model_img_only
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mm.forward_image(x)

def build_preprocess(size: int) -> T.Compose:
    return T.Compose([
        T.ToPILImage(), T.Resize((size, size)),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])

@torch.no_grad()
def _load_ckpt_model(cfg_load: "RunConfig", ckpt_path: str) -> "ImgOrTxtClassifier":
    m = ImgOrTxtClassifier(cfg_load).to(device).eval()
    state = torch.load(ckpt_path, map_location="cpu")
    m.load_state_dict(state, strict=True)
    return m

def _init_cam(CAMClass, model, target_layers):
    sig = inspect.signature(CAMClass.__init__)
    if "use_cuda" in sig.parameters:
        return CAMClass(model=model, target_layers=target_layers, use_cuda=(device.type == "cuda"))
    return CAMClass(model=model, target_layers=target_layers)

# CSV Reader
def _is_hdr(s: str) -> bool:
    s = (str(s) or "").strip().lower()
    return s in {"img","image","filename","file","path","label","class","target","y"} or s.startswith("img")

def _read_label_csv(csv_path: str, tag: str) -> pd.DataFrame:
    if not os.path.isfile(csv_path):
        log.warning(f"[CSV] Missing {tag} CSV at: {csv_path}")
        return pd.DataFrame(columns=["img", "label"])
    try:
        df = pd.read_csv(csv_path, header=None, sep=None, engine="python", encoding="utf-8-sig")
    except Exception:
        df = pd.read_csv(csv_path, header=None, encoding="utf-8-sig")
    if df.shape[1] < 2:
        raise ValueError(f"{tag} CSV must have >= 2 columns (filename, label): {csv_path}")
    df = df.iloc[:, :2].copy(); df.columns = ["img", "label"]
    if len(df) and _is_hdr(df.iloc[0,0]) and _is_hdr(df.iloc[0,1]):
        df = df.iloc[1:].reset_index(drop=True)
    df["img"]   = df["img"].astype(str).map(lambda s: os.path.basename(s.strip()))
    df["label"] = pd.to_numeric(df["label"], errors="coerce").fillna(0).astype(int)
    df = df[df["img"].map(_is_img)].reset_index(drop=True)
    log.info(f"[CSV] {tag}: {len(df)} rows after clean → {csv_path}")
    return df

def _read_splits(test_csv: str) -> pd.DataFrame:
    dfe = _read_label_csv(test_csv,  "test")
    before = len(dfe)
    dfe = dfe.drop_duplicates(subset=["img"], keep="first").reset_index(drop=True)
    after = len(dfe)
    if after < before:
        log.info(f"[CSV] Dedup: {before} → {after} unique basenames")
    return dfe
    
# Build YOLO txts from grayscale masks
def _compute_yolo_from_mask(mask_gray: np.ndarray, W: int, H: int,
                            min_area_px: int = 4) -> List[Tuple[float,float,float,float]]:
    if mask_gray.max() > 1:
        _, binm = cv2.threshold(mask_gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    else:
        binm = (mask_gray > 0).astype(np.uint8) * 255
    k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
    binm = cv2.morphologyEx(binm, cv2.MORPH_OPEN,  k, iterations=1)
    binm = cv2.morphologyEx(binm, cv2.MORPH_CLOSE, k, iterations=1)
    cnts, _ = cv2.findContours(binm, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    boxes: List[Tuple[float,float,float,float]] = []
    for c in cnts:
        x, y, w, h = cv2.boundingRect(c)
        if w*h < max(min_area_px, int(0.00005 * W * H)):
            continue
        cx = (x + w/2.0) / float(W)
        cy = (y + h/2.0) / float(H)
        ww =  w / float(W)
        hh =  h / float(H)
        cx = min(max(cx, 0.0), 1.0); cy = min(max(cy, 0.0), 1.0)
        ww = min(max(ww, 0.0), 1.0); hh = min(max(hh, 0.0), 1.0)
        if ww > 0 and hh > 0:
            boxes.append((cx, cy, ww, hh))
    return boxes

def build_yolo_txts_from_masks(orig_dir: str, mask_dir: str, out_dir: str, strict_match: bool = True) -> None:
    _ensure_dir(out_dir)
    orig_files = _list_images(orig_dir)
    mask_files = _list_images(mask_dir)
    log.info(f"[YOLO-from-mask] Candidates — orig={len(orig_files)} | mask={len(mask_files)} (after junk filtering)")

    def _to_base_map(paths: List[str], tag: str) -> Dict[str, str]:
        m: Dict[str, str] = {}
        dups: Dict[str, List[str]] = {}
        for p in paths:
            b = _basename_no_ext(p)
            if b in m:
                dups.setdefault(b, []).append(p)
                continue
            m[b] = p
        if dups:
            ex = [(k, [os.path.basename(x) for x in v[:2]]) for k, v in list(dups.items())[:3]]
            log.warning(f"[{tag}] {len(dups)} duplicate basenames detected (keeping first). Examples: {ex}")
        return m

    orig_by_base: Dict[str, str] = { _basename_no_ext(p): os.path.basename(p) for p in orig_files }
    mask_by_base: Dict[str, str] = _to_base_map(mask_files, tag="MASK")

    orig_bases = set(orig_by_base.keys())
    mask_bases = set(mask_by_base.keys())
    missing_in_mask = sorted(orig_bases - mask_bases)
    missing_in_orig = sorted(mask_bases - orig_bases)

    if missing_in_mask or missing_in_orig:
        msg_parts = []
        if missing_in_mask:
            msg_parts.append(f"Missing masks for {len(missing_in_mask)} images (e.g., {missing_in_mask[:5]})")
        if missing_in_orig:
            msg_parts.append(f"Extra masks without originals: {len(missing_in_orig)} (e.g., {missing_in_orig[:5]})")
        msg = " | ".join(msg_parts)
        if strict_match:
            raise ValueError(f"[YOLO-from-mask] Filename sets must match after junk-filtering. {msg}")
        else:
            log.warning(f"[YOLO-from-mask] {msg} — proceeding with intersection.")

    bases = sorted(orig_bases & mask_bases) if (missing_in_mask or missing_in_orig) else sorted(orig_bases)

    n_written, n_empty = 0, 0
    for b in bases:
        img_name = orig_by_base[b]  # keep extension; write "<name>.<ext>.txt"
        img_path = os.path.join(orig_dir, img_name)
        mask_path = mask_by_base[b]

        img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
        if img is None:
            log.warning(f"[YOLO-from-mask] Skipping (orig not readable): {img_path}")
            continue
        H, W = img.shape[:2]

        m = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if m is None:
            log.warning(f"[YOLO-from-mask] Skipping (mask not readable): {mask_path}")
            continue

        boxes = _compute_yolo_from_mask(m, W, H)
        out_txt = os.path.join(out_dir, img_name + ".txt")
        with open(out_txt, "w", encoding="utf-8") as f:
            if not boxes:
                n_empty += 1
            for (cx, cy, ww, hh) in boxes:
                f.write(f"0 {cx:.6f} {cy:.6f} {ww:.6f} {hh:.6f}\n")
        n_written += 1

    log.info(f"[YOLO-from-mask] Wrote {n_written} .txt files to: {out_dir}  (empty={n_empty})")

# Phase B: Load-scaled GT boxes (to 1024)
def _load_original_size(name: str) -> Tuple[np.ndarray, int, int]:
    p = os.path.join(ORIG_DIR, name)
    bgr = cv2.imread(p)
    if bgr is None:
        raise FileNotFoundError(f"Original image not found: {p}")
    H, W = bgr.shape[:2]
    return bgr, W, H

def _load_resized_1024(name: str) -> Tuple[Optional[np.ndarray], bool]:
    p = os.path.join(RESIZED_1024_DIR, name)
    bgr = cv2.imread(p)
    if bgr is None:
        return None, False
    return bgr, True

def _ensure_1024_base_image(name: str) -> Tuple[np.ndarray, int, int, int, int]:
    bgr_1024, ok = _load_resized_1024(name)
    orig_bgr, orig_W, orig_H = _load_original_size(name)
    if not ok:
        base_bgr = cv2.resize(orig_bgr, (RESIZE_TO, RESIZE_TO), interpolation=cv2.INTER_AREA)
    else:
        H1, W1 = bgr_1024.shape[:2]
        base_bgr = bgr_1024 if (H1 == RESIZE_TO and W1 == RESIZE_TO) else cv2.resize(bgr_1024, (RESIZE_TO, RESIZE_TO), interpolation=cv2.INTER_AREA)
    return base_bgr, RESIZE_TO, RESIZE_TO, orig_W, orig_H

def _read_yolo_gt_scaled_to_1024(name: str, orig_W: int, orig_H: int) -> List[Tuple[int,int,int,int]]:
    """
    Read YOLO txt created from masks (normalized to ORIGINAL WxH) and scale to 1024×1024 pixels.
    The txt filename is "<original_image_name>.<ext>.txt".
    """
    txt_path = os.path.join(YOLO_TXT_DIR, name + ".txt")
    if not os.path.isfile(txt_path):
        return []
    boxes = []
    with open(txt_path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            line = line.strip()
            if not line: continue
            toks = re.split(r"[,\s]+", line)
            if len(toks) < 5: continue
            try:
                _ = int(float(toks[0]))  # class id (unused)
                cx = float(toks[1]); cy = float(toks[2])
                w  = float(toks[3]); h  = float(toks[4])
            except Exception:
                continue
            x1 = (cx - w/2.0) * orig_W
            y1 = (cy - h/2.0) * orig_H
            x2 = (cx + w/2.0) * orig_W
            y2 = (cy + h/2.0) * orig_H
            sx = RESIZE_TO / float(orig_W)
            sy = RESIZE_TO / float(orig_H)
            X1 = int(max(0, min(RESIZE_TO-1, round(x1 * sx))))
            Y1 = int(max(0, min(RESIZE_TO-1, round(y1 * sy))))
            X2 = int(max(0, min(RESIZE_TO-1, round(x2 * sx))))
            Y2 = int(max(0, min(RESIZE_TO-1, round(y2 * sy))))
            if X2 > X1 and Y2 > Y1:
                boxes.append((X1, Y1, X2, Y2))
    return boxes

# Main Grad-CAM (image-only; GT overlay)
def run_gradcam_img_with_gt(
    cfg_cam: "RunConfig",
    ckpt_path: str,
    dfe: pd.DataFrame,
    save_root_overlap: str,
    cam_method: str = CAM_METHOD,
    heatmap_alpha: float = HEATMAP_ALPHA,
    bin_thr: float = BIN_THR,
) -> None:
    assert os.path.isfile(ckpt_path), f"Checkpoint not found: {ckpt_path}"
    os.makedirs(save_root_overlap, exist_ok=True)
    out_dirs = {
        "images":   os.path.join(save_root_overlap, "images"),
        "heatmaps": os.path.join(save_root_overlap, "heatmaps"),
        "contours": os.path.join(save_root_overlap, "contours"),
        "bboxes":   os.path.join(save_root_overlap, "bboxes"),
    }
    for d in out_dirs.values():
        os.makedirs(d, exist_ok=True)

    # Load model & CAM
    model_img = _load_ckpt_model(cfg_cam, ckpt_path)
    print_model_overview_for_cam(model_img, img_size=int(getattr(cfg_cam, "img_size", 224)))
    cam_model = CamImageWrapper(model_img).to(device).eval()
    target_layer = select_target_layer_for_cam(model_img)
    for p in target_layer.parameters():
        p.requires_grad_(True)
    tl_name = _qualname_of_module(model_img, target_layer)
    log.info(f"[CAM] Target layer: {tl_name}  kernel={getattr(target_layer, 'kernel_size', None)}")

    methods = {
        "gradcam": GradCAM, "gradcam++": GradCAMPlusPlus, "hirescam": HiResCAM,
        "xgradcam": XGradCAM, "layercam": LayerCAM, "eigencam": EigenCAM,
        "eigengradcam": EigenGradCAM, "scorecam": ScoreCAM, "ablationcam": AblationCAM
    }
    CAMClass = methods[cam_method.lower()]
    cam = _init_cam(CAMClass, model=cam_model, target_layers=[target_layer])
    try: cam.batch_size = 1
    except Exception: pass

    preprocess = build_preprocess(int(getattr(cfg_cam, "img_size", 224)))

    # Iterate over UNION of CSVs (all unique images)
    for _, row in dfe.iterrows():
        name = os.path.basename(str(row["img"]).strip())
        if not _is_img(name): continue

        # Build 1024 base + get original size for GT scaling
        try:
            base_bgr, W_vis, H_vis, orig_W, orig_H = _ensure_1024_base_image(name)
        except FileNotFoundError:
            log.warning(f"[CAM] Skipping (missing original and/or 1024 file): {name}")
            continue

        # Model input from 1024 base
        rgb = cv2.cvtColor(base_bgr, cv2.COLOR_BGR2RGB)
        x = preprocess(rgb).unsqueeze(0).to(device)

        # Forward Grad-CAM
        cam_model.zero_grad(set_to_none=True)
        with torch.enable_grad():
            if cam_method.lower() == "eigencam":
                cam_mask = cam(input_tensor=x)[0]
            else:
                cam_mask = cam(
                    input_tensor=x,
                    targets=[ClassifierOutputTarget(int(TARGET_CLASS_IDX))],
                    aug_smooth=True, eigen_smooth=True
                )[0]

        # Normalize & upsample to 1024
        mmin, mmax = float(np.min(cam_mask)), float(np.max(cam_mask))
        cam_mask = (cam_mask - mmin) / (mmax - mmin + 1e-8)
        mask_u8 = (np.clip(cv2.resize(cam_mask, (W_vis, H_vis), interpolation=cv2.INTER_NEAREST), 0, 1) * 255).astype(np.uint8)

        # Heatmap overlay
        heat = cv2.applyColorMap(mask_u8, cv2.COLORMAP_HOT)
        heat_overlay = cv2.addWeighted(heat, float(heatmap_alpha), base_bgr, 1.0 - float(heatmap_alpha), 0.0)

        # Contours (model areas)
        thr = int(255 * float(bin_thr))
        _, binm = cv2.threshold(mask_u8, thr, 255, cv2.THRESH_BINARY)
        contours,_ = cv2.findContours(binm, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Contour image (red)
        cont_img = base_bgr.copy()
        if len(contours) > 0:
            cv2.drawContours(cont_img, contours, -1, CONTOUR_COLOR, CONTOUR_THICK, lineType=cv2.LINE_AA)

        # Model boxes (blue)
        box_img = base_bgr.copy()
        model_boxes: List[Tuple[int,int,int,int]] = []
        for c in contours:
            x0, y0, w, h = cv2.boundingRect(c)
            x1, y1, x2, y2 = x0, y0, x0 + w, y0 + h
            model_boxes.append((x1, y1, x2, y2))
            cv2.rectangle(box_img, (x1, y1), (x2, y2), BB_MODEL_COLOR, BB_THICK, lineType=cv2.LINE_AA)

        # GT boxes scaled to 1024 (red)
        gt_boxes_1024 = _read_yolo_gt_scaled_to_1024(name, orig_W, orig_H)

        # Save base image
        stem = os.path.splitext(name)[0]
        cv2.imwrite(os.path.join(out_dirs["images"], f"{stem}.png"), base_bgr)

        # Heatmap + GT
        heat_ov = heat_overlay.copy()
        for (gx1, gy1, gx2, gy2) in gt_boxes_1024:
            cv2.rectangle(heat_ov, (gx1, gy1), (gx2, gy2), BB_GT_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        cv2.imwrite(os.path.join(out_dirs["heatmaps"], f"{stem}__{cfg_cam.network.name}__{cam_method}.png"), heat_ov)

        # Contours + GT
        cont_ov = cont_img.copy()
        for (gx1, gy1, gx2, gy2) in gt_boxes_1024:
            cv2.rectangle(cont_ov, (gx1, gy1), (gx2, gy2), BB_GT_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        cv2.imwrite(os.path.join(out_dirs["contours"], f"{stem}__{cfg_cam.network.name}__{cam_method}.png"), cont_ov)

        # BBoxes view: model (blue) + GT (red)
        box_ov = base_bgr.copy()
        for (x1, y1, x2, y2) in model_boxes:
            cv2.rectangle(box_ov, (x1, y1), (x2, y2), BB_MODEL_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        for (gx1, gy1, gx2, gy2) in gt_boxes_1024:
            cv2.rectangle(box_ov, (gx1, gy1), (gx2, gy2), BB_GT_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        cv2.imwrite(os.path.join(out_dirs["bboxes"], f"{stem}__{cfg_cam.network.name}__{cam_method}.png"), box_ov)

    # Cleanup
    try:
        if hasattr(cam, "activations_and_grads") and (cam.activations_and_grads is not None):
            cam.activations_and_grads.release()
    except Exception:
        pass
    del cam, cam_model, model_img
    if device.type == "cuda":
        torch.cuda.empty_cache()
    log.info(f"[CAM] Saved to: {save_root_overlap}")

# Scan → Select best (by val-MCC) → Build GT → Run CAM 
# 1) Scan and list image-only candidates
cfg_scan = cfg  # already MOD.img
df_found = scan_img_checkpoints_by_mcc(cfg_scan)

def _fmt(x):
    if x is None or (isinstance(x, float) and math.isnan(x)): return "—"
    return f"{x:.4f}" if isinstance(x, float) else str(x)

# Only show non-negative, valid MCCs
df_show = df_found.copy()
df_show["val_mcc"] = pd.to_numeric(df_show["val_mcc"], errors="coerce")
df_show = df_show[df_show["val_mcc"].notna() & (df_show["val_mcc"] >= 0.0)]

print("\n================ Available image-only checkpoints (sorted by folder) ================")
if len(df_show) == 0:
    print("• (no checkpoints with a valid non-negative MCC found; selection will fallback to newest mtime)")
else:
    for _, r in df_show.sort_values("subdir").iterrows():
        print(f"• {r['subdir']}")
        print(f"    val_MCC={_fmt(float(r['val_mcc']))}")
        print(f"    ckpt={r['ckpt_path']}")
        if r['log_json']: print(f"    log ={r['log_json']}")
print("======================================================================================\n")

row_best = select_best_by_mcc(df_found)
best_ckpt = row_best["ckpt_path"]
best_subdir = row_best["subdir"]
best_mcc = row_best["val_mcc"]

reason = f"highest validation MCC={_fmt(best_mcc)}"
if best_mcc is None or (isinstance(best_mcc, float) and math.isnan(best_mcc)):
    reason = "no MCC recorded; selecting most recent checkpoint"

print(">>> Selecting checkpoint:")
print(f" {best_ckpt}")
print(f" Because: {reason}\n")

# 2) Build YOLO GT txts from masks (once)
log.info("================ Precompute YOLO from Masks ================")
log.info(f"Original dir : {ORIG_DIR}")
log.info(f"Mask dir     : {MASK_DIR}")
log.info(f"YOLO out dir : {YOLO_TXT_DIR}")
if AUTO_CLEAN:
    stats1 = clean_tree_of_junk(ORIG_DIR)
    stats2 = clean_tree_of_junk(MASK_DIR)
    log.info(f"[Clean ORIG] removed={stats1['removed']} skipped={stats1['skipped']}")
    log.info(f"[Clean MASK] removed={stats2['removed']} skipped={stats2['skipped']}")

build_yolo_txts_from_masks(ORIG_DIR, MASK_DIR, YOLO_TXT_DIR, strict_match=True)

# 3) Read all splits for CAM
dfe = _read_splits(cfg.csv_test)
log.info(f"[CSV] test: {len(dfe)} unique images")

# 4) Run Grad-CAM with GT overlay; outputs beside the chosen run
save_root_overlap = os.path.join(best_subdir, SAVE_ROOT_NAME)
run_gradcam_img_with_gt(
    cfg_cam=cfg,
    ckpt_path=best_ckpt,
    dfe=dfe,
    save_root_overlap=save_root_overlap,
    cam_method=CAM_METHOD,
    heatmap_alpha=HEATMAP_ALPHA,
    bin_thr=BIN_THR,
)

### Computing GRAD-CAM visualizations with the external TBX11K test set

1. Annotations (512×512) → pre-crop (256×256); Scale the authors’ ground truth boxes by 0.5.
2. Pre-crop (256×256) → lung-crop: Use the provided 256×256 mask to recover the actual crop ROI (the tight bounding rectangle of non-zero mask, optionally with a small margin). Intersect the pre-crop box with this ROI and translate it to the crop coordinate frame.
3. Lung-crop → final saved TB image (256×256): The crop ROI is then resized to 256×256 when you saved the lung-cropped images. So scale the translated box by: sx = 256 / (roi_w),  sy = 256 / (roi_h)
4. Use floor for x1,y1 and ceil for x2,y2 to preserve coverage; clip to [0,255].
5. Run Grad-CAM on those TB images and overlays blue model CAM boxes + red GT boxes.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Logging
logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] %(levelname)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger("gradcam-tbx11k-img")

# Reproducibility
def set_deterministic(seed: int = 1337) -> None:
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)  # no-op on CPU
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_deterministic(1337)

# User config
cfg = RunConfig(
    modality=MOD.img,                 
    network=NETWORK.vgg11,           
    n_classes=2,
    hidden_dim=256,
    epochs=64,
    lr=5e-5,
    weight_decay=1e-4,
    img_size=224,                 
    batch_size=64,
    num_workers=2,
    pin_memory=True,
    dataset_root="/dataset",   
    csv_test="/dataset/label_test_tbx11k.csv",  # TBX11K test CSV
    images_subdir="tbx11k/all",  
    out_dir="/unimodal/models_img_vgg11",  # image-only runs root
    class_names=("normal","tb"),
    hf_offline=True,
    hf_local_dir=None,
)

# TBX11K dataset specific paths
# Final lung-cropped TB images (256×256)
TB_256_DIR = "/dataset/tbx11k/tb"
# 256×256 masks aligned to the pre-crop 256×256 space
TB_MASKS_256_DIR = "/dataset/tbx11k/masks/tb"
# Authors' lesion boxes (512×512 coords)
BBOX_CSV_PATH = "/dataset/data_tbx11k.csv"
# Output folder name under the chosen run
SAVE_ROOT_NAME   = f"gradcam_tbx11k_{cfg.network.name}_{cfg.modality.name}_test_overlap"
# Grad-CAM & drawing params
CAM_METHOD  = "gradcam"   # "gradcam", "gradcam++", "xgradcam", "layercam", ...
HEATMAP_ALPHA = 0.5
BIN_THR = 0.4         # CAM threshold for contours/bboxes

# Colors / thickness (BGR)
CONTOUR_COLOR = (0, 0, 255) # red contours
CONTOUR_THICK = 2
BB_MODEL_COLOR = (255, 0, 0) # blue model boxes
BB_GT_COLOR = (0, 0, 255) # red GT boxes
BB_THICK = 3
TARGET_CLASS_IDX = 1
FINAL_SIZE = 256
IMG_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
# Optional crop behavior 
ROI_MARGIN_FRAC = 0.00        # e.g., 0.02 to add 2% per side
ROI_ENFORCE_SQUARE = False    # set True if padded to square before resizing

# Utility: paths, junk filtering
JUNK_DIR_TOKENS = {"__pycache__", ".ipynb_checkpoints", ".git", ".svn", ".DS_Store"}
JUNK_FILE_BASENAMES = {"thumbs.db", "desktop.ini"}
_CHECKPOINT_RE = re.compile(r"(?i)(?:^|[^A-Za-z0-9])checkpoint(?:$|[^A-Za-z0-9])")
_DOTFILE_RE = re.compile(r"^\.")

def _is_img(path_or_name: str) -> bool:
    return os.path.splitext(path_or_name)[1].lower() in IMG_EXTS

def _is_junk_path(path: str) -> bool:
    p = pathlib.Path(path)
    for part in p.parts:
        if part in JUNK_DIR_TOKENS or _DOTFILE_RE.search(part):
            return True
    base = p.name
    if base.lower() in JUNK_FILE_BASENAMES:
        return True
    if _DOTFILE_RE.match(base) or _CHECKPOINT_RE.search(os.path.splitext(base)[0]):
        return True
    return False

def _list_images(dir_path: str) -> List[str]:
    out = []
    for root, dirs, files in os.walk(dir_path):
        dirs[:] = [d for d in dirs if not _is_junk_path(os.path.join(root, d))]
        for fn in files:
            fp = os.path.join(root, fn)
            if not _is_junk_path(fp) and _is_img(fn):
                out.append(fp)
    return sorted(out)

def _find_by_basename(dir_path: str, basename: str) -> Optional[str]:
    base, _ = os.path.splitext(basename)
    for ext in [".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"]:
        p = os.path.join(dir_path, base + ext)
        if os.path.isfile(p):
            return p
    return None

# Checkpoint scanning 
def _ckpt_stem(cfg_scan: "RunConfig") -> str:
    return f"{cfg_scan.network.name}_{cfg_scan.modality.name}"  # e.g., vgg11_img

def _read_val_mcc_from_log(log_json_path: str) -> Optional[float]:
    try:
        with open(log_json_path, "r") as f:
            meta = json.load(f)
    except Exception:
        return None
    # direct keys (also allow *_MCC)
    for k in ("best_val_mcc","best_val_MCC","val_mcc","val_MCC","best_mcc","best_MCC"):
        v = meta.get(k, None)
        if isinstance(v, (int, float)): return float(v)
    # nested dicts
    for nk in ("valid_metrics","validation_metrics","val_metrics","best_metrics","best_valid_metrics"):
        d = meta.get(nk, None)
        if isinstance(d, dict):
            for cand in ("MCC","mcc","val_MCC","val_mcc"):
                v = d.get(cand, None)
                if isinstance(v, (int, float)): return float(v)
    # history → max
    hist = meta.get("history", None)
    if isinstance(hist, dict):
        for hk in ("val_mcc","val_MCC"):
            arr = hist.get(hk, None)
            if isinstance(arr, list) and len(arr) > 0:
                try: return float(np.nanmax(np.asarray(arr, dtype=float)))
                except Exception: pass
    return None

def _guess_log_for_checkpoint(pt_path: str, stem: str) -> Optional[str]:
    """
    Map:
      best_{stem}_val_loss.pt               -> best_{stem}_val_loss_log.json
      best_{stem}_val_loss_dimension_256.pt -> best_{stem}_val_loss_log_dimension_256.json
    """
    subdir = os.path.dirname(pt_path)
    cand1  = re.sub(r"(_val_loss)(.*)\.pt$", r"\1_log\2.json", pt_path)
    cand2  = os.path.join(subdir, f"best_{stem}_val_loss_log.json")
    for c in (cand1, cand2):
        if c and os.path.isfile(c):
            return c
    return None

def scan_img_checkpoints_by_mcc(cfg_scan: "RunConfig") -> pd.DataFrame:
    assert cfg_scan.modality == MOD.img, "scan_img_checkpoints_by_mcc expects MOD.img"
    stem = _ckpt_stem(cfg_scan)
    patterns = [
        os.path.join(cfg_scan.out_dir, "**", f"best_{stem}_val_loss.pt"), 
        os.path.join(cfg_scan.out_dir, "**", f"best_{stem}_val_loss_dimension_*.pt"), 
    ]
    paths: List[str] = []
    for pat in patterns:
        paths.extend(glob.glob(pat, recursive=True))
    paths = sorted(set(paths))
    if not paths:
        raise FileNotFoundError(f"No checkpoints found under '{cfg_scan.out_dir}'. Tried: {patterns}")

    rows = []
    for pt in paths:
        subdir = os.path.dirname(pt)
        log_json = _guess_log_for_checkpoint(pt, stem)
        val_mcc = _read_val_mcc_from_log(log_json) if log_json else None
        mtime = os.path.getmtime(pt) if os.path.isfile(pt) else 0.0
        rows.append({
            "val_mcc": (None if val_mcc is None else float(val_mcc)),
            "ckpt_path": pt, "log_json": log_json,
            "mtime": mtime, "subdir": subdir
        })
    return pd.DataFrame(rows)

def select_best_by_mcc(df: pd.DataFrame) -> pd.Series:
    """
    Highest non-negative MCC; ties → newest mtime.
    If none valid → newest by mtime.
    """
    df2 = df.copy()
    df2["val_mcc"] = pd.to_numeric(df2["val_mcc"], errors="coerce")
    ok = df2[df2["val_mcc"].notna() & (df2["val_mcc"] >= 0.0)].copy()
    if len(ok) > 0:
        ok = ok.sort_values(by=["val_mcc","mtime"], ascending=[False, False])
        return ok.iloc[0]
    return df2.sort_values(by=["mtime"], ascending=[False]).iloc[0]

# TBX11K CSV helpers
def _read_test_csv(csv_path: str) -> pd.DataFrame:
    try:
        df = pd.read_csv(csv_path, header=None, sep=None, engine="python", encoding="utf-8-sig")
    except Exception:
        df = pd.read_csv(csv_path, header=None, encoding="utf-8-sig")
    if df.shape[1] < 2:
        raise ValueError("Test CSV must have >= 2 columns (filename, label).")
    df = df.iloc[:, :2]; df.columns = ["img", "label"]

    def _is_hdr(s: str) -> bool:
        s = str(s).strip().lower()
        return s in {"img","image","filename","file","path","label","class","target","y"} or s.startswith("img")

    if _is_hdr(df.iloc[0,0]) and _is_hdr(df.iloc[0,1]):
        df = df.iloc[1:].reset_index(drop=True)

    df["img"]   = df["img"].astype(str).map(lambda s: os.path.basename(s.strip()))
    df["label"] = pd.to_numeric(df["label"], errors="coerce").fillna(0).astype(int)
    df = df[df["img"].map(_is_img)].reset_index(drop=True)
    return df

def crosscheck_tb(csv_tb_names: List[str], tb_dir: str) -> List[str]:
    csv_tb_set = {n for n in csv_tb_names if n.lower().startswith("tb")}
    dir_tb_set = {os.path.basename(p) for p in _list_images(tb_dir) if os.path.basename(p).lower().startswith("tb")}
    inter = sorted(csv_tb_set & dir_tb_set)
    log.info("[TB CHECK] CSV tb count : %d", len(csv_tb_set))
    log.info("[TB CHECK] Dir tb count : %d", len(dir_tb_set))
    log.info("[TB CHECK] Match count  : %d", len(inter))
    only_csv = sorted(csv_tb_set - dir_tb_set)
    only_dir = sorted(dir_tb_set - csv_tb_set)
    if only_csv or only_dir:
        log.warning("[TB CHECK] Mismatch — only_in_csv=%d, only_in_dir=%d", len(only_csv), len(only_dir))
        if only_csv[:5]: log.warning("  e.g., only_in_csv: %s", only_csv[:5])
        if only_dir[:5]: log.warning("  e.g., only_in_dir: %s", only_dir[:5])
    else:
        log.info("[TB CHECK] Filename sets MATCH ✅")
    return inter

# GT mapping: authors (512×512) → final 256×256 via mask ROI
def _parse_bbox_field(braw: Any) -> List[Dict[str, float]]:
    """'bbox' can be list[dict] or dict-as-string; return list of dicts with xmin,ymin,width,height."""
    if braw is None or (isinstance(braw, float) and math.isnan(braw)): return []
    try:
        b = ast.literal_eval(braw) if isinstance(braw, str) else braw
    except Exception:
        return []
    if isinstance(b, dict):  return [b]
    if isinstance(b, list):  return [x for x in b if isinstance(x, dict)]
    return []

def _mask_to_roi(mask_256: np.ndarray, margin_frac: float = ROI_MARGIN_FRAC,
                 enforce_square: bool = ROI_ENFORCE_SQUARE) -> Optional[Tuple[int,int,int,int]]:
    """Return (x1,y1,x2,y2) tight ROI of non-zero mask with optional margin/square padding."""
    if mask_256 is None: return None
    if mask_256.max() > 1:
        _, binm = cv2.threshold(mask_256, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    else:
        binm = (mask_256 > 0).astype(np.uint8)*255
    cnts, _ = cv2.findContours(binm, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not cnts: 
        return None
    x1 = min(cv2.boundingRect(c)[0] for c in cnts)
    y1 = min(cv2.boundingRect(c)[1] for c in cnts)
    x2 = max(cv2.boundingRect(c)[0] + cv2.boundingRect(c)[2] for c in cnts)
    y2 = max(cv2.boundingRect(c)[1] + cv2.boundingRect(c)[3] for c in cnts)
    # margin
    if margin_frac and margin_frac > 0:
        w = x2 - x1; h = y2 - y1
        mx = int(round(w * margin_frac)); my = int(round(h * margin_frac))
        x1 -= mx; y1 -= my; x2 += mx; y2 += my
    # enforce square (optional)
    if enforce_square:
        w = x2 - x1; h = y2 - y1
        side = max(w, h)
        cx = (x1 + x2) // 2; cy = (y1 + y2) // 2
        x1 = cx - side//2; y1 = cy - side//2; x2 = x1 + side; y2 = y1 + side
    # clip to [0,255]
    x1 = max(0, min(255, x1)); y1 = max(0, min(255, y1))
    x2 = max(1, min(256, x2)); y2 = max(1, min(256, y2))
    if x2 <= x1 or y2 <= y1:
        return None
    return (x1, y1, x2, y2)

def _scale_512_to_256_box(xmin: float, ymin: float, width: float, height: float,
                          src_w: int = 512, src_h: int = 512) -> Tuple[float,float,float,float]:
    """Return (x1,y1,x2,y2) in pre-crop 256×256 space."""
    sx = 256.0 / float(src_w); sy = 256.0 / float(src_h)
    x1 = (xmin) * sx; y1 = (ymin) * sy
    x2 = (xmin + width) * sx; y2 = (ymin + height) * sy
    return x1, y1, x2, y2

def _map_box_pre256_to_final256_via_roi(box_pre: Tuple[float,float,float,float],
                                        roi: Tuple[int,int,int,int]) -> Optional[Tuple[int,int,int,int]]:
    """Apply ROI crop (intersect) then scale to 256 final image; floor/ceil to preserve coverage."""
    x1, y1, x2, y2 = box_pre
    rx1, ry1, rx2, ry2 = roi
    ix1 = max(x1, rx1); iy1 = max(y1, ry1)
    ix2 = min(x2, rx2); iy2 = min(y2, ry2)
    if ix2 <= ix1 or iy2 <= iy1:
        return None
    # translate to crop coords
    cx1 = ix1 - rx1; cy1 = iy1 - ry1
    cx2 = ix2 - rx1; cy2 = iy2 - ry1
    rw = max(1, rx2 - rx1); rh = max(1, ry2 - ry1)
    sx = 256.0 / float(rw); sy = 256.0 / float(rh)
    fx1 = math.floor(cx1 * sx); fy1 = math.floor(cy1 * sy)
    fx2 = math.ceil (cx2 * sx); fy2 = math.ceil (cy2 * sy)
    # clip to [0,255], enforce non-zero size
    fx1 = max(0, min(255, fx1)); fy1 = max(0, min(255, fy1))
    fx2 = max(fx1+1, min(256, fx2)); fy2 = max(fy1+1, min(256, fy2))
    return (int(fx1), int(fy1), int(fx2), int(fy2))

def build_gt_box_map_final256(bbox_csv_path: str, masks_dir_256: str,
                              names_final_256: List[str]) -> Dict[str, List[Tuple[int,int,int,int]]]:
    """
    For each TB image name in names_final_256:
      • Read its mask (256×256) to derive the crop ROI
      • Load authors’ boxes (512×512), scale to pre-crop 256×256
      • Apply ROI crop+resize mapping to final 256×256
    Output: dict[name] -> list of (x1,y1,x2,y2) in final 256×256 coords
    """
    df = pd.read_csv(bbox_csv_path)
    cols = {c.lower(): c for c in df.columns}
    needed = ["fname","image_height","image_width","bbox","target","image_type"]
    for k in needed:
        if k not in cols: raise ValueError(f"Column '{k}' missing in {bbox_csv_path}")

    df = df[
        (df[cols["image_type"]].astype(str).str.lower() == "tb") &
        (df[cols["target"]].astype(str).str.lower() == "tb")
    ].copy()

    # Collect authors' boxes keyed by basename
    authors_boxes_512: Dict[str, List[Tuple[float,float,float,float,int,int]]] = {}
    for _, r in df.iterrows():
        name = os.path.basename(str(r[cols["fname"]]).strip())
        boxes = _parse_bbox_field(r[cols["bbox"]])
        if not boxes: 
            continue
        W = int(r[cols["image_width"]]) if not pd.isna(r[cols["image_width"]]) else 512
        H = int(r[cols["image_height"]]) if not pd.isna(r[cols["image_height"]]) else 512
        W = 512 if W <= 0 else W; H = 512 if H <= 0 else H
        pre_list = authors_boxes_512.setdefault(name, [])
        for b in boxes:
            try:
                xmin = float(b.get("xmin",0.0)); ymin = float(b.get("ymin",0.0))
                width= float(b.get("width",0.0)); height= float(b.get("height",0.0))
                if width <= 0 or height <= 0: continue
                pre_list.append((xmin, ymin, width, height, W, H))
            except Exception:
                continue

    # Map to final 256×256 using the 256×256 mask ROI
    out: Dict[str, List[Tuple[int,int,int,int]]] = {}
    for name in names_final_256:
        base = os.path.splitext(name)[0]
        mask_path = _find_by_basename(masks_dir_256, name)
        if mask_path is None:
            log.warning("[GT MAP] Missing mask for %s — skipping GT boxes.", name)
            continue
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            log.warning("[GT MAP] Unreadable mask for %s — skipping.", name)
            continue
        roi = _mask_to_roi(mask, ROI_MARGIN_FRAC, ROI_ENFORCE_SQUARE)
        if roi is None:
            log.warning("[GT MAP] Empty ROI from mask for %s — skipping.", name)
            continue

        # Try exact filename; if not found, try any key with same stem (different ext)
        auth_list = authors_boxes_512.get(name, [])
        if not auth_list:
            for k, v in authors_boxes_512.items():
                if os.path.splitext(k)[0].lower() == base.lower():
                    auth_list = v
                    break
        if not auth_list:
            # no GT for this image — allowed
            continue

        final_boxes = []
        for (xmin, ymin, width, height, W, H) in auth_list:
            x1p, y1p, x2p, y2p = _scale_512_to_256_box(xmin, ymin, width, height, src_w=W, src_h=H)
            mapped = _map_box_pre256_to_final256_via_roi((x1p, y1p, x2p, y2p), roi)
            if mapped is not None:
                final_boxes.append(mapped)
        if final_boxes:
            out[name] = final_boxes
    return out

# Model / CAM helpers
def print_model_overview_for_cam(model: nn.Module, img_size: int = 224) -> None:
    total = sum(p.numel() for p in model.parameters())
    train = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.info("Grad-CAM — Model Overview")
    log.info(f"Device: {device.type}")
    log.info(f"Input size (C,H,W): (3, {img_size}, {img_size})")
    log.info(f"Total params: {total:,}")
    log.info(f"Trainable params: {train:,}")
    log.info(f"Frozen params: {total-train:,}")

def _kernel_tuple(m: nn.Conv2d) -> Tuple[int,int]:
    k = m.kernel_size
    return (k if isinstance(k, tuple) else (k, k))

def _last_conv_kgt1(module: nn.Module) -> Optional[nn.Conv2d]:
    last_any = None; last_kgt1 = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            last_any = m
            if max(_kernel_tuple(m)) > 1:
                last_kgt1 = m
    return last_kgt1 if last_kgt1 is not None else last_any

def _vit_patch_embed_conv(module: nn.Module) -> Optional[nn.Conv2d]:
    target = None
    for name, m in module.named_modules():
        if isinstance(m, nn.Conv2d) and ("patch_embed" in name or "patch" in name):
            target = m
    return target

def _qualname_of_module(root: nn.Module, target: nn.Module) -> str:
    for n, m in root.named_modules():
        if m is target:
            return n or "<root>"
    return "<unknown>"

def select_target_layer_for_cam(model: nn.Module) -> nn.Conv2d:
    enc = getattr(model, "img_enc", None)
    if enc is not None:
        p3 = getattr(enc, "post3x3", None)
        if isinstance(p3, nn.Sequential):
            for m in p3.modules():
                if isinstance(m, nn.Conv2d) and _kernel_tuple(m) == (3,3):
                    return m
        elif isinstance(p3, nn.Conv2d) and _kernel_tuple(p3) == (3,3):
            return p3
        tgt = _last_conv_kgt1(enc)
        if isinstance(tgt, nn.Conv2d): return tgt
        vitp = _vit_patch_embed_conv(enc)
        if isinstance(vitp, nn.Conv2d): return vitp
    tgt = _last_conv_kgt1(model)
    if isinstance(tgt, nn.Conv2d): return tgt
    raise RuntimeError("No suitable Conv2d layer found for Grad-CAM.")

class CamImageWrapper(nn.Module):
    def __init__(self, model_img_only: nn.Module):
        super().__init__()
        self.mm = model_img_only
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mm.forward_image(x)

def build_preprocess(size: int) -> T.Compose:
    return T.Compose([
        T.ToPILImage(), T.Resize((size, size)),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])

@torch.no_grad()
def _load_ckpt_model(cfg_load: "RunConfig", ckpt_path: str) -> "ImgOrTxtClassifier":
    m = ImgOrTxtClassifier(cfg_load).to(device).eval()
    state = torch.load(ckpt_path, map_location="cpu")
    m.load_state_dict(state, strict=True)
    return m

def _init_cam(CAMClass, model, target_layers):
    sig = inspect.signature(CAMClass.__init__)
    if "use_cuda" in sig.parameters:
        return CAMClass(model=model, target_layers=target_layers, use_cuda=(device.type == "cuda"))
    return CAMClass(model=model, target_layers=target_layers)

# Grad-CAM main (TBX11K with GT overlay)
def run_gradcam_tbx11k_img_with_gt(
    cfg_cam: "RunConfig",
    ckpt_path: str,
    csv_test_path: str,
    tb_256_dir: str,
    masks_256_dir: str,
    bbox_csv_path: str,
    save_root: str,
    cam_method: str = CAM_METHOD,
    heatmap_alpha: float = HEATMAP_ALPHA,
    bin_thr: float = BIN_THR,
) -> None:
    assert os.path.isfile(ckpt_path), f"Checkpoint not found: {ckpt_path}"
    for p in [csv_test_path, tb_256_dir, masks_256_dir, bbox_csv_path]:
        if not os.path.exists(p):
            raise FileNotFoundError(f"Required path missing: {p}")

    # Outputs
    os.makedirs(save_root, exist_ok=True)
    out_dirs = {
        "images": os.path.join(save_root, "images"),
        "heatmaps": os.path.join(save_root, "heatmaps"),
        "contours": os.path.join(save_root, "contours"),
        "bboxes": os.path.join(save_root, "bboxes"),
    }
    for d in out_dirs.values(): os.makedirs(d, exist_ok=True)

    # Model & CAM
    model_img = _load_ckpt_model(cfg_cam, ckpt_path)
    print_model_overview_for_cam(model_img, img_size=int(getattr(cfg_cam, "img_size", 224)))
    cam_model = CamImageWrapper(model_img).to(device).eval()
    target_layer = select_target_layer_for_cam(model_img)
    for p in target_layer.parameters(): p.requires_grad_(True)
    tl_name = _qualname_of_module(model_img, target_layer)
    log.info(f"[CAM] Target layer: {tl_name}  kernel={getattr(target_layer, 'kernel_size', None)}")

    methods = {
        "gradcam": GradCAM, "gradcam++": GradCAMPlusPlus, "hirescam": HiResCAM,
        "xgradcam": XGradCAM, "layercam": LayerCAM, "eigencam": EigenCAM,
        "eigengradcam": EigenGradCAM, "scorecam": ScoreCAM, "ablationcam": AblationCAM
    }
    mkey = cam_method.lower()
    if mkey not in methods:
        raise ValueError(f"Unknown CAM method '{cam_method}'")
    cam = _init_cam(methods[mkey], model=cam_model, target_layers=[target_layer])
    try: cam.batch_size = 1
    except Exception: pass
    preprocess = build_preprocess(int(getattr(cfg_cam, "img_size", 224)))

    # TB selection via CSV + directory cross-check
    df_all = _read_test_csv(csv_test_path)
    df_tb  = df_all[df_all["img"].str.lower().str.startswith("tb")].copy()
    tb_names = crosscheck_tb(sorted(df_tb["img"].astype(str).unique().tolist()), tb_256_dir)

    # Build GT box map: final 256×256 coords
    gt_boxes_map = build_gt_box_map_final256(bbox_csv_path, masks_256_dir, tb_names)

    # Iterate TB images
    for name in tb_names:
        img_path = _find_by_basename(tb_256_dir, name) or os.path.join(tb_256_dir, name)
        bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
        if bgr is None:
            log.warning("[CAM] Skipping unreadable: %s", img_path)
            continue
        H, W = bgr.shape[:2]
        if (H, W) != (FINAL_SIZE, FINAL_SIZE):
            bgr = cv2.resize(bgr, (FINAL_SIZE, FINAL_SIZE), interpolation=cv2.INTER_AREA)

        # Save base 256 image
        stem = os.path.splitext(os.path.basename(name))[0]
        cv2.imwrite(os.path.join(out_dirs["images"], f"{stem}.png"), bgr)

        # Model input
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        x = preprocess(rgb).unsqueeze(0).to(device)

        # Grad-CAM
        cam_model.zero_grad(set_to_none=True)
        with torch.enable_grad():
            if mkey == "eigencam":
                mask = cam(input_tensor=x)[0]
            else:
                mask = cam(
                    input_tensor=x,
                    targets=[ClassifierOutputTarget(int(TARGET_CLASS_IDX))],
                    aug_smooth=True, eigen_smooth=True
                )[0]

        # Normalize & resize to 256
        mmin, mmax = float(np.min(mask)), float(np.max(mask))
        mask = (mask - mmin) / (mmax - mmin + 1e-8)
        mask_u8 = (np.clip(cv2.resize(mask, (FINAL_SIZE, FINAL_SIZE), interpolation=cv2.INTER_NEAREST), 0, 1) * 255).astype(np.uint8)

        # Heatmap overlay
        heat = cv2.applyColorMap(mask_u8, cv2.COLORMAP_HOT)
        heat_overlay = cv2.addWeighted(heat, float(heatmap_alpha), bgr, 1.0 - float(heatmap_alpha), 0.0)

        # Contours & model boxes
        thr = int(255 * float(bin_thr))
        _, binm = cv2.threshold(mask_u8, thr, 255, cv2.THRESH_BINARY)
        contours,_ = cv2.findContours(binm, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        cont_img = bgr.copy()
        if contours:
            cv2.drawContours(cont_img, contours, -1, CONTOUR_COLOR, CONTOUR_THICK, lineType=cv2.LINE_AA)

        box_img = bgr.copy()
        model_boxes: List[Tuple[int,int,int,int]] = []
        for c in contours:
            x0, y0, w0, h0 = cv2.boundingRect(c)
            x1, y1, x2, y2 = x0, y0, x0 + w0, y0 + h0
            model_boxes.append((x1, y1, x2, y2))
            cv2.rectangle(box_img, (x1, y1), (x2, y2), BB_MODEL_COLOR, BB_THICK, lineType=cv2.LINE_AA)

        # GT boxes in final 256×256 space
        gt_boxes = gt_boxes_map.get(name, gt_boxes_map.get(os.path.basename(name), []))

        # Save outputs with GT overlay (red)
        # 1) Heatmaps + GT
        heat_ov = heat_overlay.copy()
        for (gx1, gy1, gx2, gy2) in gt_boxes:
            cv2.rectangle(heat_ov, (gx1, gy1), (gx2, gy2), BB_GT_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        cv2.imwrite(os.path.join(out_dirs["heatmaps"], f"{stem}__{cfg_cam.network.name}__{mkey}.png"), heat_ov)

        # 2) Contours + GT
        cont_ov = cont_img.copy()
        for (gx1, gy1, gx2, gy2) in gt_boxes:
            cv2.rectangle(cont_ov, (gx1, gy1), (gx2, gy2), BB_GT_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        cv2.imwrite(os.path.join(out_dirs["contours"], f"{stem}__{cfg_cam.network.name}__{mkey}.png"), cont_ov)

        # 3) BBoxes view: model (blue) + GT (red)
        box_ov = bgr.copy()
        for (x1, y1, x2, y2) in model_boxes:
            cv2.rectangle(box_ov, (x1, y1), (x2, y2), BB_MODEL_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        for (gx1, gy1, gx2, gy2) in gt_boxes:
            cv2.rectangle(box_ov, (gx1, gy1), (gx2, gy2), BB_GT_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        cv2.imwrite(os.path.join(out_dirs["bboxes"], f"{stem}__{cfg_cam.network.name}__{mkey}.png"), box_ov)

    # Cleanup
    try:
        if hasattr(cam, "activations_and_grads") and (cam.activations_and_grads is not None):
            cam.activations_and_grads.release()
    except Exception:
        pass
    del cam, cam_model, model_img
    if device.type == "cuda":
        torch.cuda.empty_cache()
    log.info("[CAM] Saved to: %s", save_root)

# Scan → select best (by val-MCC) → run CAM on TBX11K
cfg_scan = dataclasses.replace(cfg, modality=MOD.img)
df_found = scan_img_checkpoints_by_mcc(cfg_scan)

def _fmt(x):
    if x is None or (isinstance(x, float) and math.isnan(x)): return "—"
    return f"{x:.4f}" if isinstance(x, float) else str(x)

# Print only valid, non-negative MCC entries
df_show = df_found.copy()
df_show["val_mcc"] = pd.to_numeric(df_show["val_mcc"], errors="coerce")
df_show = df_show[df_show["val_mcc"].notna() & (df_show["val_mcc"] >= 0.0)]

print("\n================ Available image-only checkpoints (sorted by folder) ================")
if len(df_show) == 0:
    print("• (no checkpoints with a valid non-negative MCC found; selection will fallback to newest mtime)")
else:
    for _, r in df_show.sort_values("subdir").iterrows():
        print(f"• {r['subdir']}")
        print(f" val_MCC={_fmt(float(r['val_mcc']))}")
        print(f" ckpt={r['ckpt_path']}")
        if r['log_json']: print(f" log ={r['log_json']}")
print("======================================================================================\n")

row_best = select_best_by_mcc(df_found)
best_ckpt = row_best["ckpt_path"]
best_subdir = row_best["subdir"]
best_mcc = row_best["val_mcc"]
reason = f"highest validation MCC={_fmt(best_mcc)}"
if best_mcc is None or (isinstance(best_mcc, float) and math.isnan(best_mcc)):
    reason = "no MCC recorded; selecting most recent checkpoint"
print(">>> Selecting checkpoint:")
print(f" {best_ckpt}")
print(f" Because: {reason}\n")
# Build save root under the chosen run folder (no _dimension_ tags written)
save_root = os.path.join(best_subdir, SAVE_ROOT_NAME)

log.info("================ Grad-CAM (TBX11K, IMG-ONLY; mask-mapped GT) ================")
log.info(f"Checkpoint : {best_ckpt}")
log.info(f"Backbone   : {cfg.network.name}")
log.info(f"Modality   : {cfg.modality.name}")
log.info(f"csv_test   : {cfg.csv_test}")
log.info(f"TB dir     : {TB_256_DIR} (final lung-cropped 256×256)")
log.info(f"Masks dir  : {TB_MASKS_256_DIR} (pre-crop masks 256×256)")
log.info(f"BBoxes csv : {BBOX_CSV_PATH} (authors' boxes @512×512)")
log.info(f"Save root  : {save_root}")
log.info("===========================================================================")

# Run Grad-CAM with GT overlay
run_gradcam_tbx11k_img_with_gt(
    cfg_cam=cfg,
    ckpt_path=best_ckpt,
    csv_test_path=cfg.csv_test,
    tb_256_dir=TB_256_DIR,
    masks_256_dir=TB_MASKS_256_DIR,
    bbox_csv_path=BBOX_CSV_PATH,
    save_root=save_root,
    cam_method=CAM_METHOD,
    heatmap_alpha=HEATMAP_ALPHA,
    bin_thr=BIN_THR,
)

## END OF CODE