In [None]:
!nvidia-smi

In [None]:
# ====================== UMAP (GRID) — Save ONLY IMAGES into model-specific dir ======================
# - Loads image-head from unimodal OR multimodal checkpoint
# - Embeddings taken from the **deepest pooling layer** in img_enc.features (pre-GAP)
# - Two eval pipelines: "mm" (PIL/torchvision) or "uni" (cv2/Albumentations)
# - Grid search over n_neighbors × min_dist
# - Saves figures into checkpoint-derived directory
# ================================================================================================

from __future__ import annotations
import os, random, logging, pickle
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as T
from PIL import Image
import cv2, albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

try:
    import umap   # pip install umap-learn
except ImportError:
    import umap.umap_ as umap

from sklearn.metrics import silhouette_score, davies_bouldin_score
from sklearn.neighbors import KNeighborsClassifier

# --------------------------- Repro & logging --------------------------- #
def set_seed(seed: int = 1337) -> None:
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(1337)

logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s",
                    datefmt="%Y-%m-%d %H:%M:%S")
log = logging.getLogger("umap-save-exact")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.getcwd()

In [None]:
# User config

CKPT_PATH = Path("/mm_grid1/best_vgg11_multimodal_val_loss.pt") #RAW, STRUCTURED, UNIMODAL
DATASET_ROOT = Path("/dataset")
IMAGES_DIR   = DATASET_ROOT / "images"
CSV_PATH   = DATASET_ROOT / "label_test.csv"  # <-- headerless: filename,label

# Which eval pipeline to use:
PIPELINE = "uni"              # set "uni" for unimodal checkpoints
EXPECT = "unimodal"           # "auto" | "unimodal" | "multimodal"
ARCH_NAME = "vgg11"
ARCH_TAG = ARCH_NAME.lower()
IMG_SIZE = 224
BATCH_SIZE = 64
NUM_WORKERS = 2
PIN_MEMORY = True
NUM_CLASSES = 2

# UMAP grid
UMAP_N_NEIGHBORS_GRID = [5, 10, 15, 20, 50, 100, 200]
UMAP_MIN_DIST_GRID = [0.0, 0.05, 0.1, 0.2, 0.25, 0.4, 0.5]
UMAP_METRIC  = "euclidean"
UMAP_RANDOM_STATE = 42
TOPK_PLOTS = 0

# Plot styling EXACTLY as requested
CLASS_NAMES = ["Normal", "TB"]                 # 0 -> Normal, 1 -> TB
CMAP = ListedColormap(["blue", "red"])         # 0 -> blue, 1 -> red

# Output dir
CKPT_DIR: Path = CKPT_PATH if CKPT_PATH.is_dir() else CKPT_PATH.parent
ckpt_dir_str = CKPT_DIR.as_posix()
is_unimodal = "models_img_" in ckpt_dir_str

if is_unimodal:
    base_dir = CKPT_DIR.parent if CKPT_DIR.name.startswith(("gradcam_", "cam_", "umap_")) else CKPT_DIR
    SAVE_DIR = base_dir / f"umap_{ARCH_TAG}_unimodal"
else:
    SAVE_DIR = CKPT_DIR / f"umap_{ARCH_TAG}_multimodal_structured"

SAVE_DIR.mkdir(parents=True, exist_ok=True)
log.info(f"[UMAP] Output directory: {SAVE_DIR}")

# Filenames for saving all combos & best-by-silhouette
ALL_PNG_PATTERN = SAVE_DIR / f"umap_{ARCH_TAG}_unimodal_nn{{nn}}_md{{md:.2f}}.png"
BEST_SIL_PNG = SAVE_DIR / f"umap_best_{ARCH_TAG}_unimodal_by_silhouette.png"

In [None]:
# CSV loader

def _is_image_name(name: str) -> bool:
    return Path(name).suffix.lower() in {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}

def load_headerless_csv(csv_path: Path) -> pd.DataFrame:
    if not csv_path.is_file():
        raise FileNotFoundError(f"CSV not found: {csv_path}")
    df = pd.read_csv(csv_path, header=None, encoding="utf-8-sig")
    if df.shape[1] < 2:
        raise ValueError(f"CSV must have >=2 columns: filename,label. Got {df.shape}")
    df = df.iloc[:, :2].copy()
    df.columns = ["img", "label"]
    df["img"] = df["img"].astype(str).map(lambda s: os.path.basename(s.strip()))
    df["label"] = pd.to_numeric(df["label"], errors="coerce").fillna(0).astype(int)
    df = df[df["img"].map(_is_image_name)].reset_index(drop=True)
    return df

# Image head (shared topology)
def _ends_with_3x3(module: nn.Module) -> bool:
    last = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d): last = m
    return (last is not None) and (tuple(getattr(last, "kernel_size",(0,0)))==(3,3))

def _post3x3(ch: int) -> nn.Sequential:
    # If training used this stage, keep it; else set to nn.Identity()
    return nn.Sequential(
        nn.Conv2d(ch, ch, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ch),
        nn.SiLU(inplace=True),
    )

class VGG11BN_ImageEnc(nn.Module):
    def __init__(self, p_drop: float = 0.3, use_post3x3: bool = True):
        super().__init__()
        m = torchvision.models.vgg11_bn(weights=None)
        self.features = m.features
        ch = None
        for mod in self.features.modules():
            if isinstance(mod, nn.Conv2d): ch = mod.out_channels
        if ch is None: raise RuntimeError("Could not infer channels from VGG11-BN features.")
        self.post3x3 = _post3x3(ch) if (use_post3x3 and not _ends_with_3x3(self.features)) else nn.Identity()
        self.gap = nn.AdaptiveAvgPool2d((1,1))
        self.drop = nn.Dropout(p=p_drop)
        self.out_dim = int(ch)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.features(x); f = self.post3x3(f); f = self.gap(f).flatten(1); f = self.drop(f)
        return f

class ImageHead(nn.Module):
    """img_enc → proj → ReLU → mid → cls; used by BOTH unimodal & multimodal image paths."""
    def __init__(self, hidden: int = 256, n_classes: int = 2, use_post3x3: bool = True):
        super().__init__()
        self.img_enc = VGG11BN_ImageEnc(p_drop=0.3, use_post3x3=use_post3x3)
        C  = self.img_enc.out_dim
        self.proj = nn.Linear(C, hidden)   # alias for 'img_proj.*'
        self.act = nn.ReLU(inplace=True)
        self.mid = nn.Linear(hidden, hidden)
        self.cls = nn.Linear(hidden, n_classes)
    def forward_image(self, image: torch.Tensor) -> torch.Tensor:
        f = self.img_enc(image); z = self.mid(self.act(self.proj(f))); return self.cls(z)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_image(x)

# Checkpoint loader
def _strip_prefix(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    if not sd: return sd
    keys = list(sd.keys())
    prefixes = [pref for k in keys for pref in ("module.","model.","net.") if k.startswith(pref)]
    if not prefixes: return sd
    pref = max(set(prefixes), key=prefixes.count)
    return { (k[len(pref):] if k.startswith(pref) else k): v for k,v in sd.items() }

def _load_sd(ckpt_path: Path) -> Dict[str, torch.Tensor]:
    try:
        chk = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    except TypeError:
        chk = torch.load(ckpt_path, map_location="cpu")
    except Exception:
        chk = torch.load(ckpt_path, map_location="cpu", weights_only=False, pickle_module=pickle)
    if isinstance(chk, nn.Module):
        if hasattr(chk, "state_dict"): return chk.state_dict()
        raise RuntimeError("Pickled nn.Module without state_dict; please save state_dict next time.")
    if isinstance(chk, dict) and "state_dict" in chk and isinstance(chk["state_dict"], dict):
        sd = chk["state_dict"]
    elif isinstance(chk, dict) and "model" in chk and isinstance(chk["model"], dict):
        sd = chk["model"]
    elif isinstance(chk, dict) and all(isinstance(v, torch.Tensor) for v in chk.values()):
        sd = chk
    else:
        tensorish = {k: v for k, v in (chk if isinstance(chk, dict) else {}).items() if isinstance(v, torch.Tensor)}
        if tensorish: sd = tensorish
        else: raise RuntimeError("No tensor state_dict found in checkpoint.")
    return _strip_prefix(sd)

def _has_image_head_keys(sd: Dict[str, torch.Tensor]) -> bool:
    ks = sd.keys()
    return any(k.startswith(("img_enc.","proj.","img_proj.","mid.","cls.","image_head.")) for k in ks)

def build_image_head_from_ckpt(ckpt_path: Path,
                               n_classes: int = 2,
                               expect: str = "auto",
                               use_post3x3: bool = True) -> nn.Module:
    sd = _load_sd(ckpt_path)
    if not _has_image_head_keys(sd):
        msg = (f"Checkpoint lacks image-head keys ('img_enc.', 'proj.'/'img_proj.', 'mid.', 'cls.' or 'image_head.'). "
               f"File: {ckpt_path}")
        raise RuntimeError(msg if expect == "auto" else f"[{expect}] {msg}")

    # Unwrap common wrapper prefixes, alias img_proj -> proj
    remap = {}
    for k, v in sd.items():
        nk = k
        if nk.startswith("image_head."): nk = nk.replace("image_head.", "", 1)
        if nk.startswith("img_proj."): nk = nk.replace("img_proj.",   "proj.", 1)
        remap[nk] = v

    model = ImageHead(hidden=256, n_classes=n_classes, use_post3x3=use_post3x3)
    missing, unexpected = model.load_state_dict(remap, strict=False)
    if missing: log.warning(f"[CKPT] Missing keys (first few): {missing[:8]}")
    if unexpected: log.warning(f"[CKPT] Unexpected keys (first few): {unexpected[:8]}")
    log.info(f"[CKPT] Image head reconstructed from {ckpt_path.name} (strict=False).")
    return model

# Datasets
class PILImageDataset(Dataset):
    def __init__(self, images_dir: Path, df: pd.DataFrame, size: int = 224) -> None:
        super().__init__()
        self.root   = Path(images_dir)
        self.names  = df["img"].tolist()
        self.labels = df["label"].astype(int).tolist()
        self.tf = T.Compose([
            T.Resize((size, size)),
            T.Lambda(lambda x: x.convert("RGB") if x.mode != "RGB" else x),
            T.ToTensor(),
            T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ])
    def __len__(self) -> int: return len(self.names)
    def __getitem__(self, i: int):
        name = self.names[i]; y = self.labels[i]
        p = self.root / name
        img = Image.open(p)
        x = self.tf(img)
        return x, int(y), name

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

class CV2ImageDataset(Dataset):
    def __init__(self, images_dir: Path, df: pd.DataFrame, size: int = 224) -> None:
        super().__init__()
        self.root   = Path(images_dir)
        self.names  = df["img"].tolist()
        self.labels = df["label"].astype(int).tolist()
        self.tf = A.Compose([
            A.Resize(size, size, interpolation=cv2.INTER_AREA),
            A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ToTensorV2()
        ])
    def __len__(self) -> int: return len(self.names)
    def __getitem__(self, i: int):
        name = self.names[i]; y = self.labels[i]
        p = self.root / name
        img = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
        if img is None: raise FileNotFoundError(f"Image not found/unreadable: {p}")
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        x = self.tf(image=img)["image"]
        return x, int(y), name

# --------------------------- Deepest pooling hook --------------------------- #
def find_deepest_pool(module: nn.Module) -> nn.Module:
    last_pool = None
    for m in module.modules():
        if isinstance(m, nn.MaxPool2d): last_pool = m
    if last_pool is None:
        raise RuntimeError("No MaxPool2d found under img_enc.features.")
    return last_pool

class FeatureHook:
    def __init__(self, layer: nn.Module):
        self.feats = []
        self.h = layer.register_forward_hook(self._hook)
    def _hook(self, module, inp, out):
        self.feats.append(out.detach().cpu())
    def close(self): self.h.remove()

def global_avg_pool_4d(x: torch.Tensor) -> torch.Tensor:
    return x.mean(dim=(2,3))  # (B,C,H,W)->(B,C)

@torch.no_grad()
def collect_embeddings(model: nn.Module, loader: DataLoader) -> Tuple[np.ndarray, np.ndarray]:
    model.eval().to(DEVICE)
    if not hasattr(model, "img_enc") or not hasattr(model.img_enc, "features"):
        raise AttributeError("Model lacks img_enc.features; not an image head.")
    target_layer = find_deepest_pool(model.img_enc.features)
    hk = FeatureHook(target_layer)
    labels = []
    for x, y, _ in loader:
        x = x.to(DEVICE, non_blocking=True)
        _ = model.forward_image(x) if hasattr(model, "forward_image") else model(x)
        labels.append(y.numpy())

    if not hk.feats:
        hk.close(); raise RuntimeError("No features captured; wrong hook/layer?")
    fmaps = torch.cat(hk.feats, dim=0)  # (N,C,H,W)
    hk.close()
    emb = global_avg_pool_4d(fmaps).numpy()  # (N,C)
    y = np.concatenate(labels).astype(np.int64)
    return emb, y
    
# --------------------------- UMAP + grid (SAVE ALL @ 400 dpi + BEST-by-silhouette) --------------------------- #
def run_umap(X: np.ndarray, n_neighbors: int, min_dist: float,
             metric: str = UMAP_METRIC, random_state: int = UMAP_RANDOM_STATE) -> np.ndarray:
    reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist,
                        metric=metric, random_state=random_state)
    return reducer.fit_transform(X)

def score_2d(Z: np.ndarray, labels: np.ndarray) -> Dict[str, float]:
    m = {}
    if len(np.unique(labels)) > 1:
        try: m["silhouette"] = float(silhouette_score(Z, labels))
        except Exception: m["silhouette"] = float("nan")
        try: m["db_index"]  = float(davies_bouldin_score(Z, labels))
        except Exception: m["db_index"] = float("nan")
        try:
            knn = KNeighborsClassifier(n_neighbors=1)
            knn.fit(Z, labels)
            m["knn1_acc"] = float(knn.score(Z, labels))
        except Exception:
            m["knn1_acc"] = float("nan")
    else:
        m["silhouette"] = m["db_index"] = m["knn1_acc"] = float("nan")
    return m

# exact coloring and show-on-console
def plot_umap_png(Z: np.ndarray, labels: np.ndarray, title: str, out_path: Path,
                  display_dpi: int = 120, save_dpi: int = 400) -> None:
    plt.figure(figsize=(7.5, 6.5), dpi=display_dpi)   # <-- console DPI
    for lab, name in enumerate(CLASS_NAMES):
        idx = (labels == lab)
        if np.any(idx):
            plt.scatter(Z[idx,0], Z[idx,1], s=10, alpha=0.85, label=name, c=[CMAP(lab)])
    plt.legend(frameon=True)
    plt.title(title)
    plt.xlabel("UMAP-1"); plt.ylabel("UMAP-2")
    plt.grid(alpha=0.15, linestyle="--")
    plt.tight_layout()
    plt.savefig(str(out_path), dpi=save_dpi, bbox_inches="tight")  # <-- file DPI
    plt.show()
    plt.close()

# Main
assert IMAGES_DIR.is_dir(), f"Images dir not found: {IMAGES_DIR}"
assert CKPT_PATH.is_file(), f"Checkpoint not found: {CKPT_PATH}"
df = load_headerless_csv(CSV_PATH)
log.info(f"[DATA] rows={len(df)} | PIPELINE={PIPELINE} | EXPECT={EXPECT}")

if PIPELINE == "mm":
    ds = PILImageDataset(IMAGES_DIR, df, size=IMG_SIZE)
elif PIPELINE == "uni":
    ds = CV2ImageDataset(IMAGES_DIR, df, size=IMG_SIZE)
else:
    raise ValueError("PIPELINE must be 'mm' or 'uni'.")

dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False,
                num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=False)

model = build_image_head_from_ckpt(CKPT_PATH, n_classes=NUM_CLASSES, expect=EXPECT, use_post3x3=True)

print("\n================== MODEL (ImageHead\n")
print(model)  # shows ImageHead(...) with img_enc/proj/mid/cls
print("\n======================================================================\n")

# 1) Collect embeddings (deepest pooling layer)
emb, y = collect_embeddings(model, dl)
log.info(f"[FEATS] emb={emb.shape}, labels={y.shape}")

# 2) Grid search: save ALL combos @400 dpi + track BEST-by-silhouette
best_sil = -float("inf")
best_tuple = None  # (nn, md, Z, met)

for nn_ in UMAP_N_NEIGHBORS_GRID:
    for md_ in UMAP_MIN_DIST_GRID:
        Z   = run_umap(emb, n_neighbors=nn_, min_dist=md_, metric=UMAP_METRIC, random_state=UMAP_RANDOM_STATE)
        met = score_2d(Z, y)
        # Save each combination (400 dpi) and DISPLAY
        combo_png = ALL_PNG_PATTERN.with_name(ALL_PNG_PATTERN.name.format(nn=nn_, md=md_))
        combo_title = (f"UMAP — nn={nn_}, md={md_:.2f} | "
                       f"sil={met.get('silhouette', np.nan):.3f}, "
                       f"db={met.get('db_index', np.nan):.3f}, "
                       f"1NN={met.get('knn1_acc', np.nan):.3f}")
        plot_umap_png(Z, y, title=combo_title, out_path=combo_png, display_dpi=60, save_dpi=400)
        # Update best-by-silhouette
        sil_val = met.get("silhouette", float("nan"))
        sil_val = sil_val if np.isfinite(sil_val) else -1.0
        if sil_val > best_sil:
            best_sil = sil_val
            best_tuple = (nn_, md_, Z, met)

# 3) Save BEST-by-silhouette with separate name (and DISPLAY)
if best_tuple is None:
    raise RuntimeError("No valid UMAP embedding found to select by silhouette.")

best_nn, best_md, best_Z, best_met = best_tuple
best_title = (f"UMAP — BEST by Silhouette | nn={best_nn}, md={best_md:.2f} | "
              f"sil={best_met.get('silhouette', np.nan):.3f}, "
              f"db={best_met.get('db_index', np.nan):.3f}, "
              f"1NN={best_met.get('knn1_acc', np.nan):.3f}")
# plot_umap_png(best_Z, y, title=best_title, out_path=BEST_SIL_PNG, dpi=400)
plot_umap_png(best_Z, y, title=best_title, out_path=BEST_SIL_PNG, display_dpi=60, save_dpi=400)
log.info(f"[UMAP] Saved BEST-by-silhouette: {BEST_SIL_PNG}")

## END OF CODE