In [None]:
!nvidia-smi

In [None]:
from __future__ import annotations

# =========================
# Standard library imports
# =========================
import ast
import glob
import inspect
import io
import json
import logging
import math
import os
import pathlib
import random
import re
import shutil
import sys
import time
import warnings
import dataclasses
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

# =========================
# Third-party imports
# =========================
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np
import pandas as pd
from PIL import Image
import sklearn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Sampler, WeightedRandomSampler
from torchvision import transforms as T
from transformers import AutoModel, AutoTokenizer
from pytorch_grad_cam import (
    AblationCAM,
    EigenCAM,
    EigenGradCAM,
    GradCAM,
    GradCAMPlusPlus,
    HiResCAM,
    LayerCAM,
    ScoreCAM,
    XGradCAM,
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

In [None]:
#set environmental variables

print(torch.__version__)
print(sys.executable)
print(sys.version)

os.environ["HF_HOME"] = "/gpfs/gsfs12/users/rajaramans2/.cache/huggingface"
os.environ["HF_HUB_CACHE"] = "/gpfs/gsfs12/users/rajaramans2/.cache/huggingface/hub"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"

In [None]:
# Configuration

class MOD(Enum):
    img         = auto()
    txt         = auto()
    multimodal  = auto()

class NETWORK(Enum):
    # timm / torchvision families
    dpn68 = auto()
    coatnet0 = auto()
    convnext_nano = auto()
    hrnet32 = auto()
    resnet18 = auto()
    densenet121 = auto()
    mobilenet_v2 = auto()
    # VGG family
    vgg11 = auto()
    vgg13 = auto()
    vgg16 = auto()
    vgg19 = auto()
 
@dataclass
class RunConfig:
    modality: MOD
    network: NETWORK
    n_classes: int = 2
    hidden_dim: int = 256
    epochs: int = 64
    lr: float = 5e-5
    weight_decay: float = 1e-4
    img_size: int = 224
    batch_size: int = 64
    num_workers: int = 2
    pin_memory: bool = True
    
    # classes
    class_names: Tuple[str, ...] = ("normal", "tb")

    # data locations
    dataset_root: str = "//dataset"
    csv_train: Optional[str] = None
    csv_valid: Optional[str] = None
    csv_test:  Optional[str] = None
    images_subdir: str = "images"
    reports_subdir: str = "reports"
    out_dir: str = "/models"

    # text/HF (offline-first)
    hf_offline: bool = True
    hf_local_dir: Optional[str] = "/huggingface/hub/models--microsoft--BiomedVLP-CXR-BERT-general"
    hf_text_model_name: str = "microsoft/BiomedVLP-CXR-BERT-general"
    max_tokens: int = 512
    hf_tokenizer_local_dir: Optional[str] = "/huggingface/hub/models--microsoft--BiomedVLP-CXR-BERT-general"

    # alignment losses (for MM only)
    use_align_losses: bool = True
    align_lambda: float = 0.5
    contrastive_temperature: float = 0.07

    def __post_init__(self):
        if self.csv_train is None:
            self.csv_train = os.path.join(self.dataset_root, "label_train.csv")
        if self.csv_valid is None:
            self.csv_valid = os.path.join(self.dataset_root, "label_valid.csv")
        if self.csv_test is None:
            self.csv_test  = os.path.join(self.dataset_root, "label_test.csv")

    @property
    def images_dir(self) -> str:
        return os.path.join(self.dataset_root, self.images_subdir)

    @property
    def reports_dir(self) -> str:
        return os.path.join(self.dataset_root, self.reports_subdir)

def ckpt_stem(cfg: RunConfig) -> str:
    return f"{cfg.network.name}_{cfg.modality.name}"

def ckpt_paths(cfg: RunConfig) -> Dict[str, str]:
    stem = ckpt_stem(cfg)
    base = cfg.out_dir
    return {
        "best_pt":         os.path.join(base, f"best_{stem}_val_loss.pt"),
        "curves_png":      os.path.join(base, f"best_{stem}_val_loss_curves.png"),
        "curves_csv":      os.path.join(base, f"best_{stem}_val_loss_history.csv"),
        "curves_png_full": os.path.join(base, f"best_{stem}_val_loss_curves_full.png"),
        "curves_csv_full": os.path.join(base, f"best_{stem}_val_loss_history_full.csv"),
        "log_json":        os.path.join(base, f"best_{stem}_val_loss_log.json"),
        "mcc_png":         os.path.join(base, f"best_{stem}_val_mcc_curve.png"),
    }

In [None]:
# CSV dataset & loaders (Albumentations), loads the train, validation, and test CSV files; class labels: 0 for normal, and 1 for tb

def _read_split_csv_headerless(path: str) -> pd.DataFrame:
    """
    Read a split CSV with NO header into a DataFrame with canonical
    columns: ["filename", "label"].    
    """
    df = pd.read_csv(path, header=None, names=["filename", "label"])
    if df.shape[1] < 2:
        raise ValueError(f"{path} must have ≥2 columns (filename, label)")
    df["filename"] = df["filename"].astype(str)
    df["label"] = df["label"].astype(int)
    return df

def _read_csv_two_cols_no_header(path: str) -> Tuple[List[str], List[int]]:
    """
    Convenience wrapper when you just need the two lists.
    Used by the dataset constructor.
    """
    df = _read_split_csv_headerless(path)
    return df["filename"].tolist(), df["label"].tolist()

def _print_split_summary(cfg: "RunConfig", which: str, modality=None) -> None:
    """
    Print a per-split summary similar to the unimodal stack.

    Parameters
    ----------
    cfg : RunConfig
        Must expose csv_train / csv_valid / csv_test, n_classes, and optionally class_names.
    which : {"train","valid","test"}
    modality : optional
        Either an enum with .name or a string; used only for pretty-printing.
    """
    csv_attr = f"csv_{which}"
    if not hasattr(cfg, csv_attr):
        print(f"[WARN] _print_split_summary: cfg has no attribute '{csv_attr}'. Skipping.")
        return

    csv_path = getattr(cfg, csv_attr)
    if not csv_path or not os.path.isfile(csv_path):
        print(f"[WARN] _print_split_summary: '{csv_attr}'='{csv_path}' not found. Skipping.")
        return

    df = _read_split_csv_headerless(csv_path)
    total = len(df)
    vc = df["label"].value_counts().sort_index()

    # Try to get a readable modality name
    if modality is None:
        mod_name = "N/A"
    else:
        mod_name = getattr(modality, "name", str(modality))

    print(f"\n[Split summary] split='{which}' | modality={mod_name} | csv='{csv_path}'", flush=True)
    print(f"  total rows: {total}", flush=True)

    class_names = getattr(cfg, "class_names", None)
    for cls_idx, cnt in vc.items():
        if class_names is not None and 0 <= cls_idx < len(class_names):
            cname = class_names[cls_idx]
        else:
            cname = f"class_{cls_idx}"
        print(f"  label={cls_idx:2d} ({cname:>10s}) : n={int(cnt)}", flush=True)
    print("", flush=True)

# ============================
# ==== Path & label helpers
# ============================

def _abs_image_path(dataset_root: str, images_subdir: str, fname: str) -> Path:
    return Path(dataset_root) / images_subdir / fname

def _abs_report_path(dataset_root: str, reports_subdir: str, fname: str) -> Path:
    """
    For a given image filename 'xxx.ext', we first try '<stem>.txt' under
    reports_subdir. If not present,
    we fall back to trying the raw fname itself under reports_subdir.
    """
    root = Path(dataset_root)
    stem_txt = root / reports_subdir / (Path(fname).stem + ".txt")
    if stem_txt.exists():
        return stem_txt
    return root / reports_subdir / fname

def _one_hot(idx: int, n_classes: int) -> torch.Tensor:
    y = torch.zeros(n_classes, dtype=torch.float32)
    if 0 <= idx < n_classes:
        y[idx] = 1.0
    return y

# ============================
# ==== Multimodal Dataset Class
# ============================

class CSVMMImageTextDataset(Dataset):
    """
    Multimodal dataset for joint image + text training.

    Returns dict:
        {
            "image":     CxHxW tensor (float32),
            "text":      {"input_ids", "attention_mask"} (int64 tensors),
            "y_onehot":  one-hot label (float32),
            "filename":  original image filename (str)
        }
    """

    def __init__(self,
                 cfg: "RunConfig",
                 which: str,
                 n_classes: int,
                 tokenizer: AutoTokenizer,
                 max_len: int = 192) -> None:

        self.cfg = cfg
        self.n_classes = int(n_classes)
        self.tokenizer = tokenizer
        self.max_len = int(max_len)

        # --- Read split CSV (headerless) ---
        csv_path = {"train": cfg.csv_train,
                    "valid": cfg.csv_valid,
                    "test":  cfg.csv_test}[which]
        img_names, labels = _read_csv_two_cols_no_header(csv_path)

        # --- Build list of (image_path, label, report_path) where both exist ---
        self.items: List[Tuple[Path, int, Path]] = []
        n_missing_img = 0
        n_missing_txt = 0

        for fn, y in zip(img_names, labels):
            ip = _abs_image_path(cfg.dataset_root, cfg.images_subdir, fn)
            rp = _abs_report_path(cfg.dataset_root, cfg.reports_subdir, fn)
            if not ip.exists():
                n_missing_img += 1
                continue
            if not rp.exists():
                n_missing_txt += 1
                continue
            self.items.append((ip, int(y), rp))

        # --- Report what survived (multimodal pairing) ---
        print(f"[Multimodal {which}] paired (image+report) samples: {len(self.items)} "
              f"(missing_img={n_missing_img}, missing_txt={n_missing_txt})",
              flush=True)

        # --- Image transforms (CLAHE only for train) ---
        size = int(cfg.img_size)
        self.train_tfm = A.Compose(
            [
                A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.5),
                A.Resize(size, size, interpolation=cv2.INTER_AREA),
                A.Normalize(mean=(0.485, 0.456, 0.406),
                            std=(0.229, 0.224, 0.225)),
                ToTensorV2()
            ]
        )
        self.eval_tfm = A.Compose(
            [
                A.Resize(size, size, interpolation=cv2.INTER_AREA),
                A.Normalize(mean=(0.485, 0.456, 0.406),
                            std=(0.229, 0.224, 0.225)),
                ToTensorV2()
            ]
        )
        self.tfm = self.train_tfm if which == "train" else self.eval_tfm

    def __len__(self) -> int:
        return len(self.items)

    # --- internals: image & text reading ---

    def _read_img(self, path: Path) -> torch.Tensor:
        # Grayscale read → RGB → Albumentations → tensor
        img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
        if img is None:
            # Failsafe black image
            img = np.zeros((self.cfg.img_size, self.cfg.img_size), dtype=np.uint8)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        img = self.tfm(image=img)["image"]  # (C,H,W), float32
        return img

    def _read_text(self, path: Path) -> Dict[str, torch.Tensor]:
        try:
            text = path.read_text(encoding="utf-8", errors="ignore").strip()
        except Exception:
            text = ""
        if not text:
            text = "[EMPTY]"
        tok = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
        )
        # Strip batch dimension
        return {
            "input_ids": tok["input_ids"][0],
            "attention_mask": tok["attention_mask"][0],
        }

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        ip, y_idx, rp = self.items[idx]
        y = _one_hot(y_idx, self.n_classes)
        return {
            "image": self._read_img(ip),
            "text": self._read_text(rp),
            "y_onehot": y,
            "filename": ip.name,
        }

# ============================
# ==== Tokenizer builder
# ============================
def build_tokenizer(cfg: "RunConfig") -> AutoTokenizer:
    """
    Build HuggingFace tokenizer for the text branch.
    Mirrors unimodal behavior; prefers local snapshot if provided.
    """
    model_id = getattr(cfg, "hf_text_model_name",
                       "microsoft/BiomedVLP-CXR-BERT-general")
    offline   = bool(getattr(cfg, "hf_offline", True))
    local_dir = getattr(cfg, "hf_tokenizer_local_dir", None)
    # >>> NEW: read trust_remote_code flag <<<
    trust_remote = bool(getattr(cfg, "hf_trust_remote_code", False))

    try:
        if local_dir and os.path.isdir(local_dir):
            return AutoTokenizer.from_pretrained(
                local_dir,
                local_files_only=True,
                use_fast=True,
                trust_remote_code=trust_remote,  # <<< NEW
            )
        return AutoTokenizer.from_pretrained(
            model_id,
            local_files_only=offline,
            use_fast=True,
            trust_remote_code=trust_remote,      # <<< NEW
        )
    except Exception:
        if offline:
            raise RuntimeError(
                "Offline mode enabled but no local tokenizer found at "
                f"hf_tokenizer_local_dir={local_dir!r}"
            )
        # Fallback to online fetch
        return AutoTokenizer.from_pretrained(
            model_id,
            use_fast=True,
            trust_remote_code=trust_remote,      # <<< NEW
        )

# ============================
# ==== Imbalanced sampler
# ============================

class ImbalancedDatasetSampler(Sampler[int]):
    """
    Simple inverse-frequency sampler over class indices.    
    """

    def __init__(self, labels: torch.Tensor):
        self.labels = labels.clone().cpu().long()
        k = int(self.labels.max().item()) + 1 if self.labels.numel() > 0 else 2
        counts = torch.bincount(self.labels, minlength=k).clamp_min(1)
        weights = (1.0 / counts)[self.labels]
        self.sample_weights = weights.double()
        self.num_samples = len(self.labels)

    def __iter__(self):
        idx = torch.multinomial(
            self.sample_weights,
            num_samples=self.num_samples,
            replacement=True,
        )
        return iter(idx.tolist())

    def __len__(self):
        return self.num_samples

# ============================
# ==== Loader creation (multimodal)
# ============================

def make_loaders(cfg: "RunConfig") -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Build train/valid/test DataLoaders for multimodal training.

    - Prints split summaries (like unimodal) using _print_split_summary.
    - Uses CSVMMImageTextDataset for all splits.
    - Uses ImbalancedDatasetSampler for the training set.
    """

    # 1) Split summaries *before* dataset construction (raw CSV)
    _print_split_summary(cfg, "train", modality=getattr(cfg, "modality", None))
    _print_split_summary(cfg, "valid", modality=getattr(cfg, "modality", None))
    _print_split_summary(cfg, "test",  modality=getattr(cfg, "modality", None))

    # 2) Tokenizer
    tok = build_tokenizer(cfg)

    # 3) Datasets (these will also print how many image+report pairs survive)
    ds_tr = CSVMMImageTextDataset(cfg, "train", cfg.n_classes, tokenizer=tok, max_len=int(cfg.max_tokens))
    ds_va = CSVMMImageTextDataset(cfg, "valid", cfg.n_classes, tokenizer=tok, max_len=int(cfg.max_tokens))
    ds_te = CSVMMImageTextDataset(cfg, "test",  cfg.n_classes, tokenizer=tok, max_len=int(cfg.max_tokens))

    # 4) Class-balanced sampler for training (based on CSV labels, as in unimodal)
    df_tr = _read_split_csv_headerless(cfg.csv_train)
    labels_tr = torch.as_tensor(df_tr["label"].values, dtype=torch.long)
    sampler = ImbalancedDatasetSampler(labels_tr)

    dl_tr = DataLoader(
        ds_tr,
        batch_size=int(cfg.batch_size),
        sampler=sampler,
        shuffle=False,
        num_workers=int(cfg.num_workers),
        pin_memory=bool(cfg.pin_memory),
        drop_last=True,
        persistent_workers=bool(int(cfg.num_workers) > 0),
    )

    dl_va = DataLoader(
        ds_va,
        batch_size=int(cfg.batch_size),
        shuffle=False,
        num_workers=int(cfg.num_workers),
        pin_memory=bool(cfg.pin_memory),
        drop_last=False,
        persistent_workers=bool(int(cfg.num_workers) > 0),
    )

    dl_te = DataLoader(
        ds_te,
        batch_size=int(cfg.batch_size),
        shuffle=False,
        num_workers=int(cfg.num_workers),
        pin_memory=bool(cfg.pin_memory),
        drop_last=False,
        persistent_workers=bool(int(cfg.num_workers) > 0),
    )

    # 5) Batch count sanity check
    def _n_batches(n_items: int, bs: int, drop_last: bool) -> int:
        if drop_last:
            return max(0, n_items // max(1, bs))
        return (n_items + max(1, bs) - 1) // max(1, bs)

    print(
        f"[batches] train={_n_batches(len(ds_tr), int(cfg.batch_size), True)}, "
        f"valid={_n_batches(len(ds_va), int(cfg.batch_size), False)}, "
        f"test={_n_batches(len(ds_te), int(cfg.batch_size), False)}",
        flush=True,
    )

    return dl_tr, dl_va, dl_te

train_loader, valid_loader, test_loader = make_loaders(cfg)

In [None]:
# Model architecture
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------

def _net_name_like(x) -> str:
    """
    Return a lowercase string name for the image backbone:
      - Enum with .name  -> .name.lower()
      - str              -> .lower()
      - anything else    -> str(x).lower()
    """
    try:
        return str(getattr(x, "name")).lower()
    except Exception:
        return str(x).lower()

def _ends_with_3x3(module: nn.Module) -> bool:
    """True if the last Conv2d in `module` has a 3×3 kernel."""
    last_conv = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            last_conv = m
    return (last_conv is not None) and (tuple(getattr(last_conv, "kernel_size", (0, 0))) == (3, 3))

def _insert_post3x3_if_needed(ch: int) -> nn.Sequential:
    """Conv(3×3) + BN + SiLU block."""
    return nn.Sequential(
        nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(ch),
        nn.SiLU(inplace=True),
    )

# HF-first timm model IDs (we try these first, then the plain timm IDs)
TIMM_NAME_MAP: _Dict[str, _List[str]] = {
    "dpn68":         ["dpn68.mx_in1k", "dpn68"],
    "coatnet0":      ["coatnet_0_rw_224.sw_in1k", "coatnet_0_rw_224"],
    "convnext_nano": ["convnext_nano.in12k_ft_in1k", "convnext_nano"],
    "hrnet32":       ["hrnet_w32.ms_in1k", "hrnet_w32"],
}

def _maybe_create_timm(candidates: _List[str]) -> nn.Module:
    """
    Try to create a timm model using HF variants first, then base IDs.
    We DO NOT pass 'img_size' here (some models don't accept it).
    We request feature-only (num_classes=0, global_pool="").
    """
    if not _HAS_TIMM:
        raise RuntimeError("timm is not installed but a timm backbone was requested.")
    kwargs = dict(num_classes=0, global_pool="")
    last_err = None
    for model_name in candidates:
        # 1) pretrained=True
        try:
            return timm.create_model(model_name, pretrained=True, **kwargs)
        except Exception as e:
            last_err = e
        # 2) pretrained=False
        try:
            return timm.create_model(model_name, pretrained=False, **kwargs)
        except Exception as e:
            last_err = e
    raise RuntimeError(f"Could not create any of timm models {candidates}. Last error: {last_err}")

# ------------------------------------------------------------------
# Image backbones
# ------------------------------------------------------------------

class TimmBackbone(nn.Module):
    """
    timm forward_features → (B,C,H,W) or (B,Tokens,C).
    If 4D: [optional post-3×3] → GAP → Dropout → (B,C)
    If 3D: take CLS/token 0 → Dropout → (B,C)
    """
    def __init__(self, key_name: str, img_size: int, insert_post3x3: bool = True, p_drop: float = 0.3):
        super().__init__()
        key_name = key_name.lower()
        if key_name not in TIMM_NAME_MAP:
            raise ValueError(f"Unknown timm backbone key '{key_name}'. Options: {list(TIMM_NAME_MAP.keys())}")

        # Create the timm model (HF-first)
        self.m = _maybe_create_timm(TIMM_NAME_MAP[key_name])

        # Probe features
        with torch.no_grad():
            dummy = torch.zeros(1, 3, int(img_size), int(img_size))
            f = self.m.forward_features(dummy)
            self._feat_is_4d = (f.ndim == 4)
            if self._feat_is_4d:
                ch = int(f.shape[1])
                # Register post-3×3 FIRST, so repr() shows the correct order
                self.post3x3 = _insert_post3x3_if_needed(ch) if (insert_post3x3 and not _ends_with_3x3(self.m)) else nn.Identity()
                # Now register GAP and Dropout (classification "head")
                self.gap  = nn.AdaptiveAvgPool2d((1, 1))
                self.drop = nn.Dropout(p=p_drop)
                self.out_dim = ch
            elif f.ndim == 3:
                # Transformer-like features: no post-3×3, just Dropout on CLS
                self.post3x3 = nn.Identity()
                self.gap  = nn.Identity()   # not used for 3D features
                self.drop = nn.Dropout(p=p_drop)
                self.out_dim = int(f.shape[-1])
            else:
                raise RuntimeError(f"Unexpected feature shape from timm model: {tuple(f.shape)}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.m.forward_features(x)
        if self._feat_is_4d:
            f = self.post3x3(f)
            f = self.gap(f).flatten(1)
            f = self.drop(f)
            return f
        # token features: take class token (index 0)
        f = f[:, 0]
        f = self.drop(f)
        return f

class TorchvisionBackbone(nn.Module):
    """ResNet18 / DenseNet121 / MobileNetV2 with optional post-3×3 → GAP → Dropout."""
    def __init__(self, which: "NETWORK", p_drop: float = 0.3):
        super().__init__()
        import torchvision.models as tvm

        if which == NETWORK.resnet18:
            try: m = tvm.resnet18(weights=tvm.ResNet18_Weights.IMAGENET1K_V1)
            except Exception: m = tvm.resnet18(weights=None)
            self.encoder = nn.Sequential(*(list(m.children())[:-2])); ch = m.fc.in_features
        elif which == NETWORK.densenet121:
            try: m = tvm.densenet121(weights=tvm.DenseNet121_Weights.IMAGENET1K_V1)
            except Exception: m = tvm.densenet121(weights=None)
            self.encoder = m.features; ch = m.classifier.in_features
        elif which == NETWORK.mobilenet_v2:
            try: m = tvm.mobilenet_v2(weights=tvm.MobileNet_V2_Weights.IMAGENET1K_V1)
            except Exception: m = tvm.mobilenet_v2(weights=None)
            self.encoder = m.features; ch = m.classifier[1].in_features
        else:
            raise ValueError(which)

        # Register post-3×3 FIRST (so it prints before GAP/Dropout)
        self.post3x3 = _insert_post3x3_if_needed(ch) if not _ends_with_3x3(self.encoder) else nn.Identity()
        # Classification head
        self.gap  = nn.AdaptiveAvgPool2d((1, 1))
        self.drop = nn.Dropout(p=p_drop)
        self.out_dim = int(ch)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.encoder(x)
        f = self.post3x3(f)
        f = self.gap(f).flatten(1)
        f = self.drop(f)
        return f

class TorchvisionVGGBackbone(nn.Module):
    """VGG11/13/16/19 (BN) features + optional post-3×3 → GAP → Dropout."""
    def __init__(self, which: "NETWORK", p_drop: float = 0.3):
        super().__init__()
        import torchvision.models as tvm

        vgg_map = {
            NETWORK.vgg11: (tvm.vgg11_bn,  getattr(tvm, "VGG11_BN_Weights", None)),
            NETWORK.vgg13: (tvm.vgg13_bn,  getattr(tvm, "VGG13_BN_Weights", None)),
            NETWORK.vgg16: (tvm.vgg16_bn,  getattr(tvm, "VGG16_BN_Weights", None)),
            NETWORK.vgg19: (tvm.vgg19_bn,  getattr(tvm, "VGG19_BN_Weights", None)),
        }
        ctor, weights_enum = vgg_map[which]
        try:
            if weights_enum is not None:
                weights = getattr(weights_enum, "IMAGENET1K_V1")
                m = ctor(weights=weights)
            else:
                m = ctor(weights=None)
        except Exception:
            m = ctor(weights=None)

        self.encoder = m.features

        # find output channels from last Conv2d in the features
        ch = None
        for mod in self.encoder.modules():
            if isinstance(mod, nn.Conv2d):
                ch = mod.out_channels
        if ch is None:
            raise RuntimeError("VGG out channels not found")

        # Register post-3×3 FIRST (so it prints before GAP/Dropout)
        self.post3x3 = _insert_post3x3_if_needed(ch) if not _ends_with_3x3(self.encoder) else nn.Identity()
        # Classification head
        self.gap  = nn.AdaptiveAvgPool2d((1, 1))
        self.drop = nn.Dropout(p=p_drop)
        self.out_dim = int(ch)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.encoder(x)
        f = self.post3x3(f)
        f = self.gap(f).flatten(1)
        f = self.drop(f)
        return f

# ------------------------------------------------------------------
# Text helpers
# ------------------------------------------------------------------
def _resolve_hf_local_dir(local_dir: _Optional[str]) -> _Optional[str]:
    if not local_dir or not os.path.isdir(local_dir):
        return None
    needed = {"config.json", "pytorch_model.bin", "model.safetensors", "rust_model.ot"}
    files = set(os.listdir(local_dir))
    if files & needed:
        return local_dir
    snaps = os.path.join(local_dir, "snapshots")
    refs  = os.path.join(local_dir, "refs", "main")
    if os.path.isdir(snaps):
        if os.path.isfile(refs):
            with open(refs, "r") as f:
                commit = f.read().strip()
            cand = os.path.join(snaps, commit)
            if os.path.isdir(cand):
                return cand
        subdirs = [os.path.join(snaps, d) for d in os.listdir(snaps) if os.path.isdir(os.path.join(snaps, d))]
        if subdirs:
            subdirs.sort(key=lambda p: os.path.getmtime(p), reverse=True)
            return subdirs[0]
    return None

@_dataclass
class TextFTPolicy:
    mode: Literal["authors_default", "freeze_all", "last_n", "train_all"] = "authors_default"
    last_n: int = 1

def configure_text_finetune(bert: nn.Module, policy: TextFTPolicy) -> None:
    def set_all(req: bool):
        for p in bert.parameters():
            p.requires_grad = req

    if policy.mode == "authors_default":
        if hasattr(bert, "embeddings"):
            for p in bert.embeddings.parameters():
                p.requires_grad = False
        if hasattr(bert, "pooler"):
            for p in bert.pooler.parameters():
                p.requires_grad = False
        enc = getattr(getattr(bert, "encoder", None), "layer", [])
        for i, layer in enumerate(enc):
            req = (i == len(enc) - 1)
            for p in layer.parameters():
                p.requires_grad = req
        return

    if policy.mode == "freeze_all":
        set_all(False); return
    if policy.mode == "train_all":
        set_all(True); return
    if policy.mode == "last_n":
        if hasattr(bert, "embeddings"):
            for p in bert.embeddings.parameters():
                p.requires_grad = False
        if hasattr(bert, "pooler"):
            for p in bert.pooler.parameters():
                p.requires_grad = False
        enc = getattr(bert, "encoder", None)
        total = len(enc.layer)
        cutoff = max(0, total - policy.last_n)
        for i, layer in enumerate(enc.layer):
            req = (i >= cutoff)
            for p in layer.parameters():
                p.requires_grad = req
        return
    raise ValueError(policy)

def build_text_encoder_from_cfg(cfg) -> _Optional[nn.Module]:
    if cfg.modality not in (MOD.txt, MOD.multimodal):
        return None
    local_dir = getattr(cfg, "hf_local_dir", None)
    offline   = bool(getattr(cfg, "hf_offline", True))
    # >>> NEW: read trust_remote_code flag from cfg <<<
    trust_remote = bool(getattr(cfg, "hf_trust_remote_code", False))

    if local_dir:
        resolved = _resolve_hf_local_dir(local_dir)
        if resolved and os.path.isdir(resolved):
            return AutoModel.from_pretrained(
                resolved,
                local_files_only=True,
                trust_remote_code=trust_remote,   # <<< NEW
            )
        if offline:
            raise RuntimeError(f"cfg.hf_local_dir='{local_dir}' is not a valid HF snapshot.")
    if offline:
        raise RuntimeError("Offline but no valid local text snapshot found.")
    return AutoModel.from_pretrained(
        cfg.hf_text_model_name,
        trust_remote_code=trust_remote,           # <<< NEW
    )

# ------------------------------------------------------------------
# Multimodal model
# ------------------------------------------------------------------
class MultimodalNet(nn.Module):
    """
    Image + Text classifier:
      - Image: (backbone → [optional post-3×3] → GAP → Dropout) → Linear → ReLU → Linear
      - Text : HF encoder (pooled) → Linear → ReLU → Linear
    """
    def __init__(self, cfg: "RunConfig", text_ft: TextFTPolicy = TextFTPolicy()):
        super().__init__()
        self.modality  = cfg.modality
        self.hidden    = int(cfg.hidden_dim)
        self.n_classes = int(cfg.n_classes)
        self.img_enc   = None
        self.txt_enc   = None

        # ----- Image encoder -----
        img_out = 0
        if cfg.modality in (MOD.img, MOD.multimodal):
            net_name = _net_name_like(cfg.network)
            if net_name in ("resnet18", "densenet121", "mobilenet_v2"):
                self.img_enc = TorchvisionBackbone(cfg.network, p_drop=0.3)
            elif net_name in ("vgg11", "vgg13", "vgg16", "vgg19"):
                self.img_enc = TorchvisionVGGBackbone(cfg.network, p_drop=0.3)
            elif net_name in TIMM_NAME_MAP:
                self.img_enc = TimmBackbone(net_name, img_size=int(cfg.img_size), insert_post3x3=True, p_drop=0.3)
            else:
                raise ValueError(f"Unsupported backbone: {cfg.network} (name='{net_name}')")
            img_out = int(getattr(self.img_enc, "out_dim"))

        # ----- Text encoder -----
        txt_hid = 0
        if cfg.modality in (MOD.txt, MOD.multimodal):
            self.txt_enc = build_text_encoder_from_cfg(cfg)
            if self.txt_enc is None:
                raise RuntimeError("Could not build text encoder from cfg.")
            configure_text_finetune(self.txt_enc, text_ft)
            txt_hid = int(self.txt_enc.config.hidden_size)

        # ----- Heads -----
        if self.img_enc is not None:
            self.img_proj = nn.Linear(img_out, self.hidden)
        if self.txt_enc is not None:
            self.txt_proj = nn.Linear(txt_hid, self.hidden)

        self.act  = nn.ReLU(inplace=True)
        self.mid  = nn.Linear(self.hidden, self.hidden)
        self.cls  = nn.Linear(self.hidden, self.n_classes)

    def _head(self, feat: torch.Tensor) -> torch.Tensor:
        return self.cls(self.mid(self.act(feat)))

    def forward(self, image=None, text=None):
        out = {"logits_img": None, "z_img": None, "logits_txt": None, "z_txt": None}
        if (self.img_enc is not None) and (image is not None):
            f_img = self.img_enc(image)                 # (B, C)
            z_img = self.mid(self.act(self.img_proj(f_img)))
            out["z_img"] = z_img
            out["logits_img"] = self.cls(z_img)
        if (self.txt_enc is not None) and (text is not None):
            enc = self.txt_enc(input_ids=text["input_ids"], attention_mask=text["attention_mask"])
            pooled = enc.pooler_output if getattr(enc, "pooler_output", None) is not None else enc.last_hidden_state[:, 0]
            z_txt = self.mid(self.act(self.txt_proj(pooled)))
            out["z_txt"] = z_txt
            out["logits_txt"] = self.cls(z_txt)
        return out

    # image/text only entrypoints (used by CAM/eval helpers)
    def forward_image(self, image):
        return self.forward(image=image, text=None)["logits_img"]

    def forward_text(self, text):
        return self.forward(image=None, text=text)["logits_txt"]

In [None]:
# Losses (CE-image + CE-text +  cosine (img, text) + supervised NT-Xent(img, text)

@_dataclass2
class LossCfg:
    use_align_losses: bool = True
    align_lambda: float = 0.5
    contrastive_temperature: float = 0.5

_ce = nn.CrossEntropyLoss(reduction="mean")
_cos = nn.CosineEmbeddingLoss(reduction="mean")

def onehot_to_idx(y_onehot: torch.Tensor) -> torch.Tensor:
    return torch.argmax(y_onehot, dim=1)

def sup_nt_xent(z1: torch.Tensor, z2: torch.Tensor, temperature: float) -> torch.Tensor:
    z1 = F.normalize(z1, dim=-1); z2 = F.normalize(z2, dim=-1)
    B = z1.size(0)
    logits_12 = (z1 @ z2.t()) / temperature
    logits_21 = (z2 @ z1.t()) / temperature
    targets = torch.arange(B, device=z1.device)
    return 0.5 * (F.cross_entropy(logits_12, targets) + F.cross_entropy(logits_21, targets))

def compute_total_loss(out: Dict[str, Optional[torch.Tensor]], batch: Dict[str, torch.Tensor],
                       cfg: RunConfig, lcfg: LossCfg) -> (torch.Tensor, Dict[str, float]):
    dev = None
    for v in out.values():
        if isinstance(v, torch.Tensor):
            dev = v.device; break
    if dev is None: dev = batch["y_onehot"].device
    y_idx = onehot_to_idx(batch["y_onehot"].to(dev))
    total = torch.zeros((), device=dev)
    scalars = {"ce_img": 0.0, "ce_txt": 0.0, "cosine": 0.0, "contrastive": 0.0}
    if out["logits_img"] is not None:
        ce_img = _ce(out["logits_img"], y_idx); total = total + ce_img; scalars["ce_img"] = float(ce_img.detach().item())
    if out["logits_txt"] is not None:
        ce_txt = _ce(out["logits_txt"], y_idx); total = total + ce_txt; scalars["ce_txt"] = float(ce_txt.detach().item())
    if lcfg.use_align_losses and (out["z_img"] is not None) and (out["z_txt"] is not None):
        target = torch.ones(out["z_img"].size(0), device=dev)
        cos = _cos(out["z_img"], out["z_txt"], target); total = total + cos
        con = sup_nt_xent(out["z_img"], out["z_txt"], lcfg.contrastive_temperature)
        total = total + lcfg.align_lambda * con
        scalars["cosine"] = float(cos.detach().item())
        scalars["contrastive"] = float(con.detach().item())
    scalars["total"] = float(total.detach().item())
    return total, scalars

In [None]:
# Metrics

@torch.no_grad()
def logits_to_probs(logits: torch.Tensor) -> np.ndarray:
    return F.softmax(logits, dim=1).cpu().numpy()

@torch.no_grad()
def evaluate_logits(logits: torch.Tensor, y_onehot: torch.Tensor, class_names: Tuple[str,...]=("normal","tb")) -> Dict[str, object]:
    probs = logits_to_probs(logits)
    y_true_idx = torch.argmax(y_onehot, dim=1).cpu().numpy()
    y_pred_idx = np.argmax(probs, axis=1)
    cm_std = confusion_matrix(y_true_idx, y_pred_idx, labels=[0,1])
    TP = int(((y_true_idx==1)&(y_pred_idx==1)).sum())
    FP = int(((y_true_idx==0)&(y_pred_idx==1)).sum())
    FN = int(((y_true_idx==1)&(y_pred_idx==0)).sum())
    TN = int(((y_true_idx==0)&(y_pred_idx==0)).sum())
    cm_clin = np.array([[TP, FP], [FN, TN]], dtype=int)
    try:
        auc_macro = roc_auc_score(y_true_idx, probs, multi_class="ovr", average="macro")
        auc_per_class = roc_auc_score(y_true_idx, probs, multi_class="ovr", average=None)
    except ValueError:
        auc_macro, auc_per_class = float("nan"), [float("nan")] * probs.shape[1]
    report = ""  # optional
    return {
        "confusion_pred_rows_true_cols": cm_std,
        "confusion_clinician_TPFP_FN_TN": cm_clin,
        "report": report,
        "auc_macro": float(auc_macro) if auc_macro == auc_macro else None,
        "auc_per_class": [float(a) if a == a else None for a in auc_per_class],
    }

@torch.no_grad()
def evaluate_dataloader_classifier(model: nn.Module, loader: DataLoader, modality: MOD) -> Dict[str, float]:
    device = next(model.parameters()).device
    y_true, y_pred, y_prob = [], [], []
    for batch in loader:
        if batch["y_onehot"].numel() == 0: continue
        img, txt = batch.get("image"), batch.get("text")
        if img is not None: img = img.to(device, non_blocking=True)
        if txt is not None: txt = {k: v.to(device, non_blocking=True) for k, v in txt.items()}
        out = model(img, txt)
        # Use image head for classification when available, else fallback to text
        logits = out["logits_img"] if out["logits_img"] is not None else out["logits_txt"]
        probs1 = F.softmax(logits.float(), dim=1)[:,1].detach().cpu().numpy()
        preds = logits.argmax(dim=1).detach().cpu().numpy()
        y_idx = torch.argmax(batch["y_onehot"], dim=1).detach().cpu().numpy()
        y_true.extend(y_idx.tolist()); y_pred.extend(preds.tolist()); y_prob.extend(probs1.tolist())
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    y_prob = np.asarray(y_prob, dtype=float)
    TP = int(((y_true==1)&(y_pred==1)).sum())
    FP = int(((y_true==0)&(y_pred==1)).sum())
    FN = int(((y_true==1)&(y_pred==0)).sum())
    TN = int(((y_true==0)&(y_pred==0)).sum())
    bal_acc = balanced_accuracy_score(y_true, y_pred) if y_true.size else 0.0
    sens = recall_score(y_true, y_pred, pos_label=1) if y_true.size else 0.0
    spec = TN / (TN + FP) if (TN + FP) else 0.0
    prec = precision_score(y_true, y_pred, pos_label=1) if (TP + FP) else 0.0
    npv = TN / (TN + FN) if (TN + FN) else 0.0
    f1s = f1_score(y_true, y_pred, pos_label=1) if y_true.size else 0.0
    youden = sens + spec - 1.0
    mcc = matthews_corrcoef(y_true, y_pred) if y_true.size else 0.0
    kappa = cohen_kappa_score(y_true, y_pred) if y_true.size else 0.0
    try: roc_auc = roc_auc_score(y_true, y_prob)
    except ValueError: roc_auc = float("nan")
    return {
        "TP": TP, "FP": FP, "FN": FN, "TN": TN,
        "balanced_accuracy": float(bal_acc), "sensitivity": float(sens), "specificity": float(spec),
        "precision": float(prec), "NPV": float(npv), "F1_score": float(f1s), "Youden_J": youden,
        "MCC": float(mcc), "Cohen_Kappa": float(kappa), "ROC_AUC": float(roc_auc)
    }

In [None]:
# Train loop (save best by highest validation MCC

def _ckpt_stem(cfg: "RunConfig") -> str:
    return f"{cfg.network.name}_{cfg.modality.name}"

def _ckpt_paths(cfg: "RunConfig") -> Dict[str, str]:
    """All artifacts WITHOUT any _dimension_### suffix."""
    stem = _ckpt_stem(cfg)
    base = cfg.out_dir
    return {
        "best_pt": os.path.join(base, f"best_{stem}_val_loss.pt"),
        "curves_png": os.path.join(base, f"best_{stem}_val_loss_curves.png"),
        "curves_csv": os.path.join(base, f"best_{stem}_val_loss_history.csv"),
        "curves_png_full": os.path.join(base, f"best_{stem}_val_loss_curves_full.png"),
        "curves_csv_full": os.path.join(base, f"best_{stem}_val_loss_history_full.csv"),
        "log_json": os.path.join(base, f"best_{stem}_val_loss_log.json"),
        "mcc_png": os.path.join(base, f"best_{stem}_val_mcc_curve.png"),
    }

# (Optional) alias for compatibility if other code expects ckpt_paths(...)
ckpt_paths = _ckpt_paths

# utils & printing

def _count_params(m: nn.Module) -> Tuple[int, int, int]:
    total = sum(p.numel() for p in m.parameters())
    train = sum(p.numel() for p in m.parameters() if p.requires_grad)
    frozen = total - train
    return total, train, frozen

def _fmt_params(n: int) -> str:
    return f"{n/1e6:.2f}M" if n >= 1e6 else (f"{n/1e3:.2f}k" if n >= 1e3 else str(n))

def _print_model_overview_full(model: nn.Module, cfg: "RunConfig") -> None:
    print("\n========== Model: FULL ARCHITECTURE ==========")
    print(model)
    print("==============================================")
    print(f"Modality   : {cfg.modality.name}")
    print(f"Backbone   : {cfg.network.name}")
    print(f"Hidden dim : {cfg.hidden_dim} | Classes: {cfg.n_classes}")
    if hasattr(model, "img_enc") and (model.img_enc is not None):
        img_out = getattr(model.img_enc, "out_dim", None)
        print(f"[Image] out_dim={img_out}")
    if hasattr(model, "txt_enc") and (model.txt_enc is not None):
        hid = getattr(getattr(model.txt_enc, "config", None), "hidden_size", None)
        print(f"[Text ] hidden_size={hid}")
    tot, trn, frz = _count_params(model)
    print(f"Parameters : total={_fmt_params(tot)} | trainable={_fmt_params(trn)} | frozen={_fmt_params(frz)}")
    print("==============================================\n", flush=True)

def _save_curves_png_csv(history: Dict[str, List[float]], out_png: str, out_csv: str) -> None:
    os.makedirs(os.path.dirname(out_png), exist_ok=True)
    df = pd.DataFrame(history); df.to_csv(out_csv, index=False)
    fig, ax = plt.subplots(figsize=(8,6))
    if "train_total" in df and "val_total" in df:
        ax.plot(df["epoch"], df["train_total"], label="train_total", linewidth=2)
        ax.plot(df["epoch"], df["val_total"],   label="val_total",   linewidth=2, linestyle="--")
    ax.set_xlabel("Epoch"); ax.set_ylabel("Loss"); ax.set_title("Training / Validation Loss")
    ax.legend(loc="upper right"); ax.grid(True, alpha=0.3)
    fig.tight_layout(); fig.savefig(out_png, dpi=200); plt.close(fig)

def _save_mcc_curve(history: Dict[str, List[float]], out_png: str) -> None:
    try:
        epochs = history["epoch"]; val_mcc = history["val_MCC"]
        fig, ax = plt.subplots(figsize=(8,5))
        ax.plot(epochs, val_mcc, linewidth=2)
        ax.set_xlabel("Epoch"); ax.set_ylabel("Validation MCC"); ax.set_title("Validation MCC")
        ax.grid(True, alpha=0.3); fig.tight_layout(); fig.savefig(out_png, dpi=200); plt.close(fig)
    except Exception:
        pass

# ------------------ collate for multimodal ------------------

def make_mm_collate(modality: "MOD", n_classes: int) -> Callable:
    def _stack(xs: List[torch.Tensor]) -> torch.Tensor:
        return torch.stack(xs, dim=0) if len(xs) > 0 else torch.empty(0)
    def _collate(batch: List[Dict]) -> Dict[str, torch.Tensor]:
        imgs, ylist, ids_list, attn_list = [], [], [], []
        for s in batch:
            img, txt, y = s.get("image"), s.get("text"), s.get("y_onehot")
            if (img is None) or (txt is None) or (txt.get("input_ids") is None) or (y is None):
                continue
            imgs.append(img)
            ids_list.append(txt["input_ids"])
            attn_list.append(txt["attention_mask"])
            ylist.append(y)
        if not ylist:
            return {"y_onehot": torch.zeros(0, n_classes, dtype=torch.float32)}
        return {
            "image": _stack(imgs),
            "text": {"input_ids": _stack(ids_list), "attention_mask": _stack(attn_list)},
            "y_onehot": _stack(ylist),
        }
    return _collate

def wrap_loader_with_collate(loader: DataLoader, modality: "MOD", n_classes: int) -> DataLoader:
    collate_fn = make_mm_collate(modality, n_classes)
    return DataLoader(
        loader.dataset,
        batch_size=loader.batch_size,
        sampler=loader.sampler,
        shuffle=False if loader.sampler is not None else True,
        num_workers=loader.num_workers,
        pin_memory=loader.pin_memory,
        drop_last=loader.drop_last,
        persistent_workers=getattr(loader, "persistent_workers", False),
        prefetch_factor=getattr(loader, "prefetch_factor", None),
        collate_fn=collate_fn,
    )

# ------------------ runtime cfg ------------------

@dataclass
class TrainRuntimeCfg:
    amp: bool = True
    patience: int = 10
    grad_clip_norm: Optional[float] = 1.0
    amp_dtype: torch.dtype = torch.float16
    log_every: int = 1

# ------------------ one epoch ------------------

def run_epoch(model: nn.Module, loader: DataLoader, cfg: "RunConfig", lcfg: "LossCfg",
              *, train: bool, optim: Optional[torch.optim.Optimizer], trcfg: TrainRuntimeCfg) -> Dict[str, float]:
    device = next(model.parameters()).device
    use_amp = trcfg.amp and (device.type == "cuda")
    model.train(mode=train)
    sums = {"total": 0.0, "ce_img": 0.0, "ce_txt": 0.0, "cosine": 0.0, "contrastive": 0.0}
    n_samples = 0

    scaler = (torch.amp.GradScaler("cuda") if (train and use_amp and hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"))
              else (torch.cuda.amp.GradScaler() if (train and use_amp and hasattr(torch.cuda, "amp")) else None))

    for batch in loader:
        if batch["y_onehot"].numel() == 0:
            continue
        img = batch["image"].to(device, non_blocking=True)
        txt = {k: v.to(device, non_blocking=True) for k, v in batch["text"].items()}
        bsz = batch["y_onehot"].size(0); n_samples += bsz

        autocast_ctx = (torch.autocast("cuda", dtype=trcfg.amp_dtype) if use_amp else nullcontext())
        with torch.set_grad_enabled(train), autocast_ctx:
            out = model(img, txt)
            loss, scalars = compute_total_loss(out, batch, cfg, lcfg)

        if train:
            optim.zero_grad(set_to_none=True)
            if scaler is not None:
                scaler.scale(loss).backward()
                if trcfg.grad_clip_norm is not None:
                    scaler.unscale_(optim); nn.utils.clip_grad_norm_(model.parameters(), trcfg.grad_clip_norm)
                scaler.step(optim); scaler.update()
            else:
                loss.backward()
                if trcfg.grad_clip_norm is not None:
                    nn.utils.clip_grad_norm_(model.parameters(), trcfg.grad_clip_norm)
                optim.step()

        for k in sums.keys():
            sums[k] += float(scalars.get(k, 0.0)) * bsz

    eps = 1e-12
    return {k: (v / max(n_samples, eps)) for k, v in sums.items()}

# ------------------ fit (select best by validation MCC) ------------------

def fit_with_val_loss_checkpointing(
    model: nn.Module,
    train_loader: DataLoader,
    valid_loader: DataLoader,
    cfg: "RunConfig",
    lcfg: Optional["LossCfg"] = None,
    trcfg: TrainRuntimeCfg = TrainRuntimeCfg(),
) -> Dict[str, object]:

    assert cfg.modality == MOD.multimodal, "This function trains the MULTIMODAL model."
    if lcfg is None:
        lcfg = LossCfg(use_align_losses=cfg.use_align_losses,
                       align_lambda=cfg.align_lambda,
                       contrastive_temperature=cfg.contrastive_temperature)

    train_loader = wrap_loader_with_collate(train_loader, cfg.modality, cfg.n_classes)
    valid_loader = wrap_loader_with_collate(valid_loader, cfg.modality, cfg.n_classes)

    _print_model_overview_full(model, cfg)
    print(f"[ALIGN] λ={getattr(cfg, 'align_lambda', None)}  |  τ={getattr(cfg, 'contrastive_temperature', None)}", flush=True)

    os.makedirs(cfg.out_dir, exist_ok=True)
    paths = _ckpt_paths(cfg) 

    optim = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    history = {"epoch": [], "train_total": [], "val_total": [], "val_MCC": []}

    best_mcc = -1.0
    best_epoch = 0
    bad_epochs = 0
    improvements = 0
    stopped_early = False
    last_va_total = None
    last_va_mcc = None

    for ep in range(1, int(cfg.epochs) + 1):
        tr = run_epoch(model, train_loader, cfg, lcfg, train=True,  optim=optim, trcfg=trcfg)
        va = run_epoch(model, valid_loader, cfg, lcfg, train=False, optim=None, trcfg=trcfg)
        last_va_total = va["total"]

        val_metrics = evaluate_dataloader_classifier(model, valid_loader, cfg.modality)
        last_va_mcc = float(val_metrics["MCC"])

        best_str = ("-∞" if best_mcc < 0 else f"{best_mcc:.4f}")
        streak_str = f"{bad_epochs}/{trcfg.patience}"
        print(f"[{ep:03d}/{cfg.epochs} | best_val_MCC={best_str} | no_improve={streak_str}] "
              f"Train: total={tr['total']:.4f}  |  Valid: total={va['total']:.4f}  MCC={last_va_mcc:.4f}",
              flush=True)

        history["epoch"].append(ep)
        history["train_total"].append(tr["total"])
        history["val_total"].append(va["total"])
        history["val_MCC"].append(last_va_mcc)

        if last_va_mcc > best_mcc + 1e-8:
            print(f"  ↑ val_MCC improved: {best_mcc:.6f} → {last_va_mcc:.6f}  (saving checkpoint; reset no_improve=0)", flush=True)
            best_mcc = last_va_mcc; best_epoch = ep; bad_epochs = 0; improvements += 1
            torch.save(model.state_dict(), paths["best_pt"])
            print(f"  ↳ saved: {paths['best_pt']}", flush=True)
            with open(paths["log_json"], "w") as f:
                json.dump({"best_val_MCC": float(best_mcc), "best_epoch": int(best_epoch)}, f, indent=2)
        else:
            bad_epochs += 1
            print(f"  ↳ no MCC improvement (no_improve={bad_epochs}/{trcfg.patience})", flush=True)
            if bad_epochs >= trcfg.patience:
                print(f"EARLY STOP: validation MCC did not improve for {bad_epochs} consecutive epochs "
                      f"(patience={trcfg.patience}). Stopping at epoch {ep}.",
                      flush=True)
                stopped_early = True
                break

    _save_curves_png_csv(history, paths["curves_png"], paths["curves_csv"])
    _save_curves_png_csv(history, paths["curves_png_full"], paths["curves_csv_full"])
    _save_mcc_curve(history, paths["mcc_png"])

    try:
        state = torch.load(paths["best_pt"], map_location="cpu")
        model.load_state_dict(state, strict=True)
        print(f"Restored best (by MCC) from {paths['best_pt']}  |  best_epoch={best_epoch}  best_val_MCC={best_mcc:.6f}",
              flush=True)
    except Exception as e:
        print(f"[WARN] Could not reload best checkpoint ({e}).", flush=True)

    return {
        "best_val_MCC": float(best_mcc),
        "best_epoch": int(best_epoch),
        "improvements": int(improvements),
        "early_stopped": bool(stopped_early),
        "last_epoch_val_total": float(last_va_total) if last_va_total is not None else None,
        "last_epoch_val_MCC": float(last_va_mcc) if last_va_mcc is not None else None,
        "paths": paths,
        "history": history,
    }

In [None]:
# RUN: Multimodal-only grid over λ and τ (saves best per run by MCC)

# Base config (MULTIMODAL ONLY)
cfg = RunConfig(
    modality=MOD.multimodal,
    network=NETWORK.vgg11,
    n_classes=2,
    hidden_dim=256,
    epochs=64,
    lr=5e-5,
    weight_decay=1e-4,
    img_size=224,
    batch_size=64,
    num_workers=2,
    pin_memory=True,
    dataset_root="/dataset",
    csv_train="/label_train.csv",
    csv_valid="/label_valid.csv",
    csv_test ="/label_test.csv",
    images_subdir="images",
    reports_subdir="reports", #path to raw or structured reports, one per training and validation image
    out_dir="/models_str",
    class_names=("normal","tb"),
    use_align_losses=True,
    hf_text_model_name="microsoft/BiomedVLP-CXR-BERT-general",
    hf_local_dir="/huggingface/hub/models--microsoft--BiomedVLP-CXR-BERT-general",
    hf_offline=True
)
cfg.hf_tokenizer_local_dir = cfg.hf_local_dir
assert os.path.isdir(cfg.hf_local_dir), "hf_local_dir not found"
if cfg.hf_tokenizer_local_dir is not None:
    assert os.path.isdir(cfg.hf_tokenizer_local_dir), "hf_tokenizer_local_dir not found"

# Load CSVs → DataFrames
def _load_split(csv_path: str, cfg) -> pd.DataFrame:
    df = pd.read_csv(csv_path, header=None, names=["filename", "label"])
    df["filename"] = df["filename"].astype(str)
    df["y_index"] = df["label"].astype(int)
    K = int(cfg.n_classes)
    eye = np.eye(K, dtype=np.float32)
    df["y_onehot"] = df["y_index"].apply(lambda i: eye[int(i)])
    img_dir = getattr(cfg, "images_dir", os.path.join(cfg.dataset_root, cfg.images_subdir))
    rep_dir = getattr(cfg, "reports_dir", os.path.join(cfg.dataset_root, cfg.reports_subdir))
    df["image_path"] = df["filename"].apply(lambda fn: os.path.join(img_dir, fn))
    df["report_path"] = df["filename"].apply(lambda fn: os.path.join(rep_dir, Path(fn).with_suffix(".txt").name))
    return df

df_train = _load_split(cfg.csv_train, cfg)
df_valid = _load_split(cfg.csv_valid, cfg)
df_test = _load_split(cfg.csv_test,  cfg)

# Tokenizer + DataLoaders 
tokenizer = build_tokenizer(cfg)
train_loader, valid_loader, test_loader = make_loaders(cfg, df_train, df_valid, df_test, tokenizer=tokenizer)
assert train_loader is not None and valid_loader is not None

# Grid of (λ, τ) and training
lambda_grid = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
tau_grid    = [0.05, 0.06, 0.07, 0.08, 0.09]

overall_best = {"val_MCC": -1.0, "lam": None, "tau": None, "ckpt": None, "subdir": None,
                "metrics_valid": None, "metrics_test": None}

for lam in lambda_grid:
    for tau in tau_grid:
        subdir = f"mm_grid_lam{lam:.2f}_tau{tau:.2f}"
        run_cfg = RunConfig(
            **{**cfg.__dict__,
               "out_dir": os.path.join(cfg.out_dir, subdir),
               "align_lambda": lam,
               "contrastive_temperature": tau}
        )

        print("\n" + "="*80)
        print(f"[RUN] Multimodal training with λ={lam:.2f}, τ={tau:.2f}")
        print(f"→ artifacts dir: {run_cfg.out_dir}")
        print("="*80, flush=True)

        os.makedirs(run_cfg.out_dir, exist_ok=True)
        model = MultimodalNet(run_cfg).to(device)
        try:
            _print_model_overview_full(model, run_cfg)
        except NameError:
            pass

        trcfg = TrainRuntimeCfg(amp=True, patience=10, grad_clip_norm=1.0, amp_dtype=torch.float16)
        out = fit_with_val_loss_checkpointing(model, train_loader, valid_loader, run_cfg, trcfg=trcfg)

        paths = out["paths"]
        valid_metrics = evaluate_dataloader_classifier(
            model,
            wrap_loader_with_collate(valid_loader, run_cfg.modality, run_cfg.n_classes),
            run_cfg.modality
        )
        test_metrics = evaluate_dataloader_classifier(
            model,
            wrap_loader_with_collate(test_loader, run_cfg.modality, run_cfg.n_classes),
            run_cfg.modality
        )

        print("\n[BEST CHECKPOINT (by MCC)]")
        print(f"λ={lam:.2f}  τ={tau:.2f}")
        print(f"checkpoint : {paths['best_pt']}")
        print(f"validation MCC : {valid_metrics['MCC']:.6f}")
        print(f"test MCC : {test_metrics['MCC']:.6f}")
        print(f"valid metrics : {valid_metrics}")
        print(f"test  metrics : {test_metrics}", flush=True)

        if valid_metrics["MCC"] > overall_best["val_MCC"]:
            overall_best.update({
                "val_MCC": valid_metrics["MCC"],
                "lam": lam, "tau": tau,
                "ckpt": paths["best_pt"],
                "subdir": subdir,
                "metrics_valid": valid_metrics,
                "metrics_test": test_metrics
            })

# Final summary
print("\n" + "#"*80)
print("[OVERALL BEST across grid (by VALIDATION MCC)]")
print(f"λ={overall_best['lam']:.2f}  τ={overall_best['tau']:.2f}")
print(f"best ckpt : {overall_best['ckpt']}")
print(f"valid MCC : {overall_best['val_MCC']:.6f}")
print(f"valid metrics : {overall_best['metrics_valid']}")
print(f"test  metrics : {overall_best['metrics_test']}")
print("#"*80, flush=True)

In [None]:
# Test inference (manually provided checkpoint path)
# Save per-sample softmax with ORIGINAL filenames into softmax_preds.csv

# User config
cfg = RunConfig(
    modality=MOD.multimodal,          # default; can be overridden from ckpt name below
    network=NETWORK.vgg11,            # must match the trained checkpoint family
    n_classes=2,
    hidden_dim=256,
    epochs=64,
    lr=1e-4,
    weight_decay=1e-5,
    img_size=224,
    batch_size=64,
    num_workers=2,
    pin_memory=True,
    balance=False,
    dataset_root="/dataset",
    csv_train="/label_train.csv",
    csv_valid="/label_valid.csv",
    csv_test ="/label_test.csv",
    images_subdir="images",
    reports_subdir="reports",
    out_dir="/output",
    class_names=("normal","tb"),
    use_align_losses=True,
    hf_text_model_name="microsoft/BiomedVLP-CXR-BERT-general",
    hf_local_dir="/huggingface/hub/models--microsoft--BiomedVLP-CXR-BERT-general",
    hf_offline=True
)
cfg.hf_tokenizer_local_dir = cfg.hf_local_dir
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------- Manually set your checkpoint -------------------- #
BEST_CKPT = "best_vgg11_multimodal_val_loss.pt"
SAVE_SUBDIR_OVERRIDE = None 

# -------------------- Helpers -------------------- #
def _ckpt_stem_from_name(ckpt_path: str) -> Tuple[Optional[str], Optional[str]]:
    """Parse e.g. best_vgg16_img_val_loss.pt → (vgg16, img)."""
    base = os.path.basename(ckpt_path)
    m = re.match(r"best_([a-zA-Z0-9]+)_([a-zA-Z0-9]+)_val_loss\.pt$", base)
    if m: return m.group(1).lower(), m.group(2).lower()
    return None, None

def _infer_mod_from_token(tok: Optional[str]) -> Optional["MOD"]:
    if tok is None: return None
    tl = tok.lower()
    if "img" in tl: return MOD.img
    if "multimodal" in tl: return MOD.multimodal
    if "txt" in tl or "text" in tl: return MOD.txt
    return None

def _load_split(csv_path: str, cfg) -> pd.DataFrame:
    """Same CSV→DF logic used during training."""
    df = pd.read_csv(csv_path, header=None, names=["filename", "label"])
    df["filename"] = df["filename"].astype(str)
    df["y_index"]  = df["label"].astype(int)
    K = int(cfg.n_classes)
    eye = np.eye(K, dtype=np.float32)
    df["y_onehot"] = df["y_index"].apply(lambda i: eye[int(i)])
    img_dir = getattr(cfg, "images_dir", os.path.join(cfg.dataset_root, cfg.images_subdir))
    rep_dir = getattr(cfg, "reports_dir", os.path.join(cfg.dataset_root, cfg.reports_subdir))
    from pathlib import Path
    df["image_path"] = df["filename"].apply(lambda fn: os.path.join(img_dir, fn))
    df["report_path"] = df["filename"].apply(lambda fn: os.path.join(rep_dir, Path(fn).with_suffix(".txt").name))
    return df

def _count_params(m: nn.Module) -> Tuple[int, int, int]:
    tot = sum(p.numel() for p in m.parameters())
    trn = sum(p.numel() for p in m.parameters() if p.requires_grad)
    frz = tot - trn
    return tot, trn, frz

def _print_and_save_model_overview(model: nn.Module, cfg_run: RunConfig, out_txt: str) -> None:
    print("\n========== Model (loaded for inference) ==========")
    print(f"Modality : {cfg_run.modality.name}")
    print(f"Backbone : {cfg_run.network.name}")
    print(model)  # full architecture
    tot, trn, frz = _count_params(model)
    print(f"Parameters : total={tot:,}  |  trainable={trn:,}  |  frozen={frz:,}")
    print("===============================================\n")
    os.makedirs(os.path.dirname(out_txt), exist_ok=True)
    with open(out_txt, "w") as f:
        f.write(str(model) + "\n")
        f.write(f"total={tot:,} trainable={trn:,}  frozen={frz:,}\n")

# -------------------- NEW: plain (no-wrap) loader for filename-safe CSV -------------------- #
def _build_plain_test_loader_from_wrapped(wrapped_loader: DataLoader, cfg_run: "RunConfig") -> DataLoader:
    """
    Use the SAME dataset as the wrapped test loader, but default collate_fn so that
    keys like 'filename' remain simple lists of strings. This keeps original filenames.
    """
    ds = wrapped_loader.dataset                       
    return DataLoader(
        ds,
        batch_size=int(cfg_run.batch_size),
        shuffle=False,
        num_workers=int(cfg_run.num_workers),
        pin_memory=bool(cfg_run.pin_memory),
        drop_last=False,
        persistent_workers=bool(int(cfg_run.num_workers) > 0),
    )

@torch.inference_mode()
def evaluate_manual_ckpt_with_training_stack(cfg_base: "RunConfig",
                                             ckpt_path: str,
                                             save_subdir: Optional[str] = None) -> Dict[str, float]:
    assert os.path.isfile(ckpt_path), f"Checkpoint not found: {ckpt_path}"

    # Infer modality token from filename (img/multimodal/txt); keep network as provided.
    net_tok, mod_tok = _ckpt_stem_from_name(ckpt_path)
    infer_mod = _infer_mod_from_token(mod_tok)
    cfg_run = dataclasses.replace(cfg_base, modality=(infer_mod or cfg_base.modality))
    cfg_run.out_dir = os.path.dirname(ckpt_path)

    # Rebuild DataFrames exactly like training
    df_train = _load_split(cfg_run.csv_train, cfg_run)
    df_valid = _load_split(cfg_run.csv_valid, cfg_run)
    df_test = _load_split(cfg_run.csv_test,  cfg_run)

    # Tokenizer + DataLoaders with the *same* function
    tok = build_tokenizer(cfg_run)
    _, _, test_loader_wrapped = make_loaders(cfg_run, df_train, df_valid, df_test, tokenizer=tok)

    # Safe MM collate identical to training eval (kept for metrics)
    test_loader_wrapped = wrap_loader_with_collate(test_loader_wrapped, cfg_run.modality, cfg_run.n_classes)

    # ALSO build a plain loader from the same dataset so 'filename' is preserved
    test_loader_plain = _build_plain_test_loader_from_wrapped(test_loader_wrapped, cfg_run)

    # Load model & weights
    model = MultimodalNet(cfg_run).to(device).eval()
    state = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(state, strict=True)

    # Print & save overview
    stem = f"{(net_tok or cfg_run.network.name)}_{(mod_tok or cfg_run.modality.name)}"
    save_subdir = save_subdir or f"test_infer_{stem}_raw_shenzhen_new"    
    save_dir = os.path.join(cfg_run.out_dir, save_subdir)
    os.makedirs(save_dir, exist_ok=True)
    _print_and_save_model_overview(model, cfg_run, os.path.join(save_dir, "model_summary.txt"))

    # Evaluate with the SAME helper as training (uses wrapped loader)
    metrics = evaluate_dataloader_classifier(model, test_loader_wrapped, cfg_run.modality)

    # Persist metrics (unchanged)
    pd.DataFrame([metrics]).to_csv(os.path.join(save_dir, "test_metrics.csv"), index=False)

    # ----------------------- Save per-sample softmax to CSV with ORIGINAL filenames -----------------------
    rows: List[dict] = []
    for batch in test_loader_plain:
        # Batch is a dict from TBMultimodalDataset with filename as a list[str]
        if ("image" not in batch) or ("y_onehot" not in batch) or ("filename" not in batch):
            continue
        if batch["y_onehot"].numel() == 0:
            continue

        # tensors
        x = batch["image"].to(device, non_blocking=True)
        txt = None
        if ("text" in batch) and (batch["text"] is not None):
            txt = {k: v.to(device, non_blocking=True) for k, v in batch["text"].items()}

        # forward
        out = model(x, txt)
        if isinstance(out, dict):
            if out.get("logits_img", None) is not None:
                logits = out["logits_img"]
            elif out.get("logits_txt", None) is not None:
                logits = out["logits_txt"]
            elif out.get("logits", None) is not None:
                logits = out["logits"]
            else:
                # last resort: any (B,2) tensor in dict
                logits = None
                for v in out.values():
                    if isinstance(v, torch.Tensor) and v.dim() == 2 and v.size(-1) == 2:
                        logits = v
                        break
        else:
            logits = out  # some models return a tensor directly

        if not isinstance(logits, torch.Tensor) or logits.dim() != 2 or logits.size(-1) != 2:
            raise RuntimeError(f"Expected logits (B,2), got {type(logits)} with shape "
                               f"{tuple(logits.shape) if isinstance(logits, torch.Tensor) else 'N/A'}")

        probs = F.softmax(logits.float(), dim=1).detach().cpu().numpy()  # (B,2)
        pred = probs.argmax(axis=1).astype(int)
        y_idx = torch.argmax(batch["y_onehot"], dim=1).detach().cpu().numpy().astype(int)

        # filenames are already as provided by dataset (strings with extension)
        names = [os.path.basename(str(s)) for s in batch["filename"]]
        B = probs.shape[0]
        assert len(names) == B, "Filename list length must match batch size."

        for i in range(B):
            rows.append({
                "img": names[i],
                "true_label": int(y_idx[i]),
                "prob_0": float(probs[i, 0]),
                "prob_1": float(probs[i, 1]),
                "predicted_label": int(pred[i]),
            })

    out_csv = os.path.join(save_dir, "softmax_preds.csv")
    pd.DataFrame(rows).to_csv(out_csv, index=False)
    print(f"[SOFTMAX] Saved per-sample probabilities to: {out_csv}")
    # ------------------------------------------------------------------------------------------------------

    # Echo results (unchanged)
    print(f"\n=== Test results (manual ckpt) [{net_tok or cfg_run.network.name}/{cfg_run.modality.name}] ===")
    for k, v in metrics.items():
        print(f"{k:>18s}: {v}")

    return metrics

# Runner
assert os.path.isfile(BEST_CKPT), f"Checkpoint not found: {BEST_CKPT}"
print("\n================ Inference Configuration (manual ckpt) ================")
print(f"Checkpoint (pt)  : {BEST_CKPT}")
try:
    fsize_mb = os.path.getsize(BEST_CKPT) / (1024**2)
    print(f"Checkpoint size  : {fsize_mb:.2f} MB")
except Exception:
    pass
print("=======================================================================\n")
evaluate_manual_ckpt_with_training_stack(cfg, BEST_CKPT, save_subdir=SAVE_SUBDIR_OVERRIDE)

### GRAD-CAM VISUALIZATION WITH INTERNAL SHENZHEN TEST SET

1. Uses manual checkpoint.
2. Reads filenames from csv_test, so we won’t accidentally process extra images in the 1024 dir).
3. For each filename, it prefers the 1024×1024 version (if missing, it resizes the original to 1024×1024 on the fly.
4. Loads YOLO GT boxes (normalized) from the original annotation .txt, converts to original pixels, then rescales to 1024×1024. GT parsing assumes YOLO normalized lines cls cx cy w h in shen_mask_crop/<name>.txt. Each box is converted from original pixels to 1024×1024 via (sx, sy) = (1024/orig_W, 1024/orig_H).
5. Saves into gradcam_manual_<network>_<modality>_overlap/ under the model directory, with subfolders images/, heatmaps/, contours/, bboxes/.
6. uses blue for model bboxes and red for GT bboxes.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# User config
cfg = RunConfig(
    modality=MOD.multimodal,
    network=NETWORK.vgg11,
    n_classes=2,
    hidden_dim=256,
    epochs=64,
    lr=5e-5,
    weight_decay=1e-4,
    img_size=224,
    batch_size=64,
    num_workers=2,
    pin_memory=True,
    balance=False,
    dataset_root="/dataset",
    class_names=("normal", "tb"),
    use_align_losses=True,
    hf_text_model_name="microsoft/BiomedVLP-CXR-BERT-general",
    hf_local_dir="/huggingface/hub/models--microsoft--BiomedVLP-CXR-BERT-general",
    hf_offline=True
)

# Paths (manual checkpoint + data)
MANUAL_CKPT_PATH = "/best_vgg11_multimodal_val_loss.pt"  
CSV_TEST_PATH = "/dataset/label_test.csv"
ORIG_DIR = "/dataset/shen_orig_crop"
YOLO_GT_DIR = "/dataset/shen_mask_crop"
RESIZED_1024_DIR = "/dataset/images"  # contains many images; we will only use those present in csv_test

# Output folder name
SAVE_ROOT_NAME = f"gradcam_{cfg.network.name}_{cfg.modality.name}_internal_overlap"

# Grad-CAM method & drawing params
CAM_METHOD = "gradcam"             # "gradcam", "gradcam++", "xgradcam", "layercam", ...
HEATMAP_ALPHA = 0.5
BIN_THR = 0.4                   # threshold for binarizing CAM to extract contours/bboxes

# Colors / thickness (BGR)
CONTOUR_COLOR = (0, 0, 255)           # red for contours
CONTOUR_THICK = 3
BB_MODEL_COLOR = (255, 0, 0)           # blue for model bboxes (requested)
BB_GT_COLOR = (0, 0, 255)           # red  for GT boxes (requested)
BB_THICK = 4
TARGET_CLASS_IDX = 1                    # "abnormal" class
IMG_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
RESIZE_TO = 1024                        # target side for visualization base

# Helpers
def _ckpt_stem(cfg: "RunConfig") -> str:
    return f"{cfg.network.name}_{cfg.modality.name}"

def _parse_lam_tau_from_path(path: str) -> Tuple[Optional[float], Optional[float]]:
    lam = tau = None
    m1 = re.search(r"lam([0-9]+(?:\.[0-9]+)?)", path)
    m2 = re.search(r"tau([0-9]+(?:\.[0-9]+)?)", path)
    if m1: lam = float(m1.group(1))
    if m2: tau = float(m2.group(1))
    return lam, tau

def print_model_overview_for_cam(model: nn.Module, img_size: int = 224):
    total = sum(p.numel() for p in model.parameters())
    train = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Grad-CAM — Model Overview")
    print(f"Device: {device.type}")
    print(f"Input size (C,H,W): (3, {img_size}, {img_size})")
    print(f"Total params:  {total:,}")
    print(f"Trainable params: {train:,}")
    print(f"Frozen params: {total-train:,}")
    print("="*80)

def _kernel_tuple(m: nn.Conv2d) -> Tuple[int,int]:
    k = m.kernel_size
    return (k if isinstance(k, tuple) else (k, k))

def _last_conv_kgt1(module: nn.Module) -> Optional[nn.Conv2d]:
    last_any = None; last_kgt1 = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            last_any = m
            if max(_kernel_tuple(m)) > 1:
                last_kgt1 = m
    return last_kgt1 if last_kgt1 is not None else last_any

def _vit_patch_embed_conv(module: nn.Module) -> Optional[nn.Conv2d]:
    target = None
    for name, m in module.named_modules():
        if isinstance(m, nn.Conv2d) and ("patch_embed" in name or "patch" in name):
            target = m
    return target

def _qualname_of_module(root: nn.Module, target: nn.Module) -> str:
    for n, m in root.named_modules():
        if m is target:
            return n or "<root>"
    return "<unknown>"

def select_target_layer_for_cam(model: nn.Module) -> nn.Conv2d:
    enc = getattr(model, "img_enc", None)
    if enc is not None:
        p3 = getattr(enc, "post3x3", None)
        if isinstance(p3, nn.Sequential):
            for m in p3.modules():
                if isinstance(m, nn.Conv2d) and _kernel_tuple(m) == (3,3):
                    return m
        elif isinstance(p3, nn.Conv2d) and _kernel_tuple(p3) == (3,3):
            return p3
        tgt = _last_conv_kgt1(enc)
        if isinstance(tgt, nn.Conv2d): return tgt
        vitp = _vit_patch_embed_conv(enc)
        if isinstance(vitp, nn.Conv2d): return vitp
    tgt = _last_conv_kgt1(model)
    if isinstance(tgt, nn.Conv2d): return tgt
    raise RuntimeError("No suitable Conv2d layer found for Grad-CAM.")

class CamImageWrapper(nn.Module):
    def __init__(self, mm: nn.Module): super().__init__(); self.mm = mm
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mm.forward_image(x)  # image head

def build_preprocess(size: int) -> T.Compose:
    return T.Compose([
        T.ToPILImage(), T.Resize((size, size)),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])

def _init_cam(CAMClass, model, target_layers):
    sig = inspect.signature(CAMClass.__init__)
    if "use_cuda" in sig.parameters:
        return CAMClass(model=model, target_layers=target_layers, use_cuda=(device.type == "cuda"))
    return CAMClass(model=model, target_layers=target_layers)

@torch.no_grad()
def _load_ckpt_model(cfg: "RunConfig", ckpt_path: str) -> "MultimodalNet":
    m = MultimodalNet(cfg).to(device).eval()
    state = torch.load(ckpt_path, map_location="cpu")
    m.load_state_dict(state, strict=True)
    return m

def _read_csv_test(csv_path: str) -> pd.DataFrame:
    try:
        df = pd.read_csv(csv_path, header=None, sep=None, engine="python", encoding="utf-8-sig")
    except Exception:
        df = pd.read_csv(csv_path, header=None, encoding="utf-8-sig")
    if df.shape[1] < 2:
        raise ValueError("label_test.csv must have >= 2 columns (filename, label).")
    df = df.iloc[:, :2].copy(); df.columns = ["img", "label"]

    def _is_hdr(s: str) -> bool:
        s = (str(s) or "").strip().lower()
        return s in {"img","image","filename","file","path","label","class","target","y"} or s.startswith("img")

    if _is_hdr(df.iloc[0,0]) and _is_hdr(df.iloc[0,1]):
        df = df.iloc[1:].reset_index(drop=True)

    df["img"] = df["img"].astype(str).str.strip()
    df["label"] = pd.to_numeric(df["label"], errors="coerce").fillna(0).astype(int)
    return df

def _is_image_name(name: str) -> bool:
    return os.path.splitext(name)[1].lower() in IMG_EXTS

def _load_original_size(name: str) -> Tuple[np.ndarray, int, int]:
    """Load the original image from ORIG_DIR to get W,H and for fallback resize."""
    p = os.path.join(ORIG_DIR, name)
    bgr = cv2.imread(p)
    if bgr is None:
        raise FileNotFoundError(f"Original image not found: {p}")
    H, W = bgr.shape[:2]
    return bgr, W, H

def _load_resized_1024(name: str) -> Tuple[np.ndarray, bool]:
    """Load the 1024×1024 image from RESIZED_1024_DIR by exact filename; return (bgr, exists_in_dir)."""
    p = os.path.join(RESIZED_1024_DIR, name)
    bgr = cv2.imread(p)
    if bgr is None:
        return None, False
    return bgr, True

def _ensure_1024_base_image(name: str) -> Tuple[np.ndarray, int, int, int, int]:
    """
    Returns:
      base_bgr(1024×1024), base_W(1024), base_H(1024), orig_W, orig_H
    If 1024 version is missing, resize the original to 1024×1024 on the fly.
    """
    # Try 1024 file first
    bgr_1024, ok = _load_resized_1024(name)
    orig_bgr, orig_W, orig_H = _load_original_size(name)  # we need orig dims for GT scaling
    if not ok:
        base_bgr = cv2.resize(orig_bgr, (RESIZE_TO, RESIZE_TO), interpolation=cv2.INTER_AREA)
    else:
        # Be safe: if that file isn’t exactly 1024×1024, force-resize to 1024×1024
        H1, W1 = bgr_1024.shape[:2]
        base_bgr = bgr_1024 if (H1 == RESIZE_TO and W1 == RESIZE_TO) else cv2.resize(bgr_1024, (RESIZE_TO, RESIZE_TO), interpolation=cv2.INTER_AREA)
    return base_bgr, RESIZE_TO, RESIZE_TO, orig_W, orig_H

def _read_yolo_gt_scaled_to_1024(name: str, orig_W: int, orig_H: int) -> List[Tuple[int,int,int,int]]:
    """
    YOLO txt: 'cls cx cy w h' with values normalized to original image (W,H).
    Convert to original pixels, then scale to 1024×1024.
    """
    txt_path = os.path.join(YOLO_GT_DIR, os.path.splitext(name)[0] + ".txt")
    if not os.path.isfile(txt_path):
        return []
    boxes = []
    with open(txt_path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            toks = re.split(r"[,\s]+", line)
            if len(toks) < 5:
                continue
            try:
                # YOLO: cls cx cy w h (normalized)
                _ = int(float(toks[0])) # class (unused)
                cx = float(toks[1]); cy = float(toks[2])
                w  = float(toks[3]); h  = float(toks[4])
            except Exception:
                continue
            # original pixels
            x1 = (cx - w/2.0) * orig_W
            y1 = (cy - h/2.0) * orig_H
            x2 = (cx + w/2.0) * orig_W
            y2 = (cy + h/2.0) * orig_H
            # scale to 1024
            sx = RESIZE_TO / float(orig_W)
            sy = RESIZE_TO / float(orig_H)
            X1 = int(max(0, min(RESIZE_TO-1, round(x1 * sx))))
            Y1 = int(max(0, min(RESIZE_TO-1, round(y1 * sy))))
            X2 = int(max(0, min(RESIZE_TO-1, round(x2 * sx))))
            Y2 = int(max(0, min(RESIZE_TO-1, round(y2 * sy))))
            if X2 > X1 and Y2 > Y1:
                boxes.append((X1, Y1, X2, Y2))
    return boxes

# Main: Grad-CAM over csv_test, overlay GT (scaled)
def run_gradcam_from_csv_resized_overlap(
    cfg: "RunConfig",
    ckpt_path: str,
    csv_test_path: str,
    save_root_overlap: str,
    cam_method: str = CAM_METHOD,
    heatmap_alpha: float = HEATMAP_ALPHA,
    bin_thr: float = BIN_THR,
):
    assert os.path.isfile(ckpt_path), f"Checkpoint not found: {ckpt_path}"
    os.makedirs(save_root_overlap, exist_ok=True)
    out_dirs = {
        "images": os.path.join(save_root_overlap, "images"),
        "heatmaps": os.path.join(save_root_overlap, "heatmaps"),
        "contours": os.path.join(save_root_overlap, "contours"),
        "bboxes": os.path.join(save_root_overlap, "bboxes"),
    }
    for d in out_dirs.values():
        os.makedirs(d, exist_ok=True)

    # Load model & CAM
    mm = _load_ckpt_model(cfg, ckpt_path)
    print_model_overview_for_cam(mm, img_size=int(getattr(cfg, "img_size", 224)))
    cam_model = CamImageWrapper(mm).to(device).eval()
    target_layer = select_target_layer_for_cam(mm)
    for p in target_layer.parameters(): p.requires_grad_(True)
    tl_name = _qualname_of_module(mm, target_layer)
    print(f"[CAM] Target layer: {tl_name}  kernel={getattr(target_layer, 'kernel_size', None)}")
    methods = {
        "gradcam": GradCAM, "gradcam++": GradCAMPlusPlus, "hirescam": HiResCAM,
        "xgradcam": XGradCAM, "layercam": LayerCAM, "eigencam": EigenCAM,
        "eigengradcam": EigenGradCAM, "scorecam": ScoreCAM, "ablationcam": AblationCAM
    }
    mkey = cam_method.lower()
    if mkey not in methods:
        raise ValueError(f"Unknown CAM method '{cam_method}'")
    CAMClass = methods[mkey]
    cam = _init_cam(CAMClass, model=cam_model, target_layers=[target_layer])
    try: cam.batch_size = 1
    except: pass

    preprocess = build_preprocess(int(getattr(cfg, "img_size", 224)))
    df = _read_csv_test(csv_test_path)

    # Iterate only over test filenames from CSV (ignore any extra files in 1024 dir)
    for _, row in df.iterrows():
        name = str(row["img"]).strip()
        if not _is_image_name(name):
            continue

        # Load base visualization image (1024×1024), and original dims for GT scaling
        try:
            base_bgr, W_vis, H_vis, orig_W, orig_H = _ensure_1024_base_image(name)
        except FileNotFoundError:
            # If the original is missing, skip safely
            print(f"[WARN] Skipping (image missing): {name}")
            continue

        # Prepare model input from the 1024 image
        rgb = cv2.cvtColor(base_bgr, cv2.COLOR_BGR2RGB)
        x = preprocess(rgb).unsqueeze(0).to(device)

        # Run CAM
        with torch.enable_grad():
            mask = cam(
                input_tensor=x,
                targets=[ClassifierOutputTarget(int(TARGET_CLASS_IDX))],
                aug_smooth=True, eigen_smooth=True
            )[0] if mkey != "eigencam" else cam(input_tensor=x)[0]

        # Normalize & upsample to 1024
        mmin, mmax = float(np.min(mask)), float(np.max(mask))
        mask = (mask - mmin) / (mmax - mmin + 1e-8)
        mask_u8 = (np.clip(cv2.resize(mask, (W_vis, H_vis), interpolation=cv2.INTER_NEAREST), 0, 1) * 255).astype(np.uint8)

        # Heatmap overlay (on 1024 image)
        heat = cv2.applyColorMap(mask_u8, cv2.COLORMAP_HOT)
        heat_overlay = cv2.addWeighted(heat, float(heatmap_alpha), base_bgr, 1.0 - float(heatmap_alpha), 0.0)

        # Contours from thresholded mask
        thr = int(255 * float(bin_thr))
        _, binm = cv2.threshold(mask_u8, thr, 255, cv2.THRESH_BINARY)
        contours,_ = cv2.findContours(binm, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Draw contours (red)
        cont_img = base_bgr.copy()
        if len(contours) > 0:
            cv2.drawContours(cont_img, contours, -1, CONTOUR_COLOR, CONTOUR_THICK, lineType=cv2.LINE_AA)

        # Draw model bounding boxes (blue)
        box_img = base_bgr.copy()
        model_boxes = []
        for c in contours:
            x0, y0, w, h = cv2.boundingRect(c)
            x1, y1, x2, y2 = x0, y0, x0 + w, y0 + h
            model_boxes.append((x1, y1, x2, y2))
            cv2.rectangle(box_img, (x1, y1), (x2, y2), BB_MODEL_COLOR, BB_THICK, lineType=cv2.LINE_AA)

        # Load & overlay GT boxes scaled to 1024
        gt_boxes_1024 = _read_yolo_gt_scaled_to_1024(name, orig_W, orig_H)

        # Write base image
        stem = os.path.splitext(name)[0]
        cv2.imwrite(os.path.join(out_dirs["images"], f"{stem}.png"), base_bgr)

        # 1) Heatmaps + GT red boxes
        heat_ov = heat_overlay.copy()
        for (gx1, gy1, gx2, gy2) in gt_boxes_1024:
            cv2.rectangle(heat_ov, (gx1, gy1), (gx2, gy2), BB_GT_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        cv2.imwrite(os.path.join(out_dirs["heatmaps"], f"{stem}__{cfg.network.name}__{mkey}.png"), heat_ov)

        # 2) Contours + GT red boxes
        cont_ov = cont_img.copy()
        for (gx1, gy1, gx2, gy2) in gt_boxes_1024:
            cv2.rectangle(cont_ov, (gx1, gy1), (gx2, gy2), BB_GT_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        cv2.imwrite(os.path.join(out_dirs["contours"], f"{stem}__{cfg.network.name}__{mkey}.png"), cont_ov)

        # 3) BBoxes: blue model boxes + red GT boxes (requested colors)
        box_ov = base_bgr.copy()
        # model (blue)
        for (x1, y1, x2, y2) in model_boxes:
            cv2.rectangle(box_ov, (x1, y1), (x2, y2), BB_MODEL_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        # GT (red)
        for (gx1, gy1, gx2, gy2) in gt_boxes_1024:
            cv2.rectangle(box_ov, (gx1, gy1), (gx2, gy2), BB_GT_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        cv2.imwrite(os.path.join(out_dirs["bboxes"], f"{stem}__{cfg.network.name}__{mkey}.png"), box_ov)

    # Cleanup
    try:
        if hasattr(cam, "activations_and_grads") and (cam.activations_and_grads is not None):
            cam.activations_and_grads.release()
    except Exception:
        pass
    del cam, cam_model, mm
    if device.type == "cuda":
        torch.cuda.empty_cache()
    print(f"[CAM] Saved to: {save_root_overlap}")

# Runner (manual ckpt, csv_test)
assert os.path.isfile(MANUAL_CKPT_PATH), f"Checkpoint not found: {MANUAL_CKPT_PATH}"
best_subdir = os.path.dirname(MANUAL_CKPT_PATH)
save_root_overlap = os.path.join(best_subdir, SAVE_ROOT_NAME)

print("\n================ Grad-CAM Configuration ================")
print(f"Checkpoint : {MANUAL_CKPT_PATH}")
print(f"Backbone   : {cfg.network.name}")
print(f"Modality   : {cfg.modality.name}")
print(f"csv_test   : {CSV_TEST_PATH}")
print(f"Orig dir   : {ORIG_DIR}")
print(f"GT (YOLO)  : {YOLO_GT_DIR}")
print(f"ResizedDir : {RESIZED_1024_DIR}  (will only use files present in csv_test)")
print(f"Save root  : {save_root_overlap}")
print("=======================================================\n")

run_gradcam_from_csv_resized_overlap(
    cfg=cfg,
    ckpt_path=MANUAL_CKPT_PATH,
    csv_test_path=CSV_TEST_PATH,
    save_root_overlap=save_root_overlap,
    cam_method=CAM_METHOD,
    heatmap_alpha=HEATMAP_ALPHA,
    bin_thr=BIN_THR,
)

### Computing GRAD-CAM visualizations with the external TBX11K test set

1. Annotations (512×512) → pre-crop (256×256); Scale the authors’ ground truth boxes by 0.5.
2. Pre-crop (256×256) → lung-crop: Use the provided 256×256 mask to recover the actual crop ROI (the tight bounding rectangle of non-zero mask, optionally with a small margin). Intersect the pre-crop box with this ROI and translate it to the crop coordinate frame.
3. Lung-crop → final saved TB image (256×256): The crop ROI is then resized to 256×256 when saving the lung-cropped images. So scale the translated box by: sx = 256 / (roi_w),  sy = 256 / (roi_h)
4. Use floor for x1,y1 and ceil for x2,y2 to preserve coverage; clip to [0,255].
5. Run Grad-CAM on those TB images and overlays blue model CAM boxes + red GT boxes.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Logging
logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] %(levelname)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger("gradcam-tbx11k-mask-mapped")

def set_deterministic(seed: int = 1337) -> None:
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
set_deterministic(1337)

# User config
cfg = RunConfig(
    modality=MOD.multimodal,
    network=NETWORK.vgg11,
    n_classes=2,
    hidden_dim=256,
    epochs=64,
    lr=5e-5,
    weight_decay=1e-4,
    img_size=224,   
    batch_size=64,
    num_workers=2,
    pin_memory=True,
    dataset_root="/dataset",
    class_names=("normal", "tb"),
    use_align_losses=True,
    hf_text_model_name="microsoft/BiomedVLP-CXR-BERT-general",
    hf_local_dir="/huggingface/hub/models--microsoft--BiomedVLP-CXR-BERT-general",
    hf_offline=True
)

# Manual checkpoint + TBX11K paths
MANUAL_CKPT_PATH = "/best_vgg11_multimodal_val_loss.pt" 
CSV_TEST_PATH    = "/dataset/label_test_tbx11k.csv"

# Final lung-cropped TB images (256×256)
TB_256_DIR = "/dataset/tbx11k/cropped/tb"
# Original TB images (512×512)
ORIG_TB_DIR = "/dataset/tbx11k/orig/tb"
# 256×256 masks aligned to the pre-crop 256×256 space
TB_MASKS_256_DIR = "/dataset/tbx11k/masks/tb"
# Authors' lesion boxes (512×512 coordinates)
BBOX_CSV_PATH = "/dataset/data_tbx11k.csv"
SAVE_ROOT_NAME = f"gradcam_tbx11k_{cfg.network.name}_{cfg.modality.name}_external_overlap"

# Grad-CAM & drawing params
CAM_METHOD  = "gradcam"   
HEATMAP_ALPHA = 0.5
BIN_THR = 0.4 

# Colors / thickness (BGR)
CONTOUR_COLOR = (0,0,255)
CONTOUR_THICK = 2
BB_MODEL_COLOR = (255,0,0)
BB_GT_COLOR = (0,0,255)
BB_THICK = 3
TARGET_CLASS_IDX = 1
IMG_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
FINAL_SIZE = 256              # visualization base size (final TB images are 256×256)
# Optional crop behavior 
ROI_MARGIN_FRAC = 0.00        # e.g., 0.02 to add 2% per side
ROI_ENFORCE_SQUARE = False    # set True if you padded to square before resizing

# Path/CSV utils
JUNK_DIR_TOKENS = {"__pycache__", ".ipynb_checkpoints", ".git", ".svn", ".DS_Store"}
JUNK_FILE_BASENAMES = {"thumbs.db", "desktop.ini"}
_CHECKPOINT_RE = re.compile(r"(?i)(?:^|[^A-Za-z0-9])checkpoint(?:$|[^A-Za-z0-9])")
_DOTFILE_RE    = re.compile(r"^\.")  # dot files

def _is_img(path_or_name: str) -> bool:
    return os.path.splitext(path_or_name)[1].lower() in IMG_EXTS

def _is_junk_path(path: str) -> bool:
    p = pathlib.Path(path)
    for part in p.parts:
        if part in JUNK_DIR_TOKENS or _DOTFILE_RE.search(part):
            return True
    base = p.name
    if base.lower() in JUNK_FILE_BASENAMES:
        return True
    if _DOTFILE_RE.match(base) or _CHECKPOINT_RE.search(os.path.splitext(base)[0]):
        return True
    return False

def _list_images(dir_path: str) -> List[str]:
    out = []
    for root, dirs, files in os.walk(dir_path):
        dirs[:] = [d for d in dirs if not _is_junk_path(os.path.join(root, d))]
        for fn in files:
            fp = os.path.join(root, fn)
            if not _is_junk_path(fp) and _is_img(fn):
                out.append(fp)
    return sorted(out)

def _read_test_csv(csv_path: str) -> pd.DataFrame:
    try:
        df = pd.read_csv(csv_path, header=None, sep=None, engine="python", encoding="utf-8-sig")
    except Exception:
        df = pd.read_csv(csv_path, header=None, encoding="utf-8-sig")
    if df.shape[1] < 2:
        raise ValueError("Test CSV must have >= 2 columns (filename, label).")
    df = df.iloc[:, :2]; df.columns = ["img", "label"]
    # strip header row if present
    def _ishdr(x): 
        s = str(x).strip().lower()
        return s in {"img","image","filename","file","path","label","class","target","y"} or s.startswith("img")
    if _ishdr(df.iloc[0,0]) and _ishdr(df.iloc[0,1]):
        df = df.iloc[1:].reset_index(drop=True)
    df["img"]   = df["img"].astype(str).map(lambda s: os.path.basename(s.strip()))
    df["label"] = pd.to_numeric(df["label"], errors="coerce").fillna(0).astype(int)
    df = df[df["img"].map(_is_img)].reset_index(drop=True)
    return df

def _find_by_basename(dir_path: str, basename: str) -> Optional[str]:
    base, _ = os.path.splitext(basename)
    for ext in [".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"]:
        p = os.path.join(dir_path, base + ext)
        if os.path.isfile(p):
            return p
    return None

def crosscheck_tb(csv_tb_names: List[str], tb_dir: str) -> List[str]:
    csv_tb_set = {n for n in csv_tb_names if n.lower().startswith("tb")}
    dir_tb_set = {os.path.basename(p) for p in _list_images(tb_dir) if os.path.basename(p).lower().startswith("tb")}
    inter = sorted(csv_tb_set & dir_tb_set)
    log.info("[TB CHECK] CSV tb count : %d", len(csv_tb_set))
    log.info("[TB CHECK] Dir tb count : %d", len(dir_tb_set))
    log.info("[TB CHECK] Match count  : %d", len(inter))
    only_csv = sorted(csv_tb_set - dir_tb_set)
    only_dir = sorted(dir_tb_set - csv_tb_set)
    if only_csv or only_dir:
        log.warning("[TB CHECK] Mismatch — only_in_csv=%d, only_in_dir=%d", len(only_csv), len(only_dir))
        if only_csv[:5]: log.warning("  e.g., only_in_csv: %s", only_csv[:5])
        if only_dir[:5]: log.warning("  e.g., only_in_dir: %s", only_dir[:5])
    else:
        log.info("[TB CHECK] Filename sets MATCH ✅")
    return inter

# BBoxes (authors @512) → final 256 via mask-driven crop
def _parse_bbox_field(braw: Any) -> List[Dict[str, float]]:
    """bbox can be list[dict] or dict-as-string; return list of dicts with xmin,ymin,width,height."""
    if braw is None or (isinstance(braw, float) and math.isnan(braw)): return []
    try:
        b = ast.literal_eval(braw) if isinstance(braw, str) else braw
    except Exception:
        return []
    if isinstance(b, dict):  return [b]
    if isinstance(b, list):  return [x for x in b if isinstance(x, dict)]
    return []

def _mask_to_roi(mask_256: np.ndarray, margin_frac: float = ROI_MARGIN_FRAC,
                 enforce_square: bool = ROI_ENFORCE_SQUARE) -> Optional[Tuple[int,int,int,int]]:
    """Return (x1,y1,x2,y2) tight ROI of non-zero mask with optional margin/square padding."""
    if mask_256 is None: return None
    if mask_256.max() > 1:
        _, binm = cv2.threshold(mask_256, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    else:
        binm = (mask_256 > 0).astype(np.uint8)*255
    cnts, _ = cv2.findContours(binm, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not cnts: 
        return None
    x1 = min(cv2.boundingRect(c)[0] for c in cnts)
    y1 = min(cv2.boundingRect(c)[1] for c in cnts)
    x2 = max(cv2.boundingRect(c)[0] + cv2.boundingRect(c)[2] for c in cnts)
    y2 = max(cv2.boundingRect(c)[1] + cv2.boundingRect(c)[3] for c in cnts)
    # margin
    if margin_frac and margin_frac > 0:
        w = x2 - x1; h = y2 - y1
        mx = int(round(w * margin_frac)); my = int(round(h * margin_frac))
        x1 -= mx; y1 -= my; x2 += mx; y2 += my
    # enforce square (optional)
    if enforce_square:
        w = x2 - x1; h = y2 - y1
        side = max(w, h)
        cx = (x1 + x2) // 2; cy = (y1 + y2) // 2
        # re-center to square
        x1 = cx - side//2; y1 = cy - side//2
        x2 = x1 + side;     y2 = y1 + side
    # clip to [0,255]
    x1 = max(0, min(255, x1)); y1 = max(0, min(255, y1))
    x2 = max(1, min(256, x2)); y2 = max(1, min(256, y2))
    if x2 <= x1 or y2 <= y1:
        return None
    return (x1, y1, x2, y2)

def _scale_512_to_256_box(xmin: float, ymin: float, width: float, height: float,
                          src_w: int = 512, src_h: int = 512) -> Tuple[float,float,float,float]:
    """Return (x1,y1,x2,y2) in pre-crop 256×256 space."""
    sx = 256.0 / float(src_w); sy = 256.0 / float(src_h)
    x1 = (xmin) * sx; y1 = (ymin) * sy
    x2 = (xmin + width) * sx; y2 = (ymin + height) * sy
    return x1, y1, x2, y2

def _map_box_pre256_to_final256_via_roi(box_pre: Tuple[float,float,float,float],
                                        roi: Tuple[int,int,int,int]) -> Optional[Tuple[int,int,int,int]]:
    """Apply crop by ROI then scale to 256 final image (floor/ceil to preserve coverage)."""
    x1, y1, x2, y2 = box_pre
    rx1, ry1, rx2, ry2 = roi
    # intersect with ROI
    ix1 = max(x1, rx1); iy1 = max(y1, ry1)
    ix2 = min(x2, rx2); iy2 = min(y2, ry2)
    if ix2 <= ix1 or iy2 <= iy1:
        return None
    # translate to crop coords
    cx1 = ix1 - rx1; cy1 = iy1 - ry1
    cx2 = ix2 - rx1; cy2 = iy2 - ry1
    rw = max(1, rx2 - rx1); rh = max(1, ry2 - ry1)
    sx = 256.0 / float(rw); sy = 256.0 / float(rh)
    fx1 = math.floor(cx1 * sx); fy1 = math.floor(cy1 * sy)
    fx2 = math.ceil (cx2 * sx); fy2 = math.ceil (cy2 * sy)
    # clip to [0,255], ensure min size
    fx1 = max(0, min(255, fx1)); fy1 = max(0, min(255, fy1))
    fx2 = max(fx1+1, min(256, fx2)); fy2 = max(fy1+1, min(256, fy2))
    return (int(fx1), int(fy1), int(fx2), int(fy2))

def build_gt_box_map_final256(bbox_csv_path: str, masks_dir_256: str,
                              names_final_256: List[str]) -> Dict[str, List[Tuple[int,int,int,int]]]:
    """
    For each tb image name in names_final_256:
      * Read its mask (256×256) to derive the crop ROI
      * Load authors’ boxes (512×512), scale to pre-crop 256×256
      * Apply ROI crop+resize mapping to final 256×256
    """
    df = pd.read_csv(bbox_csv_path)
    cols = {c.lower(): c for c in df.columns}
    needed = ["fname","image_height","image_width","bbox","target","image_type"]
    for k in needed:
        if k not in cols: raise ValueError(f"Column '{k}' missing in {bbox_csv_path}")

    df = df[
        (df[cols["image_type"]].astype(str).str.lower() == "tb") &
        (df[cols["target"]].astype(str).str.lower() == "tb")
    ].copy()

    # Collect authors' boxes keyed by lowercase basename (no extension)
    authors_boxes_512: Dict[str, List[Tuple[float,float,float,float,int,int]]] = {}
    for _, r in df.iterrows():
        name_raw = os.path.basename(str(r[cols["fname"]]).strip())
        base_key = os.path.splitext(name_raw)[0].lower()
        boxes = _parse_bbox_field(r[cols["bbox"]])
        if not boxes:
            continue
        W = int(r[cols["image_width"]]) if not pd.isna(r[cols["image_width"]]) else 512
        H = int(r[cols["image_height"]]) if not pd.isna(r[cols["image_height"]]) else 512
        W = 512 if W <= 0 else W; H = 512 if H <= 0 else H
        pre_list = authors_boxes_512.setdefault(base_key, [])
        for b in boxes:
            try:
                xmin = float(b.get("xmin",0.0)); ymin = float(b.get("ymin",0.0))
                width= float(b.get("width",0.0)); height= float(b.get("height",0.0))
                if width <= 0 or height <= 0: continue
                pre_list.append((xmin, ymin, width, height, W, H))
            except Exception:
                continue

    # Map to final 256×256 using the 256×256 mask ROI
    out: Dict[str, List[Tuple[int,int,int,int]]] = {}
    for name in names_final_256:
        base = os.path.splitext(name)[0]
        base_key = base.lower()
        # find mask file for this name
        mask_path = _find_by_basename(masks_dir_256, name)
        if mask_path is None:
            raise FileNotFoundError(f"[GT MAP] Missing mask for {name}")
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise RuntimeError(f"[GT MAP] Unreadable mask for {name}: {mask_path}")
        roi = _mask_to_roi(mask, ROI_MARGIN_FRAC, ROI_ENFORCE_SQUARE)
        if roi is None:
            raise RuntimeError(f"[GT MAP] Empty ROI from mask for {name}")

        auth_list = authors_boxes_512.get(base_key, [])
        if not auth_list:
            # Enforce: every processed TB image must have at least one GT
            raise RuntimeError(f"[GT MAP] No author GT boxes found for TB image {name}")

        final_boxes = []
        for (xmin, ymin, width, height, W, H) in auth_list:
            # 512 → pre-crop 256
            x1p, y1p, x2p, y2p = _scale_512_to_256_box(xmin, ymin, width, height, src_w=W, src_h=H)
            mapped = _map_box_pre256_to_final256_via_roi((x1p, y1p, x2p, y2p), roi)
            if mapped is not None:
                final_boxes.append(mapped)
        if not final_boxes:
            raise RuntimeError(f"[GT MAP] All author boxes cropped out for {name} (no intersection with ROI).")
        out[name] = final_boxes
    return out

# Model / CAM helpers
def _ckpt_stem(cfg: "RunConfig") -> str:
    return f"{cfg.network.name}_{cfg.modality.name}"

def print_model_overview_for_cam(model: nn.Module, img_size: int = 224) -> None:
    total = sum(p.numel() for p in model.parameters())
    train = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.info("Grad-CAM — Model Overview")
    log.info(f"Device: {device.type}")
    log.info(f"Input size (C,H,W): (3, {img_size}, {img_size})")
    log.info(f"Total params: {total:,}")
    log.info(f"Trainable params: {train:,}")
    log.info(f"Frozen params: {total-train:,}")

def _kernel_tuple(m: nn.Conv2d) -> Tuple[int,int]:
    k = m.kernel_size
    return (k if isinstance(k, tuple) else (k, k))

def _last_conv_kgt1(module: nn.Module) -> Optional[nn.Conv2d]:
    last_any = None; last_kgt1 = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            last_any = m
            if max(_kernel_tuple(m)) > 1:
                last_kgt1 = m
    return last_kgt1 if last_kgt1 is not None else last_any

def _vit_patch_embed_conv(module: nn.Module) -> Optional[nn.Conv2d]:
    target = None
    for name, m in module.named_modules():
        if isinstance(m, nn.Conv2d) and ("patch_embed" in name or "patch" in name):
            target = m
    return target

def _qualname_of_module(root: nn.Module, target: nn.Module) -> str:
    for n, m in root.named_modules():
        if m is target:
            return n or "<root>"
    return "<unknown>"

def select_target_layer_for_cam(model: nn.Module) -> nn.Conv2d:
    enc = getattr(model, "img_enc", None)
    if enc is not None:
        p3 = getattr(enc, "post3x3", None)
        if isinstance(p3, nn.Sequential):
            for m in p3.modules():
                if isinstance(m, nn.Conv2d) and _kernel_tuple(m) == (3,3):
                    return m
        elif isinstance(p3, nn.Conv2d) and _kernel_tuple(p3) == (3,3):
            return p3
        tgt = _last_conv_kgt1(enc)
        if isinstance(tgt, nn.Conv2d): return tgt
        vitp = _vit_patch_embed_conv(enc)
        if isinstance(vitp, nn.Conv2d): return vitp
    tgt = _last_conv_kgt1(model)
    if isinstance(tgt, nn.Conv2d): return tgt
    raise RuntimeError("No suitable Conv2d layer found for Grad-CAM.")

class CamImageWrapper(nn.Module):
    def __init__(self, mm: nn.Module) -> None:
        super().__init__()
        self.mm = mm
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mm.forward_image(x)  # image head only

def build_preprocess(size: int) -> T.Compose:
    return T.Compose([
        T.ToPILImage(), T.Resize((size, size)),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])

def _init_cam(CAMClass, model, target_layers):
    sig = inspect.signature(CAMClass.__init__)
    if "use_cuda" in sig.parameters:
        return CAMClass(model=model, target_layers=target_layers, use_cuda=(device.type == "cuda"))
    return CAMClass(model=model, target_layers=target_layers)

@torch.no_grad()
def _load_ckpt_model(cfg: "RunConfig", ckpt_path: str) -> "MultimodalNet":
    m = MultimodalNet(cfg).to(device).eval()
    state = torch.load(ckpt_path, map_location="cpu")
    m.load_state_dict(state, strict=True)
    return m

# Grad-CAM main
def run_gradcam_tbx11k_tb_mask_mapped(
    cfg: "RunConfig",
    ckpt_path: str,
    csv_test_path: str,
    tb_256_dir: str,
    masks_256_dir: str,
    bbox_csv_path: str,
    save_root: str,
    cam_method: str = CAM_METHOD,
    heatmap_alpha: float = HEATMAP_ALPHA,
    bin_thr: float = BIN_THR,
) -> None:
    assert os.path.isfile(ckpt_path), f"Checkpoint not found: {ckpt_path}"
    for p in [csv_test_path, tb_256_dir, masks_256_dir, bbox_csv_path]:
        if not os.path.exists(p):
            raise FileNotFoundError(f"Required path missing: {p}")

    # Output structure
    os.makedirs(save_root, exist_ok=True)
    out_dirs = {
        "images": os.path.join(save_root, "images"),
        "heatmaps": os.path.join(save_root, "heatmaps"),
        "contours": os.path.join(save_root, "contours"),
        "bboxes": os.path.join(save_root, "bboxes"),
        "bbox_coord": os.path.join(save_root, "bbox_coord"),  # YOLO label dump
    }
    for d in out_dirs.values(): os.makedirs(d, exist_ok=True)

    # Checks: TB_256_DIR vs ORIG_TB_DIR (WARN instead of raise)
    tb256_list = [p for p in _list_images(tb_256_dir) if os.path.basename(p).lower().startswith("tb")]
    orig_list = [p for p in _list_images(ORIG_TB_DIR) if os.path.basename(p).lower().startswith("tb")]
    tb256_stems = {os.path.splitext(os.path.basename(p))[0] for p in tb256_list}
    orig_stems = {os.path.splitext(os.path.basename(p))[0] for p in orig_list}

    if len(tb256_stems) == len(orig_stems) and tb256_stems == orig_stems:
        log.info("[DIR CHECK] TB_256_DIR and ORIG_TB_DIR counts & stems MATCH ✅ (%d images)", len(tb256_stems))
    else:
        missing_in_tb256 = sorted(orig_stems - tb256_stems)
        missing_in_orig = sorted(tb256_stems - orig_stems)
        log.warning("[DIR CHECK] Count or stem mismatch: TB_256_DIR=%d, ORIG_TB_DIR=%d",
                    len(tb256_stems), len(orig_stems))
        if missing_in_tb256:
            log.warning("[DIR CHECK] Present in ORIG_TB_DIR but missing in TB_256_DIR (showing up to 10): %s",
                        missing_in_tb256[:10])
        if missing_in_orig:
            log.warning("[DIR CHECK] Present in TB_256_DIR but missing in ORIG_TB_DIR (showing up to 10): %s",
                        missing_in_orig[:10])
        # Proceed with what we can actually process: images that exist in TB_256_DIR (+ listed in CSV below)

    # Model & CAM
    mm = _load_ckpt_model(cfg, ckpt_path)
    print_model_overview_for_cam(mm, img_size=int(getattr(cfg, "img_size", 224)))
    cam_model = CamImageWrapper(mm).to(device).eval()
    target_layer = select_target_layer_for_cam(mm)
    for p in target_layer.parameters(): p.requires_grad_(True)
    tl_name = _qualname_of_module(mm, target_layer)
    log.info(f"[CAM] Target layer: {tl_name}  kernel={getattr(target_layer, 'kernel_size', None)}")

    methods = {
        "gradcam": GradCAM, "gradcam++": GradCAMPlusPlus, "hirescam": HiResCAM,
        "xgradcam": XGradCAM, "layercam": LayerCAM, "eigencam": EigenCAM,
        "eigengradcam": EigenGradCAM, "scorecam": ScoreCAM, "ablationcam": AblationCAM
    }
    mkey = cam_method.lower()
    if mkey not in methods:
        raise ValueError(f"Unknown CAM method '{cam_method}'")
    cam = _init_cam(methods[mkey], model=cam_model, target_layers=[target_layer])
    try: cam.batch_size = 1
    except Exception: pass
    preprocess = build_preprocess(int(getattr(cfg, "img_size", 224)))

    # TB filenames from CSV & directory cross-check (works off TB_256_DIR presence)
    df_all = _read_test_csv(csv_test_path)
    df_tb  = df_all[df_all["img"].str.lower().str.startswith("tb")].copy()
    csv_tb_names = sorted(df_tb["img"].astype(str).unique().tolist())
    tb_names = crosscheck_tb(csv_tb_names, tb_256_dir)  # intersection(CSV TB, TB_256_DIR TB)
    # tb_names is the exact set we will process; typically 799 in your scenario.
    expected_stems = {os.path.splitext(n)[0] for n in tb_names}
    log.info("[PROCESS SET] Will process %d TB images present in TB_256_DIR & CSV.", len(tb_names))

    # Build GT box map in final 256×256 space using mask-driven mapping
    # Enforces: every processed image must produce ≥1 mapped GT box
    gt_boxes_map = build_gt_box_map_final256(bbox_csv_path, masks_256_dir, tb_names)

    # Write YOLO label files (normalized [0,1], no class id)
    yolo_dir = out_dirs["bbox_coord"]
    img_side = float(FINAL_SIZE)  # 256.0
    n_written = 0
    for name in tb_names:
        gt_boxes = gt_boxes_map.get(name, gt_boxes_map.get(os.path.basename(name), []))
        if not gt_boxes:
            # Should not happen due to enforcement above; keep guard.
            raise RuntimeError(f"[YOLO] No mapped GT boxes for {name}, but all processed TB images must have GT.")
        stem = os.path.splitext(os.path.basename(name))[0]
        txt_path = os.path.join(yolo_dir, f"{stem}.txt")
        lines = []
        for (x1, y1, x2, y2) in gt_boxes:
            w = max(0.0, (x2 - x1) / img_side)
            h = max(0.0, (y2 - y1) / img_side)
            cx = ((x1 + x2) / 2.0) / img_side
            cy = ((y1 + y2) / 2.0) / img_side
            # Clamp to [0,1]
            cx = min(1.0, max(0.0, cx))
            cy = min(1.0, max(0.0, cy))
            w = min(1.0, max(0.0, w))
            h = min(1.0, max(0.0, h))
            lines.append(f"{cx:.6f} {cy:.6f} {w:.6f} {h:.6f}")  # NO class id
        with open(txt_path, "w", encoding="utf-8") as f:
            f.write("\n".join(lines) + ("\n" if lines else ""))
        n_written += 1
    log.info("[YOLO] Wrote %d label files to: %s", n_written, yolo_dir)

        # Iterate TB images and draw CAM + GT
    for name in tb_names:
        img_path = _find_by_basename(tb_256_dir, name) or os.path.join(tb_256_dir, name)
        bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
        if bgr is None:
            log.warning("[CAM] Skipping unreadable: %s", img_path)
            continue
        H, W = bgr.shape[:2]
        if (H, W) != (FINAL_SIZE, FINAL_SIZE):
            bgr = cv2.resize(bgr, (FINAL_SIZE, FINAL_SIZE), interpolation=cv2.INTER_AREA)

        # model input
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        x = preprocess(rgb).unsqueeze(0).to(device)

        # CAM
        with torch.enable_grad():
            mask = cam(
                input_tensor=x,
                targets=[ClassifierOutputTarget(int(TARGET_CLASS_IDX))],
                aug_smooth=True, eigen_smooth=True
            )[0] if mkey != "eigencam" else cam(input_tensor=x)[0]

        # Normalize + upsample to 256
        mmin, mmax = float(np.min(mask)), float(np.max(mask))
        mask = (mask - mmin) / (mmax - mmin + 1e-8)
        mask_u8 = (np.clip(cv2.resize(mask, (FINAL_SIZE, FINAL_SIZE), interpolation=cv2.INTER_NEAREST), 0, 1) * 255).astype(np.uint8)

        # Heatmap overlay
        heat = cv2.applyColorMap(mask_u8, cv2.COLORMAP_HOT)
        heat_overlay = cv2.addWeighted(heat, float(heatmap_alpha), bgr, 1.0 - float(heatmap_alpha), 0.0)

        # Contours & model boxes
        thr_val = int(255 * float(bin_thr))
        _, binm = cv2.threshold(mask_u8, thr_val, 255, cv2.THRESH_BINARY)
        contours,_ = cv2.findContours(binm, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        cont_img = bgr.copy()
        if contours: cv2.drawContours(cont_img, contours, -1, CONTOUR_COLOR, CONTOUR_THICK, lineType=cv2.LINE_AA)

        model_boxes = []
        for c in contours:
            x0, y0, w0, h0 = cv2.boundingRect(c)
            x1, y1, x2, y2 = x0, y0, x0+w0, y0+h0
            model_boxes.append((x1, y1, x2, y2))

        # GT boxes (already mapped to final 256×256)
        gt_boxes = gt_boxes_map.get(name, gt_boxes_map.get(os.path.basename(name), []))

        # Write outputs
        stem = os.path.splitext(os.path.basename(name))[0]
        cv2.imwrite(os.path.join(save_root, "images",   f"{stem}.png"), bgr)

        # heatmap + GT
        heat_ov = heat_overlay.copy()
        for (gx1, gy1, gx2, gy2) in gt_boxes:
            cv2.rectangle(heat_ov, (gx1, gy1), (gx2, gy2), BB_GT_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        cv2.imwrite(os.path.join(save_root, "heatmaps", f"{stem}__{cfg.network.name}__{mkey}.png"), heat_ov)

        # contours + GT
        cont_ov = cont_img.copy()
        for (gx1, gy1, gx2, gy2) in gt_boxes:
            cv2.rectangle(cont_ov, (gx1, gy1), (gx2, gy2), BB_GT_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        cv2.imwrite(os.path.join(save_root, "contours", f"{stem}__{cfg.network.name}__{mkey}.png"), cont_ov)

        # model boxes + GT
        box_ov = bgr.copy()
        for (x1, y1, x2, y2) in model_boxes:
            cv2.rectangle(box_ov, (x1, y1), (x2, y2), BB_MODEL_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        for (gx1, gy1, gx2, gy2) in gt_boxes:
            cv2.rectangle(box_ov, (gx1, gy1), (gx2, gy2), BB_GT_COLOR, BB_THICK, lineType=cv2.LINE_AA)
        cv2.imwrite(os.path.join(save_root, "bboxes",   f"{stem}__{cfg.network.name}__{mkey}.png"), box_ov)

    # Post-run verification of bbox_coord vs image dirs
    yolo_files = sorted([p for p in os.listdir(yolo_dir) if p.lower().endswith(".txt")])
    yolo_stems = {os.path.splitext(f)[0] for f in yolo_files}
    
    # Strict: bbox_coord must match the processed set (TB_256_DIR ∩ CSV)
    if yolo_stems != expected_stems:
        diff_missing = sorted(expected_stems - yolo_stems)[:10]
        diff_extra = sorted(yolo_stems - expected_stems)[:10]
        raise RuntimeError(
            f"[VERIFY] bbox_coord stems mismatch with processed set (TB_256_DIR ∩ CSV). "
            f"Missing labels (first10): {diff_missing} | Extra labels (first10): {diff_extra}"
        )

    # Lenient: vs ORIG_TB_DIR — only warn (known one missing in ORIG without GT crop)
    if yolo_stems != orig_stems:
        diff_missing_vs_orig = sorted(orig_stems - yolo_stems)[:10]
        diff_extra_vs_orig = sorted(yolo_stems - orig_stems)[:10]
        log.warning(
            "[VERIFY] bbox_coord stems differ from ORIG_TB_DIR (tolerated). "
            "Missing vs ORIG (first10): %s | Extra vs ORIG (first10): %s",
            diff_missing_vs_orig, diff_extra_vs_orig
        )
    else:
        log.info("[VERIFY] bbox_coord stems also match ORIG_TB_DIR ✅")

    log.info("[VERIFY] bbox_coord count & stems MATCH processed TB set ✅ (%d files)", len(yolo_stems))

    # Cleanup
    try:
        if hasattr(cam, "activations_and_grads") and (cam.activations_and_grads is not None):
            cam.activations_and_grads.release()
    except Exception:
        pass
    del cam, cam_model, mm
    if device.type == "cuda":
        torch.cuda.empty_cache()
    log.info("[CAM] Saved to: %s", save_root)

# Runner
assert os.path.isfile(MANUAL_CKPT_PATH), f"Checkpoint not found: {MANUAL_CKPT_PATH}"
best_subdir = os.path.dirname(MANUAL_CKPT_PATH)
save_root = os.path.join(best_subdir, SAVE_ROOT_NAME)
log.info("================ Grad-CAM (TBX11K TB; mask-mapped GT + YOLO dump) ================")
log.info(f"Checkpoint : {MANUAL_CKPT_PATH}")
log.info(f"Backbone   : {cfg.network.name}")
log.info(f"Modality   : {cfg.modality.name}")
log.info(f"csv_test   : {CSV_TEST_PATH}")
log.info(f"TB dir     : {TB_256_DIR} (final lung-cropped 256×256)")
log.info(f"orig TB dir: {ORIG_TB_DIR} (original 512×512)")
log.info(f"Masks dir  : {TB_MASKS_256_DIR} (pre-crop masks 256×256)")
log.info(f"BBoxes csv : {BBOX_CSV_PATH} (authors' boxes @512×512)")
log.info(f"Save root  : {save_root}")
log.info("==================================================================================")

# Cross-checks and YOLO writing happen inside the function (logged to console)
run_gradcam_tbx11k_tb_mask_mapped(
    cfg=cfg,
    ckpt_path=MANUAL_CKPT_PATH,
    csv_test_path=CSV_TEST_PATH,
    tb_256_dir=TB_256_DIR,
    masks_256_dir=TB_MASKS_256_DIR,
    bbox_csv_path=BBOX_CSV_PATH,
    save_root=save_root,
    cam_method=CAM_METHOD,
    heatmap_alpha=HEATMAP_ALPHA,
    bin_thr=BIN_THR,
)

In [None]:
# === Sanity check: overlay YOLO bbox_coord on cropped TB images (random 3) ===

# Config
TB_256_DIR = "/dataset/tbx11k/cropped/tb"
BBOX_COORD_DIR = "/gradcam_tbx11k_vgg11_multimodal_external_overlap/bbox_coord"
FINAL_SIZE = 256  # expected image side (H=W=256)

# Logging
logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] %(levelname)s: %(message)s",
    datefmt="%H:%M:%S",
)
log = logging.getLogger("sanity-yolo-overlay")
IMG_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}

def _is_img(path: str) -> bool:
    return os.path.splitext(path)[1].lower() in IMG_EXTS

def _list_images(dir_path: str) -> List[str]:
    out = []
    for root, _, files in os.walk(dir_path):
        for fn in files:
            fp = os.path.join(root, fn)
            if _is_img(fp):
                out.append(fp)
    return sorted(out)

def _stem(path: str) -> str:
    return os.path.splitext(os.path.basename(path))[0]

def _find_by_stem(dir_path: str, stem: str) -> str | None:
    for ext in IMG_EXTS:
        p = os.path.join(dir_path, stem + ext)
        if os.path.isfile(p):
            return p
    return None

def _read_yolo_no_class(txt_path: str) -> List[Tuple[float, float, float, float]]:
    """
    Returns list of (xc, yc, w, h) normalized to [0,1].
    Skips blank/comment lines; robust to extra spaces.
    """
    boxes = []
    with open(txt_path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s or s.startswith("#"):
                continue
            parts = s.split()
            if len(parts) < 4:
                continue
            try:
                xc, yc, w, h = map(float, parts[:4])
                # clamp defensively
                xc = min(1.0, max(0.0, xc))
                yc = min(1.0, max(0.0, yc))
                w  = min(1.0, max(0.0, w))
                h  = min(1.0, max(0.0, h))
                boxes.append((xc, yc, w, h))
            except Exception:
                continue
    return boxes

def _yolo_to_xyxy(box: Tuple[float,float,float,float], img_w: int, img_h: int) -> Tuple[int,int,int,int]:
    xc, yc, w, h = box
    bw = w * img_w
    bh = h * img_h
    x1 = int(round((xc * img_w) - bw / 2.0))
    y1 = int(round((yc * img_h) - bh / 2.0))
    x2 = int(round((xc * img_w) + bw / 2.0))
    y2 = int(round((yc * img_h) + bh / 2.0))
    # clip
    x1 = max(0, min(img_w-1, x1))
    y1 = max(0, min(img_h-1, y1))
    x2 = max(0, min(img_w-1, x2))
    y2 = max(0, min(img_h-1, y2))
    if x2 <= x1: x2 = min(img_w-1, x1+1)
    if y2 <= y1: y2 = min(img_h-1, y1+1)
    return x1, y1, x2, y2

def sanity_check_yolo_overlays(
    tb_dir: str = TB_256_DIR,
    yolo_dir: str = BBOX_COORD_DIR,
    n_samples: int = 3,
    seed: int | None = None,
) -> None:
    assert os.path.isdir(tb_dir), f"Missing TB dir: {tb_dir}"
    assert os.path.isdir(yolo_dir), f"Missing bbox_coord dir: {yolo_dir}"
    imgs = _list_images(tb_dir)
    stems_img = {_stem(p) for p in imgs if os.path.basename(p).lower().startswith("tb")}
    yolo_txts = [os.path.join(yolo_dir, f) for f in os.listdir(yolo_dir) if f.lower().endswith(".txt")]
    stems_yolo = {_stem(p) for p in yolo_txts}
    only_in_imgs = sorted(stems_img - stems_yolo)
    only_in_yolo = sorted(stems_yolo - stems_img)
    log.info("Cropped TB images: %d | YOLO label files: %d | intersection: %d",
             len(stems_img), len(stems_yolo), len(stems_img & stems_yolo))
    if only_in_imgs:
        log.warning("There are %d TB images without a YOLO label file (first 5): %s",
                    len(only_in_imgs), only_in_imgs[:5])
    if only_in_yolo:
        log.warning("There are %d YOLO label files without a matching TB image (first 5): %s",
                    len(only_in_yolo), only_in_yolo[:5])
    common_stems = sorted(list(stems_img & stems_yolo))
    assert len(common_stems) >= n_samples, f"Need at least {n_samples} common items; got {len(common_stems)}."
    if seed is not None:
        random.seed(seed)
    picks = random.sample(common_stems, n_samples)

    # Show overlays
    for stem in picks:
        img_path = _find_by_stem(tb_dir, stem)
        txt_path = os.path.join(yolo_dir, stem + ".txt")

        assert img_path is not None and os.path.isfile(img_path), f"Missing image for {stem}"
        assert os.path.isfile(txt_path), f"Missing label file for {stem}"

        bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
        if bgr is None:
            log.error("Unreadable image: %s", img_path)
            continue
        H, W = bgr.shape[:2]
        if (H, W) != (FINAL_SIZE, FINAL_SIZE):
            log.warning("Image %s is %dx%d (expected %dx%d); resizing for display.",
                        stem, W, H, FINAL_SIZE, FINAL_SIZE)
            bgr = cv2.resize(bgr, (FINAL_SIZE, FINAL_SIZE), interpolation=cv2.INTER_AREA)
            H, W = FINAL_SIZE, FINAL_SIZE

        boxes_norm = _read_yolo_no_class(txt_path)
        log.info("[%s] %d GT box(es) from %s", stem, len(boxes_norm), pathlib.Path(txt_path).name)

        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        ax = plt.gca()
        ax.imshow(rgb)
        ax.set_axis_off()
        ax.set_title(f"{stem} — {len(boxes_norm)} box(es)")

        # draw boxes
        for (xc, yc, w, h) in boxes_norm:
            x1, y1, x2, y2 = _yolo_to_xyxy((xc, yc, w, h), W, H)
            rect = Rectangle((x1, y1), (x2 - x1), (y2 - y1),
                             linewidth=2, edgecolor="r", facecolor="none")
            ax.add_patch(rect)
            # mark center
            ax.plot([xc*W], [yc*H], marker="x", markersize=6)

        plt.tight_layout()
        plt.show()

# Run once per check (random 3 each time)
sanity_check_yolo_overlays(n_samples=4, seed=42)


## END OF CODE