# Assignment: Fine-tuning CLIP for Image–Text Retrieval (Part 2)

**Name (Student ID):**  
1. LI JIARU (A0332008U)

2. JIN YINAN (A0327317E)

3. SHI YANCHUN (A0328710J)

4. XIAO XIAO (A0332142W)

**Total Marks: 25**  
**Rubric:** Each step shows its mark weight. Auto-check cells provide partial verification. Keep the notebook runnable on a personal computer.


## Objective

In Part 1, CLIP was used as a fixed feature extractor. In this part, you will **fine-tune** CLIP for image–text retrieval on a small paired dataset (**Flickr8k**) using a **parameter-efficient** approach. We will **freeze** the image and text encoders and train only the **projection layers**.

**Tasks**
- Load and split Flickr8k, wrap it with a PyTorch `Dataset` and `DataLoader`.
- Implement two contrastive objectives: **InfoNCE** and **Triplet Loss**.
- Fine-tune only the projection layers and compare with a **zero-shot** baseline.
- Evaluate with **Recall@K** (I2T and T2I).
- Run small experiments and answer short questions.


> **Note**  
> This notebook is designed to be runnable on a typical laptop. If memory is tight, reduce `batch_size` and/or the number of training epochs. Avoid large image sizes. Keep the encoders frozen to limit compute.


### Step 1: Setup and Imports (Provided)

In [1]:
# Provided: core imports
import os
import random
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from datasets import load_dataset
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
from tqdm.auto import tqdm
import numpy as np

# Reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
print("Device:", device)


Device: mps


### Step 2: Load CLIP and Enable PEFT [3 marks]

In [2]:
model_str = "openai/clip-vit-base-patch32"

# TODO: Load model and processor (similar to Part 1)
# Hints: use CLIPModel.from_pretrained / CLIPProcessor.from_pretrained
# Move model to device.
model = CLIPModel.from_pretrained(model_str).to(device)
processor = CLIPProcessor.from_pretrained(model_str)

# TODO: Freeze text and vision encoders; unfreeze only projection layers
def setup_peft_model(model: "CLIPModel") -> "CLIPModel":
    """Freeze encoders, train only projection layers."""
    # 1) Freeze encoders
    for p in model.text_model.parameters():
        p.requires_grad = False
    for p in model.vision_model.parameters():
        p.requires_grad = False
    # 2) Unfreeze model.text_projection and model.visual_projection
    for p in model.text_projection.parameters():
        p.requires_grad = True
    for p in model.visual_projection.parameters():
        p.requires_grad = True
    # 3) Return model
    return model

# After loading model and processor above, call:
model = setup_peft_model(model)

# Diagnostics (will be used by auto-checks below)
def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable, 100.0 * trainable / max(1, total)

# Uncomment after you implement:
total, trainable, pct = count_params(model)
print(f"Total parameters: {total:,}")
print(f"Trainable parameters: {trainable:,} ({pct:.2f}%)")


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Total parameters: 151,277,313
Trainable parameters: 655,361 (0.43%)


In [3]:
# === Auto-check: Step 2 ===
try:
    assert model is not None and processor is not None, "Model/processor not initialized."
    # encoders must be frozen
    enc_frozen = all(not p.requires_grad for p in model.text_model.parameters()) and all(not p.requires_grad for p in model.vision_model.parameters())
    assert enc_frozen, "Encoders should be frozen."
    # projections must be trainable
    assert any(p.requires_grad for p in model.text_projection.parameters()), "text_projection should be trainable."
    assert any(p.requires_grad for p in model.visual_projection.parameters()), "visual_projection should be trainable."
    print("Step 2 check passed.")
except Exception as e:
    print("Step 2 check failed:", e)


Step 2 check passed.


### Step 3: Data Preparation (Flickr8k) [4 marks]

In [4]:
# TODO: Load the Flickr8k dataset and split into train/test indices
# Use the split sizes: 6000 for training, remaining for test (~1000).
# Keep a fixed random seed for reproducibility.
# Hints: flickr = load_dataset("Naveengo/flickr8k") ; all_data = flickr["train"]
flickr = load_dataset("Naveengo/flickr8k")
all_data = flickr["train"]
num_train = 6000
all_indices = list(range(len(all_data)))
random.shuffle(all_indices)
train_indices = all_indices[:num_train]
test_indices = all_indices[num_train:]

# TODO: Implement a custom Dataset that packs pixel_values, input_ids, attention_mask
class FlickrDataset(Dataset):
    def __init__(self, hf_dataset, indices, processor):
        self.hf_dataset = hf_dataset
        self.indices = indices
        self.processor = processor

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        ex = self.hf_dataset[self.indices[idx]]
        out = self.processor(images = ex["image"], text = ex["text"],
                             padding = "max_length", truncation = True, return_tensors = "pt")
        return {k: v.squeeze(0) for k, v in out.items()}

# TODO: Create train_dataset/test_dataset and DataLoaders (batch_size <= 32)
train_dataset = FlickrDataset(all_data, train_indices, processor)
test_dataset = FlickrDataset(all_data, test_indices, processor)
train_loader = DataLoader(train_dataset, batch_size = 32, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 32, shuffle = False)


In [5]:
# === Auto-check: Step 3 ===
try:
    assert train_dataset is not None and test_dataset is not None, "Datasets not built."
    assert train_loader is not None and test_loader is not None, "DataLoaders not built."
    # quick sample
    sample = next(iter(train_loader))
    for k in ("pixel_values","input_ids","attention_mask"):
        assert k in sample, f"Missing key in batch: {k}"
    bs = sample["pixel_values"].shape[0]
    assert bs >= 1, "Empty batch."
    print("Step 3 check passed.")
except Exception as e:
    print("Step 3 check failed:", e)


Step 3 check passed.


### Step 4: Contrastive Losses and Recall@K [7 marks: InfoNCE 3m, Triplet 3m, Recall@K: 1m]

In [6]:
# TODO: Implement InfoNCE loss
def info_nce_loss(image_features: torch.Tensor, text_features: torch.Tensor, temperature: float = 0.07):
    """Return scalar loss."""
    image_features = F.normalize(image_features, dim = 1)
    text_features = F.normalize(text_features, dim = 1)
    logits = image_features @ text_features.T / temperature
    labels = torch.arange(image_features.size(0), device = logits.device)
    return (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2

# TODO: Implement simple triplet loss with cosine distance and a rolled negative
def triplet_loss(image_features: torch.Tensor, text_features: torch.Tensor, margin: float = 0.2):
    image_features = F.normalize(image_features, dim = 1)
    text_features = F.normalize(text_features, dim = 1)
    neg = torch.roll(text_features, shifts = 1, dims = 0)
    return F.triplet_margin_with_distance_loss(anchor = image_features, positive = text_features,
        negative = neg, distance_function = lambda x, y: 1.0 - F.cosine_similarity(x, y), margin = margin)

# TODO: Implement Recall@K for I2T and T2I in a light-weight way (no gradients)
@torch.no_grad()
def calculate_recall_at_k(model, test_dataset, device, k_values=(1,5,10)):
    """Return a dict with keys: I2T_R@1, I2T_R@5, I2T_R@10, T2I_R@1, T2I_R@5, T2I_R@10."""
    model.eval()
    I2C = defaultdict(list)
    images = {}
    for i in test_dataset.indices:
        item = test_dataset.hf_dataset[i]
        I2C[i].append(item["text"])
        images.setdefault(i, item["image"])
    image_embeds = torch.cat([model.get_image_features(**processor(images = i, return_tensors = "pt").to(device))
                              for i in images.values()])
    image_embeds = F.normalize(image_embeds, dim = -1)
    texts = [c for idx in images for c in I2C[idx]]
    text_embeds = model.get_text_features(**processor(text = texts, padding = True, truncation = True, return_tensors = "pt").to(device))
    text_embeds = F.normalize(text_embeds, dim=-1)
    T2I = [i for i, idx in enumerate(images) for _ in range(len(I2C[idx]))]
    n_i = image_embeds.size(0)
    tpi = [[] for _ in range(n_i)]
    for c, i in enumerate(T2I):
        tpi[i].append(c)
    sim = image_embeds @ text_embeds.T
    metrics = {}
    for k in k_values:
        topk = torch.topk(sim, k, dim = 1).indices
        correct = sum(bool(set(topk[i].tolist()) & set(tpi[i])) for i in range(n_i))
        metrics[f"I2T_R@{k}"] = correct / n_i
    n_t = sim.T.size(0)
    for k in k_values:
        topk = torch.topk(sim.T, k, dim = 1).indices
        correct = sum(T2I[t] in set(topk[t].tolist()) for t in range(n_t))
        metrics[f"T2I_R@{k}"] = correct / n_t
    return metrics

In [7]:
# === Auto-check: Step 4 ===
try:
    # create random embeddings to sanity-check loss shapes
    a = torch.randn(8, 512)
    b = torch.randn(8, 512)
    l1 = info_nce_loss(a, b)
    l2 = triplet_loss(a, b)
    assert torch.is_tensor(l1) and l1.ndim == 0, "InfoNCE must return a scalar tensor."
    assert torch.is_tensor(l2) and l2.ndim == 0, "Triplet loss must return a scalar tensor."
    print("Loss functions basic shape check passed.")
except Exception as e:
    print("Step 4 loss checks failed:", e)


Loss functions basic shape check passed.


### Step 5: Trainer [4 marks]

In [8]:
# TODO: Implement a small trainer that accepts a loss function
class RetrievalTrainer:
    def __init__(self, model, train_loader, test_dataset, device):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.test_dataset = test_dataset
        self.device = device
        self.writer = None

    def train(self, epochs: int, lr: float, loss_fn, loss_name: str):
        # Hints: Use AdamW over trainable params only; normalize embeddings; log scalar loss each epoch
        optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr = lr)
        self.writer = SummaryWriter()
        for epoch in range(epochs):
            self.model.train()
            total_loss = 0.0
            for batch in self.train_loader:
                pixel_values = batch["pixel_values"].to(self.device)
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                image_features = self.model.get_image_features(pixel_values = pixel_values)
                text_features = self.model.get_text_features(input_ids = input_ids, attention_mask = attention_mask)
                loss = loss_fn(image_features, text_features)
                optim.zero_grad()
                loss.backward()
                optim.step()
                total_loss += loss.item() * pixel_values.size(0)
            avg_loss = total_loss / len(self.train_loader.dataset)
            print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")
            self.writer.add_scalar(f"Loss/{loss_name}", avg_loss, epoch + 1)
            self.evaluate(epoch + 1)
        self.writer.close()

    def evaluate(self, epoch: int):
        # Hints: call calculate_recall_at_k and print metrics; log to TensorBoard as well
        metrics = calculate_recall_at_k(self.model, self.test_dataset, self.device)
        for k, v in metrics.items():
            self.writer.add_scalar(f"Metrics/{k}", v, epoch)
            print(f"{k}: {v:.4f}")
        return metrics


In [9]:
# === Auto-check: Step 5 (structure only) ===
try:
    t = RetrievalTrainer
    assert callable(getattr(t, "train")), "train() missing."
    assert callable(getattr(t, "evaluate")), "evaluate() missing."
    print("Step 5 structure check passed.")
except Exception as e:
    print("Step 5 check failed:", e)


Step 5 structure check passed.


### Step 6: Experiments [Partial Provided, 2 marks for completion]

In [10]:
# Baseline: zero-shot metrics without fine-tuning
# Note: This will be slower on CPU. Consider reducing test set size temporarily for a quick smoke test.
print("--- Calculating Zero-Shot Baseline Performance ---")
try:
    baseline_metrics = calculate_recall_at_k(model, test_dataset, device)
    for k, v in baseline_metrics.items():
        print(f"{k}: {v:.4f}")
except Exception as e:
    print("Baseline evaluation failed (ok if you have not implemented calculate_recall_at_k yet):", e)


--- Calculating Zero-Shot Baseline Performance ---
I2T_R@1: 0.4653
I2T_R@5: 0.7303
I2T_R@10: 0.8183
T2I_R@1: 0.4558
T2I_R@5: 0.7044
T2I_R@10: 0.8101


In [11]:
# Experiment 1: InfoNCE fine-tuning
# Reset model to fresh PEFT state
try:
    model = CLIPModel.from_pretrained(model_str).to(device)
    model = setup_peft_model(model)
    trainer = RetrievalTrainer(model, train_loader, test_dataset, device)
    learning_rate = 1e-5
    num_epochs = 3  # keep small for resource limits
    trainer.train(epochs=num_epochs, lr=learning_rate, loss_fn=info_nce_loss, loss_name="InfoNCE")
except Exception as e:
    print("InfoNCE experiment not executed:", e)


Epoch 1, Loss: 1.3476
I2T_R@1: 0.4400
I2T_R@5: 0.7241
I2T_R@10: 0.8097
T2I_R@1: 0.4720
T2I_R@5: 0.7542
T2I_R@10: 0.8326
Epoch 2, Loss: 0.9132
I2T_R@1: 0.4472
I2T_R@5: 0.7288
I2T_R@10: 0.8173
T2I_R@1: 0.4720
T2I_R@5: 0.7504
T2I_R@10: 0.8350
Epoch 3, Loss: 0.7118
I2T_R@1: 0.4538
I2T_R@5: 0.7293
I2T_R@10: 0.8221
T2I_R@1: 0.4720
T2I_R@5: 0.7465
T2I_R@10: 0.8374


In [12]:
# Experiment 2: Triplet fine-tuning
try:
    model = CLIPModel.from_pretrained(model_str).to(device)
    model = setup_peft_model(model)
    trainer = RetrievalTrainer(model, train_loader, test_dataset, device)
    learning_rate = 1e-5
    num_epochs = 3
    trainer.train(epochs=num_epochs, lr=learning_rate, loss_fn=triplet_loss, loss_name="Triplet")
except Exception as e:
    print("Triplet experiment not executed:", e)


Epoch 1, Loss: 0.0373
I2T_R@1: 0.4543
I2T_R@5: 0.7250
I2T_R@10: 0.8140
T2I_R@1: 0.4605
T2I_R@5: 0.7255
T2I_R@10: 0.8216
Epoch 2, Loss: 0.0206
I2T_R@1: 0.4529
I2T_R@5: 0.7274
I2T_R@10: 0.8211
T2I_R@1: 0.4601
T2I_R@5: 0.7322
T2I_R@10: 0.8197
Epoch 3, Loss: 0.0156
I2T_R@1: 0.4596
I2T_R@5: 0.7331
I2T_R@10: 0.8274
T2I_R@1: 0.4586
T2I_R@5: 0.7303
T2I_R@10: 0.8168


### Step 7: Analysis[5 marks]

**Your Task:** After completing the code above, answer the following analysis questions. *Write your answers in the Markdown cell below AND in the PDF report.*

1. **Performance Comparison:** Compare your fine-tuned results (InfoNCE vs Triplet) with the zero-shot baseline. Which improved Recall@K more on Flickr8k and why?  
2. **Convergence:** Show training loss plots (TensorBoard) and discuss which objective converged more smoothly.  
3. **Learning Rate Sweep:** Try `1e-4` and `1e-6` for InfoNCE. Summarize changes in Recall@K and loss curves.  
4. **Overfitting:** Train longer with the better setup and discuss if/when overfitting arises.  
5. **Qualitative:** One success and one failure case. Speculate on causes.

**(Your analysis goes here)**
