In [10]:
# If needed, uncomment and run:
! uv pip install -q torch torchvision timm sentence-transformers datasets pillow tqdm


In [11]:
import os, math, json, time, random, pathlib
from typing import List
from dataclasses import dataclass

import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import timm
from torchvision import transforms
from sentence_transformers import SentenceTransformer
from datasets import load_dataset

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


In [12]:
# Data/output
PROJECT_ROOT = "freeze_align_poc"
DATA_ROOT    = f"{PROJECT_ROOT}/data/flickr30k"
IMG_DIR      = f"{DATA_ROOT}/images"
CAP_FILE     = f"{DATA_ROOT}/captions.json"
CKPT_PATH    = f"{PROJECT_ROOT}/checkpoints/pair_img_txt.pt"

os.makedirs(IMG_DIR, exist_ok=True)
os.makedirs(f"{PROJECT_ROOT}/checkpoints", exist_ok=True)

# Flickr30k subset size for quick POC
NUM_PAIRS    = 10_000   # try 30_000 if you want a bigger run

# Training
BATCH_SIZE   = 64
ACCUM_STEPS  = 1        # set >1 for larger effective batch
MAX_STEPS    = 2000     # increase for better alignment (e.g., 5k)
LR           = 1e-3
WD           = 0.01
WARMUP       = 100
LOG_EVERY    = 50


In [None]:
# Loads a split from HF and materializes to local images + captions.json
ds = load_dataset("lmms-lab/flickr30k", split=f"test[:{NUM_PAIRS}]")


In [20]:
ds

Dataset({
    features: ['image', 'caption', 'sentids', 'img_id', 'filename'],
    num_rows: 10000
})

In [21]:
records = []
for i, row in tqdm(enumerate(ds), total=len(ds), desc="Saving images"):
    image: Image.Image = row["image"]
    caption = row["caption"]
    img_path = f"{IMG_DIR}/{i:06d}.jpg"
    image.save(img_path, quality=90)
    records.append({"image": f"images/{i:06d}.jpg", "caption": caption})

with open(CAP_FILE, "w") as f:
    json.dump(records, f)
print(f"Saved {len(records)} pairs to {DATA_ROOT}")


Saving images: 100%|██████████| 10000/10000 [00:40<00:00, 249.44it/s]


Saved 10000 pairs to freeze_align_poc/data/flickr30k


In [22]:
class ImgTxtDataset(Dataset):
    def __init__(self, root, captions_file, transform):
        self.root = pathlib.Path(root)
        with open(captions_file) as f:
            self.items = json.load(f)
        self.transform = transform
    def __len__(self): return len(self.items)
    def __getitem__(self, i):
        rec = self.items[i]
        img = Image.open(self.root / rec["image"]).convert("RGB")
        return self.transform(img), rec["caption"]

img_tf = transforms.Compose([
    transforms.Resize(256, interpolation=Image.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
])

ds = ImgTxtDataset(DATA_ROOT, CAP_FILE, img_tf)

def collate(batch):
    imgs, caps = zip(*batch)
    return torch.stack(imgs), list(caps)

loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=4, pin_memory=True, collate_fn=collate)

len(ds), next(iter(loader))[0].shape, next(iter(loader))[1][0]


(10000,
 torch.Size([64, 3, 224, 224]),
 ['Two little girls are sitting on a yellow rubber ball toy .',
  'The two girls are playing on a yellow sit-and-bounce .',
  'Two children sitting atop a large yellow bounce toy .',
  'Two girls bounce on a large yellow ball indoors .',
  'Two smiling girls sit on a yellow bouncy ball .'])

In [27]:
# Fix for vision encoder selection (DINOv2 via timm)
import timm, torch

print("timm version:", timm.__version__)
print("available dinov2 models:", timm.list_models("*dinov2*"))

# Prefer large → base → small, depending on what's installed
CANDIDATES = [
    "vit_large_patch14_dinov2.lvd142m",
    "vit_base_patch14_dinov2.lvd142m",
    "vit_small_patch14_dinov2.lvd142m",
]
for name in CANDIDATES:
    if name in timm.list_models("*dinov2*"):
        VISION_MODEL_NAME = name
        break
else:
    # last-resort fallback if your timm build lacks dinov2
    VISION_MODEL_NAME = "vit_base_patch16_224"

print("Using vision model:", VISION_MODEL_NAME)


timm version: 1.0.21
available dinov2 models: ['vit_base_patch14_dinov2', 'vit_base_patch14_reg4_dinov2', 'vit_giant_patch14_dinov2', 'vit_giant_patch14_reg4_dinov2', 'vit_large_patch14_dinov2', 'vit_large_patch14_reg4_dinov2', 'vit_small_patch14_dinov2', 'vit_small_patch14_reg4_dinov2']
Using vision model: vit_base_patch16_224


In [None]:
# Vision encoder: DINOv2-L, pooled features
vision_model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0)
vision_model.eval().to(device)
for p in vision_model.parameters(): p.requires_grad = False
vision_dim = vision_model.num_features

# Text encoder: SentenceTransformer all-roberta-large-v1
text_model = SentenceTransformer("sentence-transformers/all-roberta-large-v1", device=device)
text_model.eval()
for p in text_model.parameters(): p.requires_grad = False
text_dim = text_model.get_sentence_embedding_dimension()


In [31]:
print("Dims:", dict(vision_dim=vision_dim, text_dim=text_dim))
assert vision_dim == text_dim == 1024, "If dims differ, add a linear mapper to unify."

Dims: {'vision_dim': 768, 'text_dim': 1024}


AssertionError: If dims differ, add a linear mapper to unify.

In [37]:
 # Replace your "Dims + assert" cell with this
print("Dims:", dict(vision_dim=vision_dim, text_dim=text_dim))

# Pick a shared target dim automatically (no information loss on the larger side)
TARGET_DIM = max(vision_dim, text_dim)     # e.g., 1024 if text=1024 and vision=768
print("TARGET_DIM:", TARGET_DIM)


Dims: {'vision_dim': 768, 'text_dim': 1024}
TARGET_DIM: 1024


In [38]:
# Drop-in replacement for the projectors + wrapper cell

def l2_normalize(x, dim=-1, eps=1e-8):
    return x / (x.norm(dim=dim, keepdim=True) + eps)

def maybe_linear(in_dim, out_dim):
    if in_dim == out_dim:
        return nn.Identity()
    return nn.Linear(in_dim, out_dim)

class TokenProjector(nn.Module):
    def __init__(self, dim, hidden=None):
        super().__init__()
        hidden = hidden or max(1024, dim * 2)
        self.net = nn.Sequential(nn.Linear(dim, hidden), nn.ReLU(), nn.Linear(hidden, dim))
    def forward(self, x):  # [B, D]
        return x + self.net(x)

class GlobalProjector(nn.Module):
    def __init__(self, dim, hidden=None):
        super().__init__()
        hidden = hidden or max(1024, dim * 2)
        self.net = nn.Sequential(nn.Linear(dim, hidden), nn.ReLU(), nn.Linear(hidden, dim))
    def forward(self, x):  # [B, D]
        return self.net(x)

class FreezeAlignIT(nn.Module):
    def __init__(self, vision_backbone, text_backbone, vision_dim, text_dim, target_dim):
        super().__init__()
        self.vision = vision_backbone
        self.text   = text_backbone

        # NEW: light mappers to a shared TARGET_DIM
        self.v_map = maybe_linear(vision_dim, target_dim)
        self.t_map = maybe_linear(text_dim,   target_dim)

        # Projectors operate in TARGET_DIM
        self.v_token  = TokenProjector(target_dim)
        self.v_global = GlobalProjector(target_dim)
        self.t_token  = TokenProjector(target_dim)
        self.t_global = GlobalProjector(target_dim)

        self.temperature = nn.Parameter(torch.tensor(0.07))
        self.target_dim  = target_dim

    @torch.no_grad()
    def _vision(self, images):           # [B, Dv]
        return self.vision(images)

    @torch.no_grad()
    def _text(self, captions, device):   # [B, Dt]
        return self.text.encode(captions, convert_to_tensor=True, device=device)

    def encode_image(self, images):
        v = self._vision(images)                 # [B, Dv]
        v = self.v_map(v)                        # [B, TARGET_DIM]
        v = self.v_token(v)
        v = self.v_global(v)
        return l2_normalize(v)

    def encode_text(self, captions, device):
        t = self._text(captions, device)         # [B, Dt]
        t = self.t_map(t)                        # [B, TARGET_DIM]
        t = self.t_token(t)
        t = self.t_global(t)
        return l2_normalize(t)

    def forward(self, images, captions, device):
        zi = self.encode_image(images)
        zt = self.encode_text(captions, device)
        return zi, zt, self.temperature.clamp(0.01, 1.0)

model = FreezeAlignIT(vision_model, text_model, vision_dim, text_dim, TARGET_DIM).to(device)
print("Trainable params (M):", sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6)


Trainable params (M): 17.576961


In [43]:
def clip_loss(img_z, txt_z, temperature):
    logits = (img_z @ txt_z.t()) / temperature
    labels = torch.arange(len(img_z), device=img_z.device)
    return (F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)) / 2


# Recreate optimizer & scheduler after redefining model
optim = torch.optim.AdamW(
    [p for n,p in model.named_parameters()
     if p.requires_grad and not n.startswith("vision") and not n.startswith("text")],  # includes v_map/t_map
    lr=LR, weight_decay=WD
)

class CosineAnnealWarmup(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_steps, max_steps, last_epoch=-1):
        self.warmup_steps = warmup_steps; self.max_steps = max_steps
        super().__init__(optimizer, last_epoch)
    def get_lr(self):
        step = self.last_epoch + 1
        if step <= self.warmup_steps:
            s = step / max(1, self.warmup_steps)
            return [base * s for base in self.base_lrs]
        progress = (step - self.warmup_steps) / max(1, self.max_steps - self.warmup_steps)
        factor = 0.5 * (1 + math.cos(math.pi * progress))
        return [base * factor for base in self.base_lrs]

sched  = CosineAnnealWarmup(optim, warmup_steps=WARMUP, max_steps=MAX_STEPS)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))


In [40]:
@torch.no_grad()
def encode_dataset(model, dataset, batch=128):
    model.eval()
    all_img, all_txt = [], []
    for i in range(0, len(dataset), batch):
        imgs = [dataset[j][0] for j in range(i, min(i+batch, len(dataset)))]
        caps = [dataset[j][1] for j in range(i, min(i+batch, len(dataset)))]
        imgs = torch.stack(imgs).to(device, non_blocking=True)
        zi = model.encode_image(imgs)
        zt = model.encode_text(caps, device)
        all_img.append(zi.cpu()); all_txt.append(zt.cpu())
    return torch.cat(all_img), torch.cat(all_txt)

@torch.no_grad()
def recall_at_k(A, B, ks=(1,5,10)):
    sims  = A @ B.t()
    ranks = torch.argsort(sims, dim=1, descending=True)
    target = torch.arange(A.size(0)).unsqueeze(1)
    out = {}
    for k in ks:
        hit = (ranks[:, :k] == target).any(dim=1).float().mean().item()
        out[f"R@{k}"] = hit
    return out


In [41]:
img_z0, txt_z0 = encode_dataset(model, ds, batch=256)
print("Baseline  Image→Text:", recall_at_k(img_z0, txt_z0))
print("Baseline  Text→Image:", recall_at_k(txt_z0, img_z0))

Baseline  Image→Text: {'R@1': 0.0, 'R@5': 9.999999747378752e-05, 'R@10': 0.0006000000284984708}
Baseline  Text→Image: {'R@1': 9.999999747378752e-05, 'R@5': 0.0006000000284984708, 'R@10': 0.00139999995008111}


In [42]:
model.train()
step, t0 = 0, time.time()

for epoch in range(9999):
    for imgs, caps in loader:
        imgs = imgs.to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            zi, zt, temp = model(imgs, caps, device)
            loss = clip_loss(zi, zt, temp) / ACCUM_STEPS

        scaler.scale(loss).backward()

        if (step + 1) % ACCUM_STEPS == 0:
            scaler.unscale_(optim)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optim); scaler.update()
            optim.zero_grad(set_to_none=True)
            sched.step()

        step += 1
        if step % LOG_EVERY == 0:
            print(f"step {step:5d} | loss {loss.item()*ACCUM_STEPS:.4f} | T={temp.item():.3f} | lr={sched.get_last_lr()[0]:.2e}")

        if step >= MAX_STEPS:
            break
    if step >= MAX_STEPS:
        break

print(f"Training done in {time.time()-t0:.1f}s")

  with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):


step    50 | loss 1.3210 | T=0.059 | lr=5.10e-04
step   100 | loss 1.3711 | T=0.054 | lr=1.00e-03
step   150 | loss 1.5962 | T=0.056 | lr=9.98e-04
step   200 | loss 1.0407 | T=0.052 | lr=9.93e-04
step   250 | loss 0.9373 | T=0.051 | lr=9.84e-04
step   300 | loss 0.6604 | T=0.050 | lr=9.73e-04
step   350 | loss 0.4601 | T=0.041 | lr=9.58e-04
step   400 | loss 0.7422 | T=0.042 | lr=9.39e-04
step   450 | loss 0.5671 | T=0.045 | lr=9.18e-04
step   500 | loss 0.5670 | T=0.033 | lr=8.94e-04
step   550 | loss 0.4150 | T=0.036 | lr=8.67e-04
step   600 | loss 0.4277 | T=0.038 | lr=8.38e-04
step   650 | loss 0.1910 | T=0.030 | lr=8.06e-04
step   700 | loss 0.2739 | T=0.037 | lr=7.73e-04
step   750 | loss 0.3360 | T=0.035 | lr=7.37e-04
step   800 | loss 0.1929 | T=0.033 | lr=7.00e-04
step   850 | loss 0.1935 | T=0.031 | lr=6.62e-04
step   900 | loss 0.3501 | T=0.030 | lr=6.22e-04
step   950 | loss 0.1497 | T=0.029 | lr=5.81e-04
step  1000 | loss 0.1032 | T=0.026 | lr=5.40e-04
step  1050 | loss 0.

In [44]:
labels  = ["cat", "dog", "car", "airplane", "flower", "person"]
prompts = [f"A photo of a {c}." for c in labels]

@torch.no_grad()
def zero_shot_scores(model, images, label_prompts):
    txt = model.encode_text(label_prompts, device)      # [C, D]
    vi  = model.encode_image(images.to(device))          # [B, D]
    sims = vi @ txt.t()
    preds = sims.argmax(dim=1)
    return preds, sims

N = min(16, len(ds))
sample_imgs = torch.stack([ds[i][0] for i in range(N)])
preds, sims = zero_shot_scores(model, sample_imgs, prompts)
[(i, labels[preds[i].item()]) for i in range(N)]


[(0, 'person'),
 (1, 'airplane'),
 (2, 'flower'),
 (3, 'car'),
 (4, 'cat'),
 (5, 'car'),
 (6, 'person'),
 (7, 'person'),
 (8, 'person'),
 (9, 'person'),
 (10, 'airplane'),
 (11, 'car'),
 (12, 'dog'),
 (13, 'car'),
 (14, 'car'),
 (15, 'cat')]

In [45]:
torch.save({
    "state_dict": model.state_dict(),
    "vision_dim": vision_dim,
    "text_dim": text_dim,
    "config": {
        "LR": LR, "WD": WD, "MAX_STEPS": MAX_STEPS,
        "BATCH_SIZE": BATCH_SIZE, "ACCUM_STEPS": ACCUM_STEPS
    }
}, CKPT_PATH)
CKPT_PATH


'freeze_align_poc/checkpoints/pair_img_txt.pt'

In [46]:
# Sanity reload
ckpt = torch.load(CKPT_PATH, map_location="cpu")
model.load_state_dict(ckpt["state_dict"], strict=True)
model.eval()
img_z_r, txt_z_r = encode_dataset(model, ds, batch=256)
print("Reloaded Image→Text:", recall_at_k(img_z_r, txt_z_r))
print("Reloaded Text→Image:", recall_at_k(txt_z_r, img_z_r))


Reloaded Image→Text: {'R@1': 0.5141000151634216, 'R@5': 0.8799999952316284, 'R@10': 0.9595000147819519}
Reloaded Text→Image: {'R@1': 0.49570000171661377, 'R@5': 0.8748000264167786, 'R@10': 0.9567999839782715}
