# Reverse Image Search — Starter Lab (Beginner‑Friendly)

**Goal:** Build a minimal reverse image search pipeline that:
- Ingests a small local image folder
- Extracts image embeddings (deep features via ResNet50 or a simple color‑histogram fallback)
- Performs **top‑K** nearest‑neighbour search (starter: pure NumPy cosine similarity)
- Visualizes the results in a neat grid

**What *you* will implement later (TODOs in notebook):**
- Replace the baseline NumPy search with **FAISS / Annoy / HNSW** OR your own similarity function
- Store & fetch embeddings from **MongoDB** (schema stub provided)
- (Optional) Try a different embedding model (e.g., CLIP, ViT) and compare quality

## Folder structure

Put some test images in: `./data/images` (the notebook searches recursively). Any of these formats will work: `jpg, jpeg, png, webp, bmp`.

> Tip: 100–500 images is great for testing. More is fine, but the first run will take longer.

In [None]:
# --- Optional: installs (uncomment if needed) ---
# %pip install --quiet pillow numpy pandas pyarrow matplotlib tqdm torchvision torch
# If you plan to do the MongoDB extension:
# %pip install --quiet pymongo
# For FAISS/Annoy extensions (advanced):
# %pip install --quiet faiss-cpu annoy

In [None]:
import os, math, glob, hashlib, pathlib, io
from dataclasses import dataclass
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd

from PIL import Image
import matplotlib.pyplot as plt

# Deep features (optional but recommended)
import torch
from tqdm import tqdm

try:
    import torchvision
    from torchvision import transforms
except Exception as e:
    torchvision = None
    transforms = None
    print("torchvision not available, deep features will be skipped. Fallback: color histogram. Error:", e)

# ---------- Config ----------
DATA_DIR = "./data/images"          # put your images here (recursively)
EMBEDDINGS_PATH = "./embeddings.parquet"
EMBEDDING_MODEL = "resnet50"        # options: 'resnet50' or 'hist' (color histogram fallback)
IMAGE_EXTS = {".jpg",".jpeg",".png",".webp",".bmp"}

# Auto-select device
if torch.cuda.is_available():
    DEVICE = "cuda"
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    DEVICE = "mps"  # Apple Silicon
else:
    DEVICE = "cpu"

print(f"Using device: {DEVICE}")
print(f"Images dir: {os.path.abspath(DATA_DIR)}")

In [None]:
def list_images(root: str) -> List[str]:
    paths = []
    root = os.path.expanduser(root)
    for ext in IMAGE_EXTS:
        paths.extend(glob.glob(os.path.join(root, "**", f"*{ext}"), recursive=True))
    paths = sorted(set(paths))
    return paths

def sha256_of_file(path: str) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(8192), b""):
            h.update(chunk)
    return h.hexdigest()

def load_image(path: str, max_size: int = 512) -> Image.Image:
    img = Image.open(path).convert("RGB")
    # quick thumbnail to bound max side
    img.thumbnail((max_size, max_size))
    return img

def show_image_grid(paths: List[str], titles: List[str] = None, cols: int = 5, figsize: Tuple[int,int] = (12, 8)):
    if not paths:
        print("No images to display.")
        return
    rows = math.ceil(len(paths)/cols)
    plt.figure(figsize=figsize)
    for i, p in enumerate(paths):
        plt.subplot(rows, cols, i+1)
        try:
            img = load_image(p, max_size=512)
            plt.imshow(img)
        except Exception as e:
            plt.text(0.1, 0.5, f"Failed to load\n{os.path.basename(p)}\n{e}")
        plt.axis("off")
        if titles and i < len(titles):
            plt.title(titles[i], fontsize=8)
    plt.tight_layout()
    plt.show()

## Embedding extractors

We provide **two** options:

1) **ResNet50 (ImageNet)** deep features — better quality, needs `torch` + `torchvision` and will download weights on first run.
2) **HSV Color Histogram** (8×8×8 bins) — very simple baseline that works offline. Good as a teaching aid.

In [None]:
@dataclass
class EmbeddingResult:
    path: str
    sha256: str
    embedding: np.ndarray

def get_resnet50_extractor():
    if torchvision is None:
        return None, None

    # Use torchvision 0.15+ API for weights; fallback if older
    weights = None
    try:
        weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V2
        model = torchvision.models.resnet50(weights=weights)
        preprocess = weights.transforms()
    except Exception:
        # Older API
        model = torchvision.models.resnet50(pretrained=True)
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    # Remove the classification head to get a 2048-dim feature
    model.fc = torch.nn.Identity()
    model.eval()
    model.to(DEVICE)

    @torch.no_grad()
    def embed_one(img: Image.Image) -> np.ndarray:
        t = preprocess(img).unsqueeze(0).to(DEVICE)
        feats = model(t)  # [1, 2048]
        v = feats.cpu().numpy().astype("float32")[0]
        # L2 normalize for cosine
        v = v / (np.linalg.norm(v) + 1e-12)
        return v

    return model, embed_one

def embed_histogram(img: Image.Image, bins_per_channel: int = 8) -> np.ndarray:
    # HSV histogram (8x8x8 = 512-dim)
    hsv = img.convert("HSV")
    arr = np.array(hsv)
    hist = []
    for ch in range(3):
        h, _ = np.histogram(arr[..., ch], bins=bins_per_channel, range=(0, 256), density=True)
        hist.append(h)
    v = np.concatenate(hist).astype("float32")
    v = v / (np.linalg.norm(v) + 1e-12)
    return v

In [None]:
def build_embeddings(paths: List[str], method: str = EMBEDDING_MODEL) -> List[EmbeddingResult]:
    results = []
    embed_one = None

    if method == "resnet50":
        model, embed_one = get_resnet50_extractor()
        if embed_one is None:
            print("torchvision not available. Falling back to 'hist'.")
            method = "hist"

    for p in tqdm(paths, desc=f"Embedding ({method})"):
        try:
            img = load_image(p, max_size=1024 if method=='resnet50' else 512)
            if method == "resnet50":
                v = embed_one(img)
            else:
                v = embed_histogram(img)
            results.append(EmbeddingResult(path=p, sha256=sha256_of_file(p), embedding=v))
        except Exception as e:
            print("Failed:", p, e)
    return results

## Build & persist embeddings

This will scan `DATA_DIR`, compute embeddings, and save them to a parquet file so you can load them quickly next time.

In [None]:
image_paths = list_images(DATA_DIR)
print(f"Found {len(image_paths)} images.")

embeds = build_embeddings(image_paths, method=EMBEDDING_MODEL)

# Pack into a DataFrame
if embeds:
    dim = embeds[0].embedding.shape[0]
else:
    dim = 0
df = pd.DataFrame({
    "path": [e.path for e in embeds],
    "sha256": [e.sha256 for e in embeds],
    "embedding": [e.embedding.tolist() for e in embeds],
    "dim": [dim for _ in embeds],
})
print(df.head())

# Save
if len(df):
    df.to_parquet(EMBEDDINGS_PATH, index=False)
    print(f"Saved embeddings -> {EMBEDDINGS_PATH} (n={len(df)})")
else:
    print("No embeddings to save. Did you place images in ./data/images ?")

## Baseline search (NumPy cosine similarity)

> **TODO for you later:** Replace this block with FAISS/Annoy/HNSW for speed and try different metrics (cosine/L2/dot). Also, implement batched & cached search.

In [None]:
def cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12))

def search_topk(query_vec: np.ndarray, matrix: np.ndarray, k: int = 10) -> List[Tuple[int, float]]:
    # Returns list of (index, score) sorted by score desc
    sims = matrix @ query_vec / (np.linalg.norm(matrix, axis=1) * (np.linalg.norm(query_vec) + 1e-12))
    idxs = np.argpartition(-sims, min(k, len(sims)-1))[:k]
    idxs = idxs[np.argsort(-sims[idxs])]
    return [(int(i), float(sims[i])) for i in idxs]

# Materialize the matrix (n, d)
if len(df):
    M = np.vstack(df["embedding"].apply(lambda x: np.array(x, dtype="float32")).values)  # shape: (n, d)
    print("Embedding matrix:", M.shape)
else:
    M = np.zeros((0,0), dtype="float32")

## Query by example

Pick an image path from your dataset (or drag & drop a path into the input). The notebook will:
1. Compute its embedding
2. Run a top‑K search
3. Display the nearest neighbours

In [None]:
# Choose a query image from the dataset:
if len(image_paths) >= 1:
    QUERY_IMAGE = image_paths[0]
else:
    QUERY_IMAGE = ""  # <- put a path string here like "./data/images/myphoto.jpg"

print("Query image:", QUERY_IMAGE)

In [None]:
def embed_query(path: str, method: str = EMBEDDING_MODEL) -> np.ndarray:
    if method == "resnet50":
        model, embed_one = get_resnet50_extractor()
        if embed_one is None:
            method = "hist"
    img = load_image(path, max_size=1024 if method=='resnet50' else 512)
    if method == "resnet50":
        _, embed_one = get_resnet50_extractor()
        v = embed_one(img)
    else:
        v = embed_histogram(img)
    return v.astype("float32")

if QUERY_IMAGE and os.path.exists(QUERY_IMAGE) and len(df):
    q = embed_query(QUERY_IMAGE, method=EMBEDDING_MODEL)
    topk = search_topk(q, M, k=10)
    print("Top-K results (index, cosine):", topk[:5])

    result_paths = [df.iloc[i]["path"] for i, _ in topk]
    result_scores = [f"{s:.3f}" for _, s in topk]

    print("\nQuery preview:")
    show_image_grid([QUERY_IMAGE], titles=["QUERY"], cols=1, figsize=(4,4))

    print("Results:")
    show_image_grid(result_paths, titles=result_scores, cols=5, figsize=(12, 8))
else:
    print("Set QUERY_IMAGE to an existing file and ensure embeddings were built.")

## (Extension) Store & fetch from MongoDB — *you implement this*

Below is a **schema stub** and sample code outline (commented). Fill this in your environment when ready.

**Document shape** (one per image):
```json
{
  "sha256": "…",
  "path": "./data/images/cats/pic123.jpg",
  "dim": 2048,
  "embedding": [0.01, -0.11, ...]   // float list
}
```

In [None]:
# from pymongo import MongoClient

# def get_mongo_collection(uri: str, db_name: str = "reverse_image_search", coll_name: str = "embeddings"):
#     client = MongoClient(uri)
#     db = client[db_name]
#     return db[coll_name]

# def mongo_upsert_embeddings(coll, df: pd.DataFrame, batch_size: int = 500):
#     # Upsert by sha256
#     ops = []
#     for _, row in df.iterrows():
#         doc = {
#             "sha256": row["sha256"],
#             "path": row["path"],
#             "dim": int(row["dim"]),
#             "embedding": row["embedding"],
#         }
#         ops.append(
#             {
#                 "update_one": {
#                     "filter": {"sha256": row["sha256"]},
#                     "update": {"$set": doc},
#                     "upsert": True,
#                 }
#             }
#         )
#         if len(ops) >= batch_size:
#             coll.bulk_write(ops)
#             ops = []
#     if ops:
#         coll.bulk_write(ops)

# def mongo_fetch_all_embeddings(coll) -> pd.DataFrame:
#     cur = coll.find({}, {"_id": 0, "sha256": 1, "path": 1, "dim": 1, "embedding": 1})
#     rows = list(cur)
#     return pd.DataFrame(rows)

## (Extension) Replace NumPy search with FAISS / Annoy — *you implement this*

**Hint:** FAISS cosine search can be done by normalizing vectors and using `IndexFlatIP` (inner product).
Annoy supports cosine, angular, Euclidean distances and is super easy to start with.

In [None]:
# import faiss
# def build_faiss_index(matrix: np.ndarray):
#     xb = matrix.astype("float32").copy()
#     # Normalize for cosine (so IP == cosine)
#     faiss.normalize_L2(xb)
#     index = faiss.IndexFlatIP(xb.shape[1])
#     index.add(xb)
#     return index
#
# def faiss_topk(index, query_vec: np.ndarray, k: int = 10):
#     q = query_vec.astype("float32")[None, :].copy()
#     faiss.normalize_L2(q)
#     sims, idxs = index.search(q, k)
#     return list(zip(idxs[0].tolist(), sims[0].tolist()))
#
# # Example:
# # index = build_faiss_index(M)
# # faiss_results = faiss_topk(index, q, k=10)

## Exercises & Discussion

1. **Quality check:** Compare results from **ResNet50** vs **color histogram**. Where does histogram fail? Why?
2. **Speed:** Time the baseline search (`search_topk`) vs FAISS/Annoy on 10K images (synthetic if needed).
3. **MongoDB:** Implement the upsert & fetch functions, rebuild the search from DB rows (not local parquet).
4. **Robustness:** Add a **query-by-cropped-region** function (bonus: detect object with `torchvision.ops`).
5. **UI:** Export a small **Flask/FastAPI** endpoint that accepts an uploaded image and returns top‑K paths.
6. **Evaluation:** If you have class labels, compute **precision@k** for a few queries.
7. **CLIP/Vision Transformers (optional):** Swap the encoder and compare.