# Radiology CLIP Mini – Quick Demo

This notebook shows a small end-to-end demo for **radiology-clip-mini**:

1. Load config and a trained checkpoint.
2. Build dataloaders and the CLIPMini model.
3. Compute retrieval metrics (image → text, text → image).
4. Display a retrieval grid for one query.
5. Display a GradCAM-style overlay for one validation image.

> This notebook assumes:
> - You are running it from the `notebooks/` folder in the repo.
> - You have already trained a small model via:
>   `python -m rclip.train --config configs/tiny.yaml`
> - There is a valid checkpoint path in `results/latest.txt`.


In [None]:
import sys
import pathlib

import torch
import yaml
from IPython.display import display
from PIL import Image

# Ensure we can import the package from src/
ROOT = pathlib.Path.cwd().parent  # repo root (notebooks/ -> root)
SRC = ROOT / "src"
sys.path.insert(0, str(SRC))

from rclip.data import build_dataloaders
from rclip.models import CLIPMini
from rclip.eval import recall_at_k, ndcg_at_k  # from your eval module
from rclip.viz import retrieval_grid, gradcam_last_block, overlay_heatmap


In [None]:
cfg_path = ROOT / "configs" / "tiny.yaml"
with open(cfg_path, "r", encoding="utf-8") as f:
    cfg = yaml.safe_load(f)

latest_path = (ROOT / "results" / "latest.txt")
if not latest_path.exists():
    raise FileNotFoundError(
        "results/latest.txt not found. "
        "Run `python -m rclip.train --config configs/tiny.yaml` first."
    )

ckpt_path = pathlib.Path(latest_path.read_text().strip())
ckpt_path


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

dls = build_dataloaders(cfg)

model = CLIPMini(
    embed_dim=cfg["model"]["embed_dim"],
    text_model=cfg["model"]["text_encoder"],
    tau_init=cfg["model"]["temperature_init"],
).to(device)

state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state["model"])
model.eval();

sum(p.numel() for p in model.parameters()) / 1e6


In [None]:
@torch.no_grad()
def embed_split(model, loader, device):
    zs_img, zs_txt = [], []
    for b in loader:
        imgs = b["images"].to(device, non_blocking=True)
        toks = model.text_enc.tokenize(b["texts"]).to(device)
        _, zi, zt = model(imgs, toks)
        zs_img.append(zi.cpu())
        zs_txt.append(zt.cpu())
    if not zs_img:
        return None, None
    return torch.cat(zs_img, dim=0), torch.cat(zs_txt, dim=0)


zi_val, zt_val = embed_split(model, dls["val"], device)
if zi_val is None:
    raise RuntimeError("Validation loader produced no embeddings")

sim = zi_val @ zt_val.t()

it = recall_at_k(sim)
it.update(ndcg_at_k(sim))

ti = recall_at_k(sim.t())
ti.update(ndcg_at_k(sim.t()))

print("Image → Text:", it)
print("Text → Image:", ti)


In [None]:
# pick a query index in the validation set
q = 0

# similarity row for query q
row = sim[q]
topk = row.argsort(descending=True)[:5]  # top 5 texts for this image

# we need the raw validation batch texts; easiest is to rerun over loader once
texts_val = []
for b in dls["val"]:
    texts_val.extend(b["texts"])
texts_val = texts_val[: sim.size(0)]

print("Query image index:", q)
print("\nTop 5 retrieved reports:\n")
for rank, idx in enumerate(topk.tolist(), start=1):
    print(f"Rank {rank}:")
    print(texts_val[idx])
    print("-" * 60)


In [None]:
viz_dir = ckpt_path.parent / "viz"
viz_dir.mkdir(parents=True, exist_ok=True)

grid_path = viz_dir / "retrieval_grid_notebook.png"
retrieval_grid(model, dls["val"], device, grid_path, k=5)

display(Image.open(grid_path))


In [None]:
# first image from first validation batch
batch = next(iter(dls["val"]))
img = batch["images"][0]

heat = gradcam_last_block(model, img, device)
overlay = overlay_heatmap(img, heat)

overlay_path = viz_dir / "gradcam_example_notebook.png"
overlay.save(overlay_path)

display(overlay)


In [None]:
# first image from first validation batch
batch = next(iter(dls["val"]))
img = batch["images"][0]

heat = gradcam_last_block(model, img, device)
overlay = overlay_heatmap(img, heat)

overlay_path = viz_dir / "gradcam_example_notebook.png"
overlay.save(overlay_path)

display(overlay)
