# Retrieval Script (Notebook Version)

This notebook is the notebook version of `scripts/retrieve.py` - search for images or captions using precomputed embeddings.


In [None]:
# Install dependencies
%pip install torch torchvision numpy pillow pyyaml tqdm scikit-learn transformers

# Mount Google Drive if needed
# from google.colab import drive
# drive.mount('/content/drive')


In [None]:
import sys
from pathlib import Path
import json
import numpy as np
import torch
import yaml
from PIL import Image
from torchvision import transforms

# Add project to path
BASE_DIR = Path('/content/CLIP_model') if Path('/content/CLIP_model').exists() else Path.cwd().parent
sys.path.insert(0, str(BASE_DIR))

from src.models.clip_model import CLIPModel
from src.utils.tokenization import SimpleTokenizer


In [None]:
# Configuration - adjust these
CONFIG_PATH = BASE_DIR / "configs/clip_coco_small.yaml"
CHECKPOINT_PATH = BASE_DIR / "checkpoints/best_model.pt"
EMBEDDINGS_DIR = BASE_DIR / "embeddings"
TOP_K = 5

# Query options - set one of these
TEXT_QUERY = "red sports car drifting on wet road"  # Set to None to skip text search
IMAGE_PATH = None  # Set to image path for image search, e.g., "images/val2017/000000000139.jpg"

# Load config
with open(CONFIG_PATH, "r") as f:
    config = yaml.safe_load(f)

print(f"Config: {CONFIG_PATH}")
print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Embeddings: {EMBEDDINGS_DIR}")
print(f"Text query: {TEXT_QUERY}")
print(f"Image path: {IMAGE_PATH}")


In [None]:
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load model
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model = CLIPModel(
    vision_config=config["model"]["vision"],
    text_config=config["model"]["text"],
).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print("Model loaded successfully!")


In [None]:
# Load embeddings
if EMBEDDINGS_DIR.exists():
    image_embeddings = np.load(EMBEDDINGS_DIR / "image_embeddings.npy")
    with open(EMBEDDINGS_DIR / "metadata.json", "r") as f:
        metadata = json.load(f)
    print(f"Loaded {len(metadata['image_ids'])} precomputed embeddings")
else:
    print(f"ERROR: Embeddings not found at {EMBEDDINGS_DIR}")
    print("Please run 07_export_embeddings.ipynb first!")
    raise FileNotFoundError(f"Embeddings directory not found: {EMBEDDINGS_DIR}")


In [None]:
# Build tokenizer
tokenizer = SimpleTokenizer(
    vocab_size=config["model"]["text"]["vocab_size"], min_freq=2
)
tokenizer.build_vocab(metadata["captions"])
print(f"Tokenizer vocabulary size: {len(tokenizer)}")


In [None]:
# Text-to-image retrieval
if TEXT_QUERY:
    print(f"\nQuery: '{TEXT_QUERY}'")
    token_ids = tokenizer.encode(
        TEXT_QUERY, max_length=config["model"]["text"]["max_seq_length"]
    )
    token_tensor = torch.tensor([token_ids]).to(device)
    mask = token_tensor == tokenizer.get_pad_token_id()

    with torch.no_grad():
        query_embedding = model.encode_text(token_tensor, mask).cpu().numpy()

    similarities = np.dot(query_embedding, image_embeddings.T).squeeze()
    top_k_indices = np.argsort(similarities)[-TOP_K:][::-1]

    print(f"\nTop {TOP_K} results:")
    for i, idx in enumerate(top_k_indices, 1):
        print(f"{i}. Image ID: {metadata['image_ids'][idx]}")
        print(f"   Score: {similarities[idx]:.4f}")
        print(f"   Caption: {metadata['captions'][idx]}")
else:
    print("Skipping text-to-image search (TEXT_QUERY is None)")


In [None]:
# Image-to-text retrieval
if IMAGE_PATH:
    image_path = BASE_DIR / IMAGE_PATH
    print(f"\nQuery Image: {image_path}")
    
    if not image_path.exists():
        print(f"ERROR: Image not found: {image_path}")
    else:
        img = Image.open(image_path).convert("RGB")
        img_transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        img_tensor = img_transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            image_embedding = model.encode_image(img_tensor).cpu().numpy()

        similarities = np.dot(image_embedding, image_embeddings.T).squeeze()
        top_k_indices = np.argsort(similarities)[-TOP_K:][::-1]

        print(f"\nTop {TOP_K} similar images:")
        for i, idx in enumerate(top_k_indices, 1):
            print(f"{i}. Image ID: {metadata['image_ids'][idx]}")
            print(f"   Score: {similarities[idx]:.4f}")
            print(f"   Caption: {metadata['captions'][idx]}")
else:
    print("Skipping image-to-text search (IMAGE_PATH is None)")
