# Build Image Index (DeepLesion Kaggle)

This notebook:
1) Loads config and metadata  
2) Preprocesses CT slices (3-window)  
3) Encodes images with CLIP  
4) Builds and saves a FAISS index (+ ids parquet)

Requires your `src/` modules (`pgr`, `pgr_dl`). Use `pip install -e .` or add repo root to `PYTHONPATH`.


In [None]:
# --- Imports & setup
import sys, json
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm import tqdm
import torch

# Add repo root to path if needed
repo_root = Path.cwd().resolve().parents[0] if Path.cwd().name == "notebooks" else Path.cwd()
sys.path.append(str(repo_root / "src"))

from pgr import encoders, index
from pgr.utils import to_tensor_and_norm, get_device, seed_everything
from pgr_dl import io_deeplesion as io, windowing

# --- Config (edit paths here or load from YAML) ---
CFG = {
    "seed": 42,
    "paths": {
        "data_root": "/data/deeplesion_kaggle",   # <<< set your path
        "results_dir": str(repo_root / "results" / "kaggle_v1"),
    },
    "preprocess": {
        "input_size": 224,
        "use_ct3ch": True,
        "normalize": "clip",
    },
    "model": {
        "image_encoder": "ViT-B/16",
        "pretrained": "openai",
    },
    "index": {
        "kind": "flat",
        "metric": "ip",
    },
    "data": {
        "split": "test",
        "max_samples": 5000,          # set None for all
    },
    "batch": 64,
}

seed_everything(CFG.get("seed", 42))
res_dir = Path(CFG["paths"]["results_dir"]); res_dir.mkdir(parents=True, exist_ok=True)
print("Results dir:", res_dir)


In [None]:
# --- Load metadata
meta = io.load_metadata(CFG["paths"]["data_root"])
if CFG["data"]["split"]:
    meta = meta[meta["split"] == CFG["data"]["split"]].copy()
if CFG["data"]["max_samples"]:
    meta = meta.sample(min(int(CFG["data"]["max_samples"]), len(meta)),
                       random_state=CFG.get("seed", 42)).reset_index(drop=True)

print(meta.head(3))
print("Rows:", len(meta))
if len(meta) == 0:
    raise RuntimeError("No rows after filtering—check your split/max_samples/settings.")


In [None]:
# --- Encoder init
device = get_device(None)  # auto-pick cuda if available
enc = encoders.ClipEncoder(
    model_name=CFG["model"]["image_encoder"],
    pretrained=CFG["model"].get("pretrained", "openai"),
    device=str(device),
)
D = enc.embed_dim
print(f"Encoder: {CFG['model']['image_encoder']}  embed_dim={D}  device={device}")


In [None]:
# --- Embed images (batched)
BATCH = int(CFG["batch"])
all_vecs = []
ids = []
size = int(CFG["preprocess"]["input_size"])

for i in tqdm(range(0, len(meta), BATCH), desc="Embedding"):
    batch_df = meta.iloc[i:i+BATCH]
    xs = []
    for _, r in batch_df.iterrows():
        img = io.load_slice(r.img_path)
        x3  = windowing.ct3ch(img) if CFG["preprocess"]["use_ct3ch"] else img
        t   = to_tensor_and_norm(x3, size=size)   # (1,3,H,W)
        xs.append(t)
        ids.append((str(r.study_id), int(r.slice_idx)))
    X = torch.cat(xs, dim=0).to(device, non_blocking=True)  # (B,3,H,W)
    V = enc.encode_images(X)                                # (B,D) float32, L2-normed
    all_vecs.append(V.cpu().numpy())

image_embs = np.vstack(all_vecs).astype("float32")
ids_arr = np.array(ids, dtype=object)
print("Embeddings:", image_embs.shape, "IDs:", ids_arr.shape)
assert image_embs.shape[0] == ids_arr.shape[0], "Row count mismatch"
assert image_embs.shape[1] == D, "Embedding dim mismatch"


In [None]:
# --- Persist embeddings & ids
np.save(res_dir / "image_embs.npy", image_embs)

ids_df = pd.DataFrame(ids_arr, columns=["study_id","slice_idx"])
for col in ["img_path","body_part","lesion_type"]:
    if col in meta.columns:
        ids_df[col] = meta[col].values
ids_df.to_parquet(res_dir / "ids.parquet", index=False)

print("Saved:", res_dir / "image_embs.npy")
print("Saved:", res_dir / "ids.parquet")


In [None]:
# --- Build FAISS index
fa = index.FaissIndex(
    dim=image_embs.shape[1],
    kind=CFG.get("index", {}).get("kind", "flat"),
    metric=CFG.get("index", {}).get("metric", "ip"),
)
fa.add(image_embs)
fa.save(res_dir / "index.faiss")
print("Saved:", res_dir / "index.faiss")


In [None]:
# --- Save a small run manifest
manifest = {
    "cfg": CFG,
    "counts": {"rows": int(len(meta)), "dim": int(image_embs.shape[1])},
    "paths": {
        "embeddings": str(res_dir / "image_embs.npy"),
        "ids": str(res_dir / "ids.parquet"),
        "index": str(res_dir / "index.faiss"),
    },
    "seed": int(CFG.get("seed", 42)),
}
with open(res_dir / "manifest.json", "w", encoding="utf-8") as f:
    json.dump(manifest, f, indent=2, default=str)
print("Saved:", res_dir / "manifest.json")
print("Done.")
