# Task2: Vision Language Alingment with CLIP Style Contrastive Learning

## Installing libs

In [3]:
# !pip install scikit-learn

## Importing Libs

In [5]:
import os
import math
import random
from pathlib import Path
from typing import List, Tuple
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity



KeyboardInterrupt: 

[31mType:[39m        module
[31mString form:[39m <module 'numpy' from '/Users/imranabbas/Documents/videoModeling/.venv/lib/python3.11/site-packages/numpy/__init__.py'>
[31mFile:[39m        ~/Documents/videoModeling/.venv/lib/python3.11/site-packages/numpy/__init__.py
[31mDocstring:[39m  
NumPy
=====

Provides
  1. An array object of arbitrary homogeneous items
  2. Fast mathematical operations over arrays
  3. Linear Algebra, Fourier Transforms, Random Number Generation

How to use the documentation
----------------------------
Documentation is available in two forms: docstrings provided
with the code, and a loose standing reference guide, available from
`the NumPy homepage <https://numpy.org>`_.

We recommend exploring the docstrings using
`IPython <https://ipython.org>`_, an advanced Python shell with
TAB-completion and introspection capabilities.  See below for further
instructions.

The docstring examples assume that `numpy` has been imported as ``np``::

  >>> import numpy

## Configs

In [None]:
DATA_ROOT = "./datasets/mini_caption"  
CSV_FILE = "captions.csv"                       
OUT_DIR = "outputs_task2"
IMAGE_SIZE = 224
BATCH_SIZE = 64
EPOCHS = 10
LR = 1e-4
EMBED_DIM = 256
TEMPERATURE = 0.07
DEVICE = "cuda" if __import__('torch').cuda.is_available() else "cpu"
NUM_WORKERS = 4

os.makedirs(OUT_DIR, exist_ok=True)

## Preparing Dataset

In [None]:
class ImageTextDataset(Dataset):
    def __init__(self, csv_path: str, data_root: str = "", transform=None):
        import csv
        self.data = []
        self.data_root = data_root
        with open(csv_path, "r", encoding="utf-8") as f:
            reader = csv.reader(f)
            header = next(reader)
            for row in reader:
                if len(row) < 2:
                    continue
                img, cap = row[0].strip(), row[1].strip()
                self.data.append((img, cap))
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, caption = self.data[idx]
        path = os.path.join(self.data_root, img_path) if not os.path.isabs(img_path) else img_path
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, caption, os.path.basename(img_path)

## Transforms

In [None]:
train_transform = T.Compose([
    T.RandomResizedCrop(IMAGE_SIZE, scale=(0.7,1.0)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225])
])

val_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225])
])


## Text encoder

In [6]:
class SimpleTokenizer:
    def __init__(self, texts: List[str], min_freq: int = 1):
        from collections import Counter
        toks = []
        for t in texts:
            toks.extend(self._tokenize(t))
        counter = Counter(toks)
        self.vocab = {"<PAD>":0, "<UNK>":1}
        for w,c in counter.items():
            if c >= min_freq and w not in self.vocab:
                self.vocab[w] = len(self.vocab)
        self.inv_vocab = {i:w for w,i in self.vocab.items()}

    def _tokenize(self, text: str):
        return text.lower().split()

    def encode(self, text: str, max_len: int = 32):
        toks = self._tokenize(text)[:max_len]
        ids = [self.vocab.get(t, self.vocab["<UNK>"]) for t in toks]
        if len(ids) < max_len:
            ids = ids + [self.vocab["<PAD>"]] * (max_len - len(ids))
        return ids

class TextEncoder(nn.Module):
    def __init__(self, vocab_size:int, emb_dim:int=256, max_len:int=32, nhead=4, nlayers=2):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, emb_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=nlayers)
        self.max_len = max_len
        self.fc = nn.Linear(emb_dim, EMBED_DIM)

    def forward(self, x):
        # x: (B, L) tokens
        emb = self.token_embed(x) + self.pos_embed[:, :x.size(1), :]
        emb = emb.permute(1,0,2)  # (L,B,E) for transformer
        out = self.transformer(emb)  # (L,B,E)
        out = out.permute(1,0,2)  # (B,L,E)
        # simple pooling: mean excluding PAD
        mask = (x != 0).unsqueeze(-1).float()
        summed = (out * mask).sum(1)
        lengths = mask.sum(1).clamp(min=1.0)
        pooled = summed / lengths
        out = self.fc(pooled)
        return out

## Image encoder with projection head

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, embed_dim=EMBED_DIM, pretrained=True):
        super().__init__()
        backbone = models.resnet50(pretrained=pretrained)
        # remove fc
        modules = list(backbone.children())[:-1]
        self.backbone = nn.Sequential(*modules)  # outputs (B,2048,1,1)
        self.fc = nn.Linear(2048, embed_dim)

    def forward(self, x):
        x = self.backbone(x)  # (B,2048,1,1)
        x = x.flatten(1)
        x = self.fc(x)
        return x

## Projection + normalization

In [None]:
class MiniCLIP(nn.Module):
    def __init__(self, vocab_size, max_len=32):
        super().__init__()
        self.img_enc = ImageEncoder()
        self.txt_enc = TextEncoder(vocab_size, emb_dim=256, max_len=max_len)

    def forward(self, images, tokens):
        img_feats = self.img_enc(images)  # (B,E)
        txt_feats = self.txt_enc(tokens)  # (B,E)
        img_norm = F.normalize(img_feats, dim=1)
        txt_norm = F.normalize(txt_feats, dim=1)
        return img_norm, txt_norm

## Contrastive Loss (NT-Xent)

In [None]:
def contrastive_loss(img_emb, txt_emb, temperature=TEMPERATURE):
    # img_emb, txt_emb: (B,E) normalized
    logits = img_emb @ txt_emb.t() / temperature  # (B,B)
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.t(), labels)
    return (loss_i2t + loss_t2i) / 2.0

## Utilities: collate, training, eval

In [None]:
from functools import partial

def collate_batch(batch, tokenizer, max_len=32, transform=None):
    imgs, caps, names = zip(*batch)
    imgs = torch.stack(imgs)
    tokens = [tokenizer.encode(c, max_len=max_len) for c in caps]
    tokens = torch.tensor(tokens, dtype=torch.long)
    return imgs, tokens, names

@torch.no_grad()
def compute_embeddings(model, loader, device):
    model.eval()
    all_img = []
    all_txt = []
    names = []
    for imgs, toks, nms in tqdm(loader, desc="Embedding"):
        imgs = imgs.to(device)
        toks = toks.to(device)
        img_emb, txt_emb = model(imgs, toks)
        all_img.append(img_emb.cpu())
        all_txt.append(txt_emb.cpu())
        names.extend(nms)
    all_img = torch.cat(all_img)
    all_txt = torch.cat(all_txt)
    return all_img.numpy(), all_txt.numpy(), names

def retrieval_metrics(img_embs, txt_embs, topk=(1,5,10)):
    # compute cosine similarity
    sims = img_embs @ txt_embs.T
    n = sims.shape[0]
    ranks = np.argsort(-sims, axis=1)
    recall = {}
    for k in topk:
        correct = 0
        for i in range(n):
            # ground truth is at same index i
            if i in ranks[i,:k]:
                correct += 1
        recall[f"R@{k}"] = correct / n
    # text->image
    ranks_t2i = np.argsort(-sims.T, axis=1)
    for k in topk:
        correct = 0
        for i in range(n):
            if i in ranks_t2i[i,:k]:
                correct += 1
        recall[f"t2i_R@{k}"] = correct / n
    return recall

## Training loop

In [None]:
def train_loop(model, train_loader, val_loader, tokenizer, device, epochs=EPOCHS, lr=LR):
    model = model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    best_val = 0.0
    for epoch in range(1, epochs+1):
        model.train()
        running_loss = 0.0
        for imgs, toks, _ in tqdm(train_loader, desc=f"Epoch {epoch}"):
            imgs = imgs.to(device)
            toks = toks.to(device)
            img_emb, txt_emb = model(imgs, toks)
            loss = contrastive_loss(img_emb, txt_emb)
            opt.zero_grad()
            loss.backward()
            opt.step()
            running_loss += loss.item() * imgs.size(0)
        avg_loss = running_loss / len(train_loader.dataset)
        # validate
        img_embs, txt_embs, _ = compute_embeddings(model, val_loader, device)
        metrics = retrieval_metrics(img_embs, txt_embs)
        val_r1 = metrics['R@1']
        print(f"Epoch {epoch}: Train Loss {avg_loss:.4f} | Val R@1 {val_r1:.4f}")
        # checkpoint
        if val_r1 > best_val:
            best_val = val_r1
            torch.save({
                'model_state': model.state_dict(),
                'tokenizer_vocab': tokenizer.vocab
            }, os.path.join(OUT_DIR, 'best_mini_clip.pth'))
    print("Training finished.")



## t-SNE visualization

In [None]:
def plot_tsne(img_embs, txt_embs, names, save_path=os.path.join(OUT_DIR,'tsne.png')):
    feats = np.concatenate([img_embs, txt_embs], axis=0)
    labels = ['img']*img_embs.shape[0] + ['txt']*txt_embs.shape[0]
    tsne = TSNE(n_components=2, perplexity=30, init='pca', random_state=42)
    low = tsne.fit_transform(feats)
    plt.figure(figsize=(8,8))
    plt.scatter(low[:img_embs.shape[0],0], low[:img_embs.shape[0],1], label='images', alpha=0.6)
    plt.scatter(low[img_embs.shape[0]:,0], low[img_embs.shape[0]:,1], label='texts', alpha=0.6)
    plt.legend()
    plt.title('t-SNE of image and text embeddings')
    plt.savefig(save_path, dpi=200)
    plt.close()

## Zero-shot classification using text prompts

In [None]:
@torch.no_grad()
def zero_shot_classify(model, class_prompts: List[str], dataloader, tokenizer, device):
    model.eval()
    # encode prompts
    token_ids = [tokenizer.encode(p) for p in class_prompts]
    token_ids = torch.tensor(token_ids, dtype=torch.long).to(device)
    # get prompt embeddings (B_prompts, E)
    _, prompt_embs = model(torch.zeros(1,3,IMAGE_SIZE,IMAGE_SIZE).to(device), token_ids) if False else None
    # fallback: run text encoder directly
    txt_enc = model.txt_enc
    prompt_embs = txt_enc(token_ids)
    prompt_embs = F.normalize(prompt_embs, dim=1).cpu().numpy()
    # compute image embeddings
    img_embs, txt_embs, names = compute_embeddings(model, dataloader, device)
    sims = img_embs @ prompt_embs.T
    preds = sims.argmax(axis=1)
    return preds, names