In [None]:
!nvidia-smi

## ROC Comparison

The code:

1. loads three checkpoints (unimodal VGG-11; multimodal raw; multimodal structured),

2. prints each model architecture,

3. runs image-only inference on the user-selected test set (Shenzhen / Montgomery / TBX11K),

4. computes AUROC and best-MCC (with its threshold),

5. and plots a smooth ROC comparison using customizable colors (here, red/blue/green).

The codes includes a loader that reconstructs the unimodal VGG-11 and the MultimodalNet (image head) so the multimodal weights can load with strict=False. 

In [None]:
from __future__ import annotations

import os, math, random, logging
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import os
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.models as tvm
import torchvision.transforms as T
import torch.nn.functional as F
import torch._utils 

from PIL import Image
from sklearn.metrics import roc_auc_score, roc_curve, matthews_corrcoef
import matplotlib.pyplot as plt
import pickle
import torch._utils  # force-load the private module so torch._weights_only_unpickler can see it
import cv2, albumentations as A
from albumentations.pytorch import ToTensorV2

os.getcwd()

In [None]:
# ROC comparison for VGG11-based unimodal & multimodal models
# - Prints model architectures
# - Evaluates on selected test CSV(s): Shenzhen / Montgomery / TBX11K (headerless "filename,label")
# - Plots AUROC & best-MCC for each model 

# -------------------------- Reproducibility -------------------------- #
def set_deterministic(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_deterministic(42)

logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] %(levelname)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger("roc-compare")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Dataset config

DATASET_ROOT = Path("/gpfs/gsfs12/users/rajaramans2/projects/omsakthi_multimodal/multimodal_shenzhen/dataset")

# Shenzhen:
IMAGES_DIR = DATASET_ROOT / "images"
CSV_PATHS = [DATASET_ROOT / "label_test.csv"]
IMG_SIZE = 224
BATCH_SIZE = 64
NUM_WORKERS = 2
PIN_MEMORY = True
NUM_CLASSES = 2

# Model checkpoints
CKPT_UNIMODAL_VGG11 = Path("/best_vgg11_img_val_loss.pt")
CKPT_MM_RAW = Path("/mm_grid1/best_vgg11_multimodal_val_loss.pt")
CKPT_MM_STRUCT = Path("/mm_grid2/best_vgg11_multimodal_val_loss.pt")

# name -> (ckpt_path, color, pipeline)
# pipeline ∈ {"uni","mm"} decides which loader/model-ctor to use
MODEL_STYLES: Dict[str, Tuple[Path, str, str]] = {
    "Unimodal": ("{}".format(CKPT_UNIMODAL_VGG11), "red", "uni"),
    "Multimodal (Raw Text)": ("{}".format(CKPT_MM_RAW), "blue", "mm"),
    "Multimodal (Structured Text)": ("{}".format(CKPT_MM_STRUCT), "green", "mm"),
}
FONT = {"title": 14, "axis": 12, "tick": 10, "legend": 11,
        "legend_color": "black", "axis_color": "black"}

# CSV
def _is_image_name(name: str) -> bool:
    return Path(name).suffix.lower() in {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}

def _load_csv_headerless(csv_path: Path) -> pd.DataFrame:
    if not Path(csv_path).is_file():
        raise FileNotFoundError(f"CSV not found: {csv_path}")
    df = pd.read_csv(csv_path, header=None, encoding="utf-8-sig")
    if df.shape[1] < 2:
        raise ValueError(f"CSV must have >=2 columns (filename,label): {csv_path}")
    df = df.iloc[:, :2].copy()
    df.columns = ["img", "label"]
    df["img"] = df["img"].astype(str).map(lambda s: os.path.basename(s.strip()))
    df["label"] = pd.to_numeric(df["label"], errors="coerce").fillna(0).astype(int)
    df = df[df["img"].map(_is_image_name)].reset_index(drop=True)
    return df

def union_csvs(paths: List[Path]) -> pd.DataFrame:
    frames = [ _load_csv_headerless(p) for p in paths ]
    df = pd.concat(frames, axis=0, ignore_index=True)
    df = df.drop_duplicates(subset=["img"], keep="first").reset_index(drop=True)
    return df

# UNIMODAL PIPELINE
#  - Albumentations + cv2: GRAY→RGB, Resize(INTER_AREA), Normalize(ImageNet)
#  - ImageOnlyHead with forward_image()

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

class UniCsvImageDataset(Dataset):
    def __init__(self, images_dir: Path, df: pd.DataFrame, size: int = 224):
        super().__init__()
        self.root, self.names = Path(images_dir), df["img"].tolist()
        self.labels = df["label"].astype(int).tolist()
        self.tf = A.Compose([
            A.Resize(size, size, interpolation=cv2.INTER_AREA),
            A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ToTensorV2()
        ])
    def __len__(self) -> int: return len(self.names)
    def __getitem__(self, i: int):
        name = self.names[i]; y = self.labels[i]
        p = self.root / name
        img = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise FileNotFoundError(f"Image not found or unreadable: {p}")
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        x = self.tf(image=img)["image"]
        return x, int(y), name

def _ends_with_3x3(module: nn.Module) -> bool:
    last = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            last = m
    return (last is not None) and (tuple(getattr(last, "kernel_size", (0,0))) == (3,3))

def _post3x3(ch: int) -> nn.Sequential:
    return nn.Sequential(nn.Conv2d(ch, ch, 3, 1, 1, bias=False), nn.BatchNorm2d(ch), nn.SiLU(inplace=True))

class VGG11ImageEncoder(nn.Module):
    def __init__(self, p_drop: float = 0.3):
        super().__init__()
        m = tvm.vgg11_bn(weights=None)
        self.encoder = m.features
        ch = None
        for mod in self.encoder.modules():
            if isinstance(mod, nn.Conv2d): ch = mod.out_channels
        self.post3x3 = _post3x3(ch) if not _ends_with_3x3(self.encoder) else nn.Identity()
        self.gap = nn.AdaptiveAvgPool2d((1,1))
        self.drop = nn.Dropout(p=p_drop)
        self.out_dim = int(ch)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.encoder(x); f = self.post3x3(f); f = self.gap(f).flatten(1); f = self.drop(f)
        return f

class ImageOnlyHead(nn.Module):
    def __init__(self, hidden: int = 256, n_classes: int = 2):
        super().__init__()
        self.img_enc = VGG11ImageEncoder(p_drop=0.3)
        self.proj = nn.Linear(self.img_enc.out_dim, hidden)
        self.act = nn.ReLU(inplace=True)
        self.mid = nn.Linear(hidden, hidden)
        self.cls = nn.Linear(hidden, n_classes)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.img_enc(x); z = self.mid(self.act(self.proj(f))); return self.cls(z)
    def forward_image(self, x: torch.Tensor) -> torch.Tensor: return self.forward(x)

def _strip_prefix(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    if not sd: return sd
    keys = list(sd.keys())
    prefixes = [pref for k in keys for pref in ("module.","model.","net.") if k.startswith(pref)]
    if not prefixes: return sd
    pref = max(set(prefixes), key=prefixes.count)
    return { (k[len(pref):] if k.startswith(pref) else k): v for k,v in sd.items() }

def _load_checkpoint(ckpt_path: Path):
    try:
        return torch.load(ckpt_path, map_location="cpu", weights_only=False)
    except TypeError:
        return torch.load(ckpt_path, map_location="cpu")
    except Exception as e:
        log.warning(f"[CKPT] load failed ({e}); retry with explicit pickle_module")
        return torch.load(ckpt_path, map_location="cpu", weights_only=False, pickle_module=pickle)

def build_unimodal_from_ckpt(ckpt_path: Path) -> nn.Module:
    chk = _load_checkpoint(ckpt_path)
    if isinstance(chk, nn.Module): return chk
    sd = chk.get("state_dict", chk.get("model", chk))
    if not isinstance(sd, dict): raise RuntimeError("No state_dict in unimodal ckpt.")
    sd = _strip_prefix(sd)
    model = ImageOnlyHead(hidden=256, n_classes=NUM_CLASSES)
    model.load_state_dict(sd, strict=False)
    return model

# MULTIMODAL PIPELINE
#  - PIL + torchvision.transforms.Resize → ToTensor → Normalize(ImageNet)
#  - MultimodalImageHead mapping + forward_image()

class MmCsvImageDataset(Dataset):
    """PIL/torchvision test-time pipeline as in your custom_code."""
    def __init__(self, images_dir: Path, df: pd.DataFrame, size: int = 224):
        super().__init__()
        self.root, self.names = Path(images_dir), df["img"].tolist()
        self.labels = df["label"].astype(int).tolist()
        self.tf = T.Compose([
            T.Resize((size, size)),
            T.Lambda(lambda x: x.convert("RGB") if x.mode != "RGB" else x),
            T.ToTensor(),
            T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ])
    def __len__(self) -> int: return len(self.names)
    def __getitem__(self, i: int):
        name = self.names[i]; y = self.labels[i]
        p = self.root / name
        if not p.is_file(): raise FileNotFoundError(f"Image not found: {p}")
        img = Image.open(p)
        x = self.tf(img)
        return x, int(y), name

def _ends_with_3x3_tv(module: nn.Module) -> bool:
    last_conv = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d): last_conv = m
    return (last_conv is not None) and (tuple(getattr(last_conv, "kernel_size",(0,0)))==(3,3))

def _insert_post3x3_if_needed_tv(ch: int) -> nn.Sequential:
    return nn.Sequential(nn.Conv2d(ch, ch, 3, 1, 1, bias=False),
                         nn.BatchNorm2d(ch), nn.SiLU(inplace=True))

class TorchvisionVGGBackbone(nn.Module):
    """Matches your custom_code's multimodal image backbone wrapper."""
    def __init__(self, which: str = "vgg11_bn", p_drop: float = 0.3):
        super().__init__()
        tvm = torchvision.models
        m = tvm.vgg11_bn(weights=None)
        self.encoder = m.features
        ch = None
        for mod in self.encoder.modules():
            if isinstance(mod, nn.Conv2d): ch = mod.out_channels
        self.post3x3 = _insert_post3x3_if_needed_tv(ch) if not _ends_with_3x3_tv(self.encoder) else nn.Identity()
        self.gap = nn.AdaptiveAvgPool2d((1,1))
        self.drop = nn.Dropout(p=p_drop)
        self.out_dim = int(ch)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.encoder(x); f = self.post3x3(f); f = self.gap(f).flatten(1); f = self.drop(f)
        return f

class MultimodalImageHead(nn.Module):
    """Exact structure as in your custom_code: img_enc -> img_proj -> ReLU -> mid -> cls with forward_image()."""
    def __init__(self, hidden: int = 256, n_classes: int = 2):
        super().__init__()
        self.img_enc = TorchvisionVGGBackbone("vgg11_bn", p_drop=0.3)
        ch  = self.img_enc.out_dim
        self.img_proj = nn.Linear(ch, hidden)
        self.act = nn.ReLU(inplace=True)
        self.mid = nn.Linear(hidden, hidden)
        self.cls = nn.Linear(hidden, n_classes)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f_img = self.img_enc(x)
        z = self.mid(self.act(self.img_proj(f_img)))
        return self.cls(z)
    def forward_image(self, image: torch.Tensor) -> torch.Tensor:
        return self.forward(image)

def _strip_prefix_mm(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    if not sd: return sd
    keys = list(sd.keys())
    prefixes = [pref for k in keys for pref in ("module.","model.","net.") if k.startswith(pref)]
    if not prefixes: return sd
    pref = max(set(prefixes), key=prefixes.count)
    return { (k[len(pref):] if k.startswith(pref) else k): v for k,v in sd.items() }

def build_multimodal_from_ckpt(ckpt_path: Path) -> nn.Module:
    """Use the same robust torch.load and mapping you used in custom_code."""
    try:
        chk = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    except TypeError:
        chk = torch.load(ckpt_path, map_location="cpu")
    except Exception:
        chk = torch.load(ckpt_path, map_location="cpu", weights_only=False, pickle_module=pickle)

    if isinstance(chk, nn.Module):
        log.info(f"[CKPT/MM] Pickled nn.Module loaded: {ckpt_path.name}")
        return chk

    # tease out state_dict (as in your custom_code)
    if "state_dict" in chk and isinstance(chk["state_dict"], dict):
        sd = chk["state_dict"]
    elif "model" in chk and isinstance(chk["model"], dict):
        sd = chk["model"]
    elif all(isinstance(v, torch.Tensor) for v in chk.values()):
        sd = chk
    else:
        tensorish = {k: v for k, v in chk.items() if isinstance(v, torch.Tensor)}
        if tensorish: sd = tensorish
        else: raise RuntimeError("No tensor state_dict in multimodal ckpt.")

    sd = _strip_prefix_mm(sd)
    keys = list(sd.keys())

    looks_multimodal = any(k.startswith("img_enc.") for k in keys) or \
                       any(k.startswith("img_proj") for k in keys) or \
                       any(k.startswith("mid") for k in keys) or \
                       any(k.startswith("cls") for k in keys)

    if looks_multimodal:
        model = MultimodalImageHead(hidden=256, n_classes=NUM_CLASSES)
        model.load_state_dict(sd, strict=False)
        return model

    # fallback (rare): treat as unimodal vgg11_bn classifier
    base = torchvision.models.vgg11_bn(weights=None)
    base.classifier[-1] = nn.Linear(base.classifier[-1].in_features, NUM_CLASSES)
    base.load_state_dict(sd, strict=False)
    return base

# Shared inference helpers
@torch.no_grad()
def predict_scores(model: nn.Module, loader: DataLoader) -> Tuple[np.ndarray, np.ndarray]:
    model.eval().to(DEVICE)
    probs, labels = [], []
    for x, y, _ in loader:
        x = x.to(DEVICE, non_blocking=True)
        logits = model.forward_image(x) if hasattr(model, "forward_image") else model(x)
        if logits.ndim == 1 or logits.shape[-1] == 1:
            p1 = torch.sigmoid(logits.float()).view(-1).cpu().numpy()
        else:
            p1 = torch.softmax(logits.float(), dim=-1)[:, 1].cpu().numpy()
        probs.append(p1); labels.append(y.numpy())
    return np.concatenate(probs), np.concatenate(labels).astype(np.int64)

def smooth_for_plot(fpr: np.ndarray, tpr: np.ndarray, n: int = 1200) -> Tuple[np.ndarray, np.ndarray]:
    grid = np.linspace(0, 1, n)
    return grid, np.interp(grid, fpr, tpr)

# Build two loaders on the same CSVs: uni_loader & mm_loader
assert IMAGES_DIR.is_dir(), f"Images directory not found: {IMAGES_DIR}"
df_all = union_csvs(CSV_PATHS)
log.info(f"[DATA] Using {len(df_all)} images from {len(CSV_PATHS)} CSV file(s).")

uni_ds = UniCsvImageDataset(IMAGES_DIR, df_all, size=IMG_SIZE)   # cv2/Albumentations
mm_ds = MmCsvImageDataset (IMAGES_DIR, df_all, size=IMG_SIZE)   # PIL/torchvision

uni_loader = DataLoader(uni_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=False)
mm_loader = DataLoader(mm_ds,  batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=False)

# Evaluate models, each with its own pipeline
results = []
for name, (ckpt_path_str, color, pipe) in MODEL_STYLES.items():
    ckpt_path = Path(ckpt_path_str)
    assert ckpt_path.is_file(), f"Checkpoint not found: {ckpt_path}"
    log.info(f"\n===== {name} | pipeline={pipe} =====")
    if pipe == "uni":
        model = build_unimodal_from_ckpt(ckpt_path)
        loader = uni_loader
    elif pipe == "mm":
        model = build_multimodal_from_ckpt(ckpt_path)
        loader = mm_loader
    else:
        raise ValueError(f"Unknown pipeline tag: {pipe}")

    # Print architecture
    print(f"\n================ Architecture: {name} ================\n{model}\n=====================================================\n")
    y_score, y_true = predict_scores(model, loader)
    fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=1)
    auroc = roc_auc_score(y_true, y_score)
    fpr_s, tpr_s = smooth_for_plot(fpr, tpr, n=1200)
    results.append({"name": name, "color": color, "fpr": fpr_s, "tpr": tpr_s, "auroc": float(auroc)})

# Plot
plt.figure(figsize=(8, 7), dpi=120)
for r in results:
    plt.plot(r["fpr"], r["tpr"], color=r["color"], lw=2.2,
             label=f"{r['name']}  (AUROC={r['auroc']:.4f})")
plt.plot([0,1],[0,1], color="gray", lw=1.0, linestyle="--", alpha=0.6)

plt.xlabel("False Positive Rate", fontsize=FONT["axis"], color=FONT["axis_color"])
plt.ylabel("True Positive Rate", fontsize=FONT["axis"], color=FONT["axis_color"])
plt.xticks(fontsize=FONT["tick"], color=FONT["axis_color"])
plt.yticks(fontsize=FONT["tick"], color=FONT["axis_color"])
leg = plt.legend(loc="lower right", fontsize=FONT["legend"])
for txt in leg.get_texts(): txt.set_color(FONT["legend_color"])
plt.grid(alpha=0.2, linestyle="--"); plt.tight_layout()
out_png = DATASET_ROOT / "roc_compare.png"
plt.savefig(str(out_png), dpi=400, bbox_inches="tight")
log.info(f"[ROC] Saved: {out_png}")
plt.show()

## END OF CODE