In [None]:
# --- Imports ---
import os
from pathlib import Path

import torch
from torch import nn
import numpy as np

from transformers import CLIPModel, CLIPProcessor

print("PyTorch:", torch.__version__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


### Step 1 — Load and freeze CLIP

In [None]:
MODEL_NAME = "openai/clip-vit-base-patch32"  # good default

clip_model = CLIPModel.from_pretrained(MODEL_NAME)
clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME, use_fast=True)

In [None]:
clip_model.to(device)
clip_model.eval()

In [None]:
# Freeze parameters
for p in clip_model.parameters():
    p.requires_grad = False

In [None]:
# Count parameters
sum_params = sum(p.numel() for p in clip_model.parameters())
sum_trainable = sum(p.numel() for p in clip_model.parameters() if p.requires_grad)
print(f"Total params: {sum_params:,}")
print(f"Trainable params: {sum_trainable:,}")

### Load the PixMo dataset for alignment 

In [None]:
from datasets import load_dataset
pixmo_ds = load_dataset("allenai/pixmo-cap", split="all")
print(pixmo_ds)
print(pixmo_ds[0])


#### Since Pixmo is image URL's, we first fetch the images in an array to load them 

In [None]:
import requests
from PIL import Image
from io import BytesIO
from multiprocessing import Pool, cpu_count


# ------------------------------------------------------------
# Fetch one image from URL
# ------------------------------------------------------------
def fetch_image(url: str) -> Image.Image:
    resp = requests.get(url, timeout=10)
    resp.raise_for_status()
    img = Image.open(BytesIO(resp.content)).convert("RGB")
    return img


In [None]:
# Build a small working subset (e.g., first 16)
N_EXAMPLES = 500
subset = pixmo_ds.select(range(N_EXAMPLES))

In [None]:


# ------------------------------------------------------------
# Worker for multiprocessing
# ------------------------------------------------------------
def _fetch_single(args):
    idx, ex = args
    url = ex["image_url"]
    cap = ex["caption"]
    try:
        img = fetch_image(url)
        return idx, img, cap, None
    except Exception as e:
        return idx, None, None, e


# ------------------------------------------------------------
# Parallel image–caption collector
# ------------------------------------------------------------
def collect_image_caption_pairs_mp(subset, processes=None):
    processes = processes or min(8, cpu_count())

    images = []
    captions = []

    # prepare input as (idx, example) pairs
    tasks = [(i, ex) for i, ex in enumerate(subset)]

    with Pool(processes=processes) as pool:
        for idx, img, cap, err in pool.imap_unordered(_fetch_single, tasks):
            print("Fetched index", idx)
            if err:
                print(f"Skipping index {idx} due to error: {err}")
                continue
            images.append(img)
            captions.append(cap)

    print("Collected", len(images), "image-caption pairs")
    return images, captions


In [None]:
images, captions = collect_image_caption_pairs_mp(subset)

In [None]:
from IPython.display import display

display(images[0])
print(captions[0][:400], "...")

### Encoding images wof Pixmo with CLIP

In [None]:
def encode_images_texts_from_pil(images, captions, batch_size=8):
    all_img_embs = []
    all_txt_embs = []

    for i in range(0, len(images), batch_size):
        batch_imgs = images[i:i+batch_size]
        batch_txts = captions[i:i+batch_size]

        inputs = clip_processor(
            text=batch_txts,
            images=batch_imgs,
            return_tensors="pt",
            padding=True,
            truncation=True,              # ✅ <--- this is the key
            max_length=77
        ).to(device)

        with torch.no_grad():
            outputs = clip_model(**inputs)
            img_emb = outputs.image_embeds       # (B, d)
            txt_emb = outputs.text_embeds        # (B, d)

        img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)
        txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True)

        all_img_embs.append(img_emb.cpu())
        all_txt_embs.append(txt_emb.cpu())

    all_img_embs = torch.cat(all_img_embs, dim=0)
    all_txt_embs = torch.cat(all_txt_embs, dim=0)

    return all_img_embs, all_txt_embs


In [None]:
img_emb, txt_emb = encode_images_texts_from_pil(images, captions, batch_size=8)
print("Image embeddings shape:", img_emb.shape)
print("Text embeddings shape :", txt_emb.shape)


### Similarity matrix + retrieval on PixMo

In [None]:
img_np = img_emb.numpy()
txt_np = txt_emb.numpy()

sim_matrix = img_np @ txt_np.T
print("Similarity matrix shape:", sim_matrix.shape)


In [None]:
def top1_retrieval(sim_matrix, captions):
    N = sim_matrix.shape[0]
    correct = 0
    for i in range(N):
        j = sim_matrix[i].argmax()
        is_correct = (i == j)
        if is_correct:
            correct += 1
        print(f"Image {i}: best caption idx = {j} | correct={is_correct}")
        print("  GT caption (truncated):", captions[i][:120].replace("\n", " "), "...")
        print("  Top caption (truncated):", captions[j][:120].replace("\n", " "), "...")
        print()
    print(f"Top-1 accuracy (identity pairing): {correct}/{N} = {correct / N:.2f}")

top1_retrieval(sim_matrix, captions)


### Save PixMo embeddings for the projector step

In [None]:
SAVE_ROOT = Path("artifacts/clip_projector_poc_pixmo")
SAVE_ROOT.mkdir(parents=True, exist_ok=True)

np.save(SAVE_ROOT / "img_emb.npy", img_np)
np.save(SAVE_ROOT / "txt_emb.npy", txt_np)

with open(SAVE_ROOT / "captions.txt", "w") as f:
    for cap in captions:
        f.write(cap.replace("\n", " ") + "\n")

print("Saved embeddings and captions to:", SAVE_ROOT)


### Loading those saved embeddings

In [None]:
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader

SAVE_ROOT = Path("artifacts/clip_projector_poc_pixmo")

img_np = np.load(SAVE_ROOT / "img_emb.npy")
txt_np = np.load(SAVE_ROOT / "txt_emb.npy")

print("Loaded shapes:", img_np.shape, txt_np.shape)


In [None]:
img_tensor = torch.from_numpy(img_np).float()
txt_tensor = torch.from_numpy(txt_np).float()

print(img_tensor.shape, img_tensor.dtype)
print(txt_tensor.shape, txt_tensor.dtype)

dataset = TensorDataset(img_tensor, txt_tensor)


In [None]:
import torch.nn as nn

d_clip = img_tensor.shape[1]     # typically 512
d_proj = 1024                    # you can change this

class ProjectorMLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=None):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = out_dim
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        return self.net(x)

P_img = ProjectorMLP(d_clip, d_proj).to(device)
P_txt = ProjectorMLP(d_clip, d_proj).to(device)

sum_p_img = sum(p.numel() for p in P_img.parameters())
sum_p_txt = sum(p.numel() for p in P_txt.parameters())
print(f"P_img params: {sum_p_img:,}, P_txt params: {sum_p_txt:,}")


In [None]:
import torch.nn.functional as F

def contrastive_loss(z_img, z_txt, temperature=0.07):
    """
    z_img, z_txt: (B, d_proj)
    Returns: scalar loss, img2txt_acc, txt2img_acc
    """
    # Normalize
    z_img = F.normalize(z_img, dim=-1)
    z_txt = F.normalize(z_txt, dim=-1)

    # Similarity matrix: (B, B)
    logits = z_img @ z_txt.T    # cosine sims
    logits = logits / temperature

    B = logits.shape[0]
    targets = torch.arange(B, device=logits.device)

    # Image -> text (rows over texts)
    loss_i2t = F.cross_entropy(logits, targets)

    # Text -> image (columns over images)
    loss_t2i = F.cross_entropy(logits.T, targets)

    loss = (loss_i2t + loss_t2i) / 2.0

    # Compute top-1 accuracy for monitoring
    with torch.no_grad():
        pred_i2t = logits.argmax(dim=1)
        pred_t2i = logits.argmax(dim=0)
        acc_i2t = (pred_i2t == targets).float().mean().item()
        acc_t2i = (pred_t2i == targets).float().mean().item()

    return loss, acc_i2t, acc_t2i


### Step 9 — DataLoader + optimizer

In [None]:
BATCH_SIZE = 64    # if your N_EXAMPLES is small, you can use 16 or 8
LR = 1e-3

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

optimizer = torch.optim.Adam(
    list(P_img.parameters()) + list(P_txt.parameters()),
    lr=LR,
    weight_decay=1e-4,
)


In [None]:
from tqdm.auto import tqdm

EPOCHS = 10

for epoch in range(1, EPOCHS + 1):
    P_img.train()
    P_txt.train()

    epoch_loss = 0.0
    epoch_acc_i2t = 0.0
    epoch_acc_t2i = 0.0
    n_batches = 0

    for img_batch, txt_batch in loader:
        img_batch = img_batch.to(device)
        txt_batch = txt_batch.to(device)

        optimizer.zero_grad()

        z_img = P_img(img_batch)      # (B, d_proj)
        z_txt = P_txt(txt_batch)      # (B, d_proj)

        loss, acc_i2t, acc_t2i = contrastive_loss(z_img, z_txt, temperature=0.07)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc_i2t += acc_i2t
        epoch_acc_t2i += acc_t2i
        n_batches += 1

    print(
        f"Epoch {epoch:02d} | "
        f"loss={epoch_loss/n_batches:.4f} | "
        f"img→txt acc={epoch_acc_i2t/n_batches:.3f} | "
        f"txt→img acc={epoch_acc_t2i/n_batches:.3f}"
    )


In [None]:
P_img.eval()
P_txt.eval()

with torch.no_grad():
    all_img = img_tensor.to(device)
    all_txt = txt_tensor.to(device)

    z_img_all = P_img(all_img)
    z_txt_all = P_txt(all_txt)

    z_img_all = F.normalize(z_img_all, dim=-1)
    z_txt_all = F.normalize(z_txt_all, dim=-1)

sim_proj = (z_img_all @ z_txt_all.T).cpu().numpy()
print("Projected similarity matrix shape:", sim_proj.shape)


In [None]:
# reload captions for readability
caps = []
with open(SAVE_ROOT / "captions.txt") as f:
    for line in f:
        caps.append(line.strip())

top1_retrieval(sim_proj, caps)
