# CLIP Image Embeddings — Pre-compute and Cache

This notebook generates **CLIP image embeddings** for every image in the image bank
and caches them to disk for fast reuse during multimodal ASR rescoring.

## Why pre-compute?

During inference, CLIP's vision encoder is the slowest part of the multimodal pipeline.
Since our image bank is fixed (the same 20 therapeutic images are reused across sessions),
we encode them **once** and store the resulting 512-dimensional vectors in a `.npz` file.
At inference time, loading a cached vector is instant.

## What is CLIP?

**CLIP** (Contrastive Language-Image Pre-training) is a model from OpenAI that learns a
shared embedding space for images and text. Given an image and a set of text candidates,
CLIP can rank how well each text describes the image via **cosine similarity** between
their embeddings. We use `openai/clip-vit-base-patch32` (ViT-B/32), which produces
512-dimensional embeddings.

## 0) Imports and constants

- `torch` — runs CLIP on GPU/MPS/CPU
- `PIL.Image` — loads `.png` files into pixel arrays
- `transformers.CLIPModel / CLIPProcessor` — the HuggingFace CLIP implementation
- `numpy` — the cached embeddings are stored as NumPy arrays

In [None]:
import json
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor

DEFAULT_CLIP_MODEL = "openai/clip-vit-base-patch32"
DEFAULT_IMAGE_DIR = Path("../imagegen/images")
DEFAULT_CACHE_DIR = Path("cache")

## 1) Device selection and model loading

CLIP runs fastest on a GPU (`cuda`). On Apple Silicon Macs, `mps` provides GPU-like
acceleration. We fall back to `cpu` otherwise.

`load_clip` downloads the model weights from HuggingFace (cached after the first call),
moves the model to the chosen device, and sets it to **eval mode** (disabling dropout
and other training-only behaviours).

In [None]:
def _pick_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def load_clip(
    model_name: str = DEFAULT_CLIP_MODEL, device: str | None = None
) -> tuple[CLIPModel, CLIPProcessor, str]:
    """Load CLIP model and processor onto the chosen device."""
    if device is None:
        device = _pick_device()
    processor = CLIPProcessor.from_pretrained(model_name)
    model = CLIPModel.from_pretrained(model_name).to(device)
    model.eval()
    return model, processor, device

## 2) Feature extraction helpers

We use explicit calls to `model.vision_model()` → `model.visual_projection()` (and the
text equivalents) rather than the higher-level `get_image_features()` shortcut. This
ensures compatibility across different `transformers` versions.

### Why L2-normalise?

CLIP embeddings are compared via **cosine similarity**:

$$\text{sim}(a, b) = \frac{a \cdot b}{\|a\| \, \|b\|}$$

If both vectors are unit-length (L2-normalised), cosine similarity simplifies to a **dot
product**, which is much cheaper to compute.

In [None]:
def _get_image_features(model: CLIPModel, pixel_values: torch.Tensor) -> torch.Tensor:
    """Extract L2-normalised image embeddings."""
    vision_out = model.vision_model(pixel_values=pixel_values)
    features = model.visual_projection(vision_out.pooler_output)
    return features / features.norm(dim=-1, keepdim=True)


def _get_text_features(
    model: CLIPModel, input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
    """Extract L2-normalised text embeddings."""
    text_out = model.text_model(input_ids=input_ids, attention_mask=attention_mask)
    features = model.text_projection(text_out.pooler_output)
    return features / features.norm(dim=-1, keepdim=True)

## 3) Batch encoding functions

### `encode_images`
Iterates over every `.png` in the image directory, encodes each one through CLIP's
vision encoder, and returns a dictionary mapping **filename → 512-d embedding**.

### `encode_texts`
Encodes a batch of text strings in one forward pass. Used later by the multimodal
rescoring pipeline to encode ASR hypothesis candidates.

In [None]:
def encode_images(
    model: CLIPModel,
    processor: CLIPProcessor,
    image_dir: Path | str,
    device: str,
) -> dict[str, np.ndarray]:
    """Encode every PNG in *image_dir* and return {filename: embedding}."""
    image_dir = Path(image_dir)
    image_files = sorted(image_dir.glob("*.png"))
    if not image_files:
        raise FileNotFoundError(f"No .png files found in {image_dir}")

    embeddings: dict[str, np.ndarray] = {}
    for img_path in image_files:
        image = Image.open(img_path).convert("RGB")
        inputs = processor(images=image, return_tensors="pt")
        pixel_values = inputs["pixel_values"].to(device)
        with torch.no_grad():
            features = _get_image_features(model, pixel_values)
        embeddings[img_path.name] = features.cpu().numpy().squeeze()
        print(f"  Encoded {img_path.name}")

    return embeddings


def encode_texts(
    model: CLIPModel,
    processor: CLIPProcessor,
    texts: list[str],
    device: str,
) -> np.ndarray:
    """Encode a batch of texts and return L2-normalised embeddings (N, D)."""
    inputs = processor(
        text=texts, return_tensors="pt", padding=True, truncation=True
    )
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    with torch.no_grad():
        features = _get_text_features(model, input_ids, attention_mask)
    return features.cpu().numpy()

## 4) Cache utilities

Embeddings are saved with `np.savez` (compressed NumPy archive). The file maps each
image filename (e.g. `img_001.png`) to its 512-d float32 vector.

- **`cache_embeddings`** — write the dictionary to `.npz`
- **`load_cached_embeddings`** — read it back as a plain `dict`

In [None]:
def cache_embeddings(embeddings: dict[str, np.ndarray], cache_path: Path | str):
    """Save image embeddings to a .npz file."""
    cache_path = Path(cache_path)
    cache_path.parent.mkdir(parents=True, exist_ok=True)
    np.savez(str(cache_path), **embeddings)
    print(f"Cached {len(embeddings)} embeddings \u2192 {cache_path}")


def load_cached_embeddings(cache_path: Path | str) -> dict[str, np.ndarray]:
    """Load embeddings previously saved with *cache_embeddings*."""
    data = np.load(str(cache_path))
    return dict(data)

## 5) Run: encode all images and cache

This cell loads the CLIP model, encodes all 20 images in `imagegen/images/`, saves
the embeddings to `cache/clip_image_embeddings.npz`, and writes a metadata JSON
summarising what was cached.

In [None]:
model, processor, device = load_clip()
print(f"Device: {device}  |  CLIP model: {DEFAULT_CLIP_MODEL}")

embeddings = encode_images(model, processor, DEFAULT_IMAGE_DIR, device)

cache_path = DEFAULT_CACHE_DIR / "clip_image_embeddings.npz"
cache_embeddings(embeddings, cache_path)

metadata = {
    "model": DEFAULT_CLIP_MODEL,
    "num_images": len(embeddings),
    "image_files": sorted(embeddings.keys()),
    "embedding_dim": int(next(iter(embeddings.values())).shape[0]),
}
meta_path = DEFAULT_CACHE_DIR / "clip_metadata.json"
with open(meta_path, "w") as f:
    json.dump(metadata, f, indent=2)
print(f"Metadata \u2192 {meta_path}")

## 6) Quick sanity check

Load the cached embeddings back and verify dimensions. Then test CLIP's
text-image alignment by comparing a few image prompts against `img_001.png`
(a cat sleeping on a windowsill).

In [None]:
cached = load_cached_embeddings(cache_path)
print(f"Loaded {len(cached)} embeddings, each {next(iter(cached.values())).shape}")

test_texts = [
    "a cat sleeping on a windowsill",
    "a dog catching a ball in a park",
    "two children building a sandcastle",
]
text_embs = encode_texts(model, processor, test_texts, device)
img_emb = cached["img_001.png"].reshape(1, -1)

sims = (img_emb @ text_embs.T).squeeze()
print("\nSimilarity of img_001.png to:")
for txt, sim in zip(test_texts, sims):
    print(f"  {sim:.4f}  {txt}")

## Summary

This notebook:

1. Loaded **CLIP ViT-B/32** and selected the best available device
2. Encoded all **20 therapeutic images** into 512-d embeddings
3. Cached them to `cache/clip_image_embeddings.npz` for instant loading at inference time
4. Verified that CLIP similarities align with image content

The cached embeddings are used by the multimodal ASR pipeline (`multimodal_asr`) to
rescore Whisper transcription candidates based on visual context.