# Phrase Query → Top-k Panel + CAMs

Loads FAISS + ids and lets you query by phrase:
- 3×3 panel of top-k (raw, CAM overlay, optional point/box)
- Optional 15-sec GIF cycling top-k overlays


In [None]:
# --- Imports & setup
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import torch

repo_root = Path.cwd().resolve().parents[0] if Path.cwd().name == "notebooks" else Path.cwd()
sys.path.append(str(repo_root / "src"))

from pgr import encoders, index, viz, cam
from pgr.utils import seed_everything, get_device
from pgr_dl import io_deeplesion as io, adapters

In [None]:
# --- Config (point to results from 00)
RESULTS_DIR = repo_root / "results" / "kaggle_v1"
INDEX_PATH  = RESULTS_DIR / "index.faiss"
IDS_PATH    = RESULTS_DIR / "ids.parquet"
MODEL_NAME  = "ViT-B/16"
PANEL_ROWS, PANEL_COLS = 3, 3
K = PANEL_ROWS * PANEL_COLS
ALPHA = 0.5  # CAM overlay

assert INDEX_PATH.exists() and IDS_PATH.exists(), "Run 00_build_index_dl.ipynb first."
seed_everything(42)

In [None]:
# --- Load artifacts
ids_df = pd.read_parquet(IDS_PATH)
fa = index.FaissIndex.load(str(INDEX_PATH))
enc = encoders.ClipEncoder(model_name=MODEL_NAME, device=str(get_device(None)))

print(f"IDs: {ids_df.shape}, Index dim: {fa.dim}, Device: {next(enc.model.parameters()).device}")


In [None]:
# --- Query phrase → top-k search
phrase = "liver lesion"   # <<< change here

q_vec = adapters.encode_phrase(enc, phrase).cpu().numpy()   # (1,D) float32, L2-normed
scores, I, _ = fa.search(q_vec, k=K)

hits = []
for rank, idx in enumerate(I[0], start=1):
    row = ids_df.iloc[int(idx)]
    hits.append({**row.to_dict(), "rank": rank, "score": float(scores[0, rank-1])})
hits_df = pd.DataFrame(hits)
hits_df.head(3)

In [None]:
# --- Build 3×3 panel with CAM overlays
from pgr.viz import overlay_cam, draw_point_or_box, grid

# Wrap the image encoder so CAM sees a module whose forward -> image embeddings
import torch.nn as nn
class ImageEmbedder(nn.Module):
    def __init__(self, clip_model): 
        super().__init__()
        self.clip = clip_model
    def forward(self, x): 
        return self.clip.encode_image(x)

img_encoder = ImageEmbedder(enc.model).eval()

tiles = []
for _, r in hits_df.iterrows():
    # base RGB (3-window CT for consistent overlays)
    img = io.load_slice(r.img_path)
    rgb = adapters.prepare_slice(r, size=224)      # (1,3,H,W) tensor (CLIP-normalized)
    # keep a uint8 RGB for overlay background
    rgb_uint8 = io.load_slice(r.img_path)
    from pgr_dl import windowing
    if rgb_uint8.ndim == 2 or rgb_uint8.dtype != np.uint8:
        rgb_uint8 = windowing.ct3ch(rgb_uint8)

    # CAM
    pvec = adapters.encode_phrase(enc, phrase)     # (1,D)
    cam_map = cam.gradcam_vit(rgb, pvec, img_encoder)  # (H,W) in [0,1]

    # Overlay + optional GT box
    over = overlay_cam(rgb_uint8, cam_map, alpha=ALPHA)
    gt_box = None
    has_box_cols = {"bbox_x1","bbox_y1","bbox_x2","bbox_y2"}.issubset(hits_df.columns)
    if has_box_cols and pd.notna(r.get("bbox_x1")):
        gt_box = (int(r["bbox_x1"]), int(r["bbox_y1"]), int(r["bbox_x2"]), int(r["bbox_y2"]))
    over = draw_point_or_box(over, point=None, box=gt_box)
    tiles.append(over)

panel = grid(tiles, rows=PANEL_ROWS, cols=PANEL_COLS)
panel_path = RESULTS_DIR / f"panel_{phrase.replace(' ','_')}.png"
panel.save(panel_path)
panel_path


In [None]:
# --- (Optional) 15-sec GIF
from pgr.viz import save_gif

gif_frames = tiles  # already PIL Images
gif_path = RESULTS_DIR / f"demo_{phrase.replace(' ','_')}.gif"
save_gif(gif_frames, str(gif_path), fps=max(1, int(len(gif_frames) / 15.0)))  # ~15s total
gif_path
