# Export Embeddings (Notebook Version)

This notebook is the notebook version of `scripts/export_embeddings.py` - precompute and save embeddings for faster retrieval.


In [None]:
# Install dependencies
%pip install torch torchvision numpy pillow pyyaml tqdm scikit-learn transformers

# Mount Google Drive if needed
# from google.colab import drive
# drive.mount('/content/drive')


In [None]:
import sys
from pathlib import Path
import json
import numpy as np
import torch
import yaml
from tqdm import tqdm

# Add project to path
BASE_DIR = Path('/content/CLIP_model') if Path('/content/CLIP_model').exists() else Path.cwd().parent
sys.path.insert(0, str(BASE_DIR))

from src.data.coco_dataset import build_coco_dataloader
from src.models.clip_model import CLIPModel
from src.utils.tokenization import SimpleTokenizer


In [None]:
# Configuration
CONFIG_PATH = BASE_DIR / "configs/clip_coco_small.yaml"
CHECKPOINT_PATH = BASE_DIR / "checkpoints/best_model.pt"
OUTPUT_DIR = BASE_DIR / "embeddings"

# Load config
with open(CONFIG_PATH, "r") as f:
    config = yaml.safe_load(f)

print(f"Config: {CONFIG_PATH}")
print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Output directory: {OUTPUT_DIR}")


In [None]:
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load model
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model = CLIPModel(
    vision_config=config["model"]["vision"],
    text_config=config["model"]["text"],
).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print("Model loaded successfully!")


In [None]:
# Build tokenizer
tokenizer = SimpleTokenizer(
    vocab_size=config["model"]["text"]["vocab_size"], min_freq=2
)

temp_loader = build_coco_dataloader(
    annotation_file=str(BASE_DIR / config["data"]["val"]["annotation_file"]),
    image_dir=str(BASE_DIR / config["data"]["val"]["image_dir"]),
    batch_size=32,
    shuffle=False,
    num_workers=2,
    max_samples=5000,
)

all_captions = []
for batch in tqdm(temp_loader, desc="Building vocab"):
    all_captions.extend(batch["caption"])
tokenizer.build_vocab(all_captions)
print(f"Tokenizer vocabulary size: {len(tokenizer)}")


In [None]:
# Create data loader
def collate_fn(batch, tokenizer, max_seq_length):
    """Custom collate function."""
    images = torch.stack([item["image"] for item in batch])
    captions = [item["caption"] for item in batch]
    image_ids = [item["image_id"] for item in batch]

    token_ids = [
        tokenizer.encode(cap, max_length=max_seq_length) for cap in captions
    ]
    token_tensor = torch.tensor(token_ids)
    mask = token_tensor == tokenizer.get_pad_token_id()

    return {
        "image": images,
        "text_tokens": token_tensor,
        "text_mask": mask,
        "caption": captions,
        "image_id": image_ids,
    }

from torch.utils.data import DataLoader

val_dataset = build_coco_dataloader(
    annotation_file=str(BASE_DIR / config["data"]["val"]["annotation_file"]),
    image_dir=str(BASE_DIR / config["data"]["val"]["image_dir"]),
    batch_size=config["eval"]["batch_size"],
    shuffle=False,
    num_workers=2,
    subset_percentage=config["data"]["val"].get("subset_percentage"),
).dataset

val_loader = DataLoader(
    val_dataset,
    batch_size=config["eval"]["batch_size"],
    shuffle=False,
    num_workers=2,
    collate_fn=lambda b: collate_fn(
        b, tokenizer, config["model"]["text"]["max_seq_length"]
    ),
)

print(f"Validation batches: {len(val_loader)}")


In [None]:
# Compute embeddings
print("Computing embeddings...")
image_embeddings = []
text_embeddings = []
image_ids = []
captions = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Encoding"):
        images = batch["image"].to(device)
        text_tokens = batch["text_tokens"].to(device)
        text_mask = batch["text_mask"].to(device)

        img_emb = model.encode_image(images)
        txt_emb = model.encode_text(text_tokens, text_mask)

        image_embeddings.append(img_emb.cpu().numpy())
        text_embeddings.append(txt_emb.cpu().numpy())
        image_ids.extend(batch["image_id"])
        captions.extend(batch["caption"])

image_embeddings = np.concatenate(image_embeddings, axis=0)
text_embeddings = np.concatenate(text_embeddings, axis=0)

print(f"Image embeddings shape: {image_embeddings.shape}")
print(f"Text embeddings shape: {text_embeddings.shape}")


In [None]:
# Save embeddings
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

np.save(OUTPUT_DIR / "image_embeddings.npy", image_embeddings)
np.save(OUTPUT_DIR / "text_embeddings.npy", text_embeddings)

metadata = {
    "image_ids": image_ids,
    "captions": captions,
    "num_images": len(image_ids),
    "embedding_dim": image_embeddings.shape[1],
}

with open(OUTPUT_DIR / "metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)

print(f"\nEmbeddings saved to {OUTPUT_DIR}")
print(f"  Image embeddings: {image_embeddings.shape}")
print(f"  Text embeddings: {text_embeddings.shape}")
print(f"  Metadata: {len(image_ids)} images")
