# Part 1: Data Preparation & Embedding

This notebook loads the provided dataset (images + captions), selects a pre-trained multimodal model (CLIP), generates vector embeddings for both text and images in batches, and saves the results to the `embeddings/` folder.

- Dataset paths are preconfigured for this repository layout.
- Uses Hugging Face Transformers `CLIPModel` and `AutoProcessor`.
- Outputs: `image_embeddings.pkl`, `text_embeddings.pkl`, `image_text_pairs.pkl`, `metadata.json`.

Run the cells top-to-bottom.


In [None]:
# Setup & Imports
import os
import json
import pickle
from pathlib import Path
from typing import List, Dict, Tuple

import torch
from PIL import Image
from tqdm.auto import tqdm

# Hugging Face
from transformers import AutoProcessor, CLIPModel

# Repository-relative paths
REPO_ROOT = Path("..").resolve().parent if (Path.cwd().name == "notebooks") else Path(".").resolve()
DATA_DIR = REPO_ROOT / "data"
IMAGES_DIR = DATA_DIR / "Images"
CAPTIONS_PATH = DATA_DIR / "captions.txt"
EMBED_DIR = REPO_ROOT / "embeddings"
EMBED_DIR.mkdir(parents=True, exist_ok=True)

# Device selection
DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Reproducibility
torch.manual_seed(42)

# Batch config
IMAGE_BATCH_SIZE = 64
TEXT_BATCH_SIZE = 256

# Model config (can be changed to other CLIP variants)
MODEL_NAME = "openai/clip-vit-base-patch32"



In [None]:
# Load captions and build image-text pairs

def read_captions(captions_file: Path) -> Dict[str, str]:
    pairs: Dict[str, str] = {}
    if not captions_file.exists():
        raise FileNotFoundError(f"Captions file not found: {captions_file}")
    with captions_file.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            # Expect lines like: filename.jpg\tcaption text
            if "\t" in line:
                fname, caption = line.split("\t", 1)
            elif "," in line:
                # fallback: CSV-like
                fname, caption = line.split(",", 1)
            else:
                # fallback: first space split
                parts = line.split(" ", 1)
                if len(parts) < 2:
                    continue
                fname, caption = parts
            pairs[fname.strip()] = caption.strip()
    return pairs

caption_map = read_captions(CAPTIONS_PATH)
print(f"Loaded {len(caption_map)} captions.")

# Filter image files that exist
all_image_files = {p.name: p for p in IMAGES_DIR.glob("**/*.jpg")}
kept_filenames = [fn for fn in caption_map.keys() if fn in all_image_files]
print(f"Images with captions and present on disk: {len(kept_filenames)}")

image_text_pairs = [(fn, caption_map[fn]) for fn in kept_filenames]
image_paths = [all_image_files[fn] for fn in kept_filenames]
texts = [caption_map[fn] for fn in kept_filenames]

# Quick peek
for i in range(min(3, len(image_text_pairs))):
    print(image_text_pairs[i])


In [None]:
# Load model and processor (CLIP)
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = CLIPModel.from_pretrained(MODEL_NAME)
model.to(DEVICE)
model.eval()

# Get embedding dimensions
image_emb_dim = model.visual_projection.out_features
text_emb_dim = model.text_projection.out_features
print(f"Image emb dim: {image_emb_dim}, Text emb dim: {text_emb_dim}")


In [None]:
# Batched embedding generation helpers

def chunked(iterable, n):
    for i in range(0, len(iterable), n):
        yield iterable[i:i+n]

@torch.no_grad()
def compute_image_embeddings(paths: List[Path]) -> torch.Tensor:
    embs = []
    for batch_paths in tqdm(list(chunked(paths, IMAGE_BATCH_SIZE)), desc="Images -> embeddings"):
        images = [Image.open(p).convert("RGB") for p in batch_paths]
        inputs = processor(images=images, return_tensors="pt", padding=True).to(DEVICE)
        outputs = model.get_image_features(**inputs)
        # Normalize to unit length as CLIP standard
        outputs = outputs / outputs.norm(dim=-1, keepdim=True)
        embs.append(outputs.cpu())
    return torch.cat(embs, dim=0) if embs else torch.empty((0, image_emb_dim))

@torch.no_grad()
def compute_text_embeddings(texts: List[str]) -> torch.Tensor:
    embs = []
    for batch_texts in tqdm(list(chunked(texts, TEXT_BATCH_SIZE)), desc="Texts -> embeddings"):
        inputs = processor(text=batch_texts, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
        outputs = model.get_text_features(**inputs)
        outputs = outputs / outputs.norm(dim=-1, keepdim=True)
        embs.append(outputs.cpu())
    return torch.cat(embs, dim=0) if embs else torch.empty((0, text_emb_dim))

# Execute embedding computation
image_embeddings = compute_image_embeddings(image_paths)
text_embeddings = compute_text_embeddings(texts)

assert image_embeddings.shape[0] == len(image_paths)
assert text_embeddings.shape[0] == len(texts)
print("Embeddings computed:")
print(" - images:", tuple(image_embeddings.shape))
print(" - texts:", tuple(text_embeddings.shape))


In [None]:
# Persist embeddings and metadata

# Convert to plain Python for pickling
image_embeddings_np = image_embeddings.numpy()
text_embeddings_np = text_embeddings.numpy()

# Save embeddings
with open(EMBED_DIR / "image_embeddings.pkl", "wb") as f:
    pickle.dump(image_embeddings_np, f)
with open(EMBED_DIR / "text_embeddings.pkl", "wb") as f:
    pickle.dump(text_embeddings_np, f)

# Save image-text pairs
pairs = [(fn, cap) for fn, cap in image_text_pairs]
with open(EMBED_DIR / "image_text_pairs.pkl", "wb") as f:
    pickle.dump(pairs, f)

# Save metadata
metadata = {
    "model": MODEL_NAME,
    "device": DEVICE,
    "num_items": len(pairs),
    "image_embedding_dim": int(image_embeddings_np.shape[1]) if image_embeddings_np.size else 0,
    "text_embedding_dim": int(text_embeddings_np.shape[1]) if text_embeddings_np.size else 0,
}
with open(EMBED_DIR / "metadata.json", "w", encoding="utf-8") as f:
    json.dump(metadata, f, indent=2)

print("Saved embeddings and metadata to", EMBED_DIR)


In [None]:
# Quick verification

print("Files in embeddings dir:")
for p in sorted(EMBED_DIR.glob("*")):
    print(" -", p.name, p.stat().st_size, "bytes")

# Load a few to verify shapes
with open(EMBED_DIR / "image_embeddings.pkl", "rb") as f:
    img_embs = pickle.load(f)
with open(EMBED_DIR / "text_embeddings.pkl", "rb") as f:
    txt_embs = pickle.load(f)
with open(EMBED_DIR / "image_text_pairs.pkl", "rb") as f:
    pairs = pickle.load(f)

print("Loaded shapes:")
print(" - images:", img_embs.shape)
print(" - texts:", txt_embs.shape)
print(" - pairs:", len(pairs))

# Show an example
if len(pairs):
    print("Example:")
    print(pairs[0])
