# Assignment: Fine-tuning CLIP for Image–Text Retrieval (Part 2)

**Name (Student ID):**  
1. __________________ (________________)
2. __________________ (________________)
3. __________________ (________________)
4. __________________ (________________)

**Total Marks: 25**  
**Rubric:** Each step shows its mark weight. Auto-check cells provide partial verification. Keep the notebook runnable on a personal computer.


## Objective

In Part 1, CLIP was used as a fixed feature extractor. In this part, you will **fine-tune** CLIP for image–text retrieval on a small paired dataset (**Flickr8k**) using a **parameter-efficient** approach. We will **freeze** the image and text encoders and train only the **projection layers**.

**Tasks**
- Load and split Flickr8k, wrap it with a PyTorch `Dataset` and `DataLoader`.
- Implement two contrastive objectives: **InfoNCE** and **Triplet Loss**.
- Fine-tune only the projection layers and compare with a **zero-shot** baseline.
- Evaluate with **Recall@K** (I2T and T2I).
- Run small experiments and answer short questions.


> **Note**  
> This notebook is designed to be runnable on a typical laptop. If memory is tight, reduce `batch_size` and/or the number of training epochs. Avoid large image sizes. Keep the encoders frozen to limit compute.


### Step 1: Setup and Imports (Provided)

In [None]:
# Provided: core imports
import os
import random
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from datasets import load_dataset
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
from tqdm.auto import tqdm
import numpy as np

# Reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
print("Device:", device)


### Step 2: Load CLIP and Enable PEFT [3 marks]

In [None]:
model_str = "openai/clip-vit-base-patch32"

# TODO: Load model and processor (similar to Part 1)
# Hints: use CLIPModel.from_pretrained / CLIPProcessor.from_pretrained
# Move model to device.
model = None
processor = None

# TODO: Freeze text and vision encoders; unfreeze only projection layers
def setup_peft_model(model: "CLIPModel") -> "CLIPModel":
    """Freeze encoders, train only projection layers."""
    # 1) Freeze encoders
    # 2) Unfreeze model.text_projection and model.visual_projection
    # 3) Return model
    raise NotImplementedError

# After loading model and processor above, call:
# model = setup_peft_model(model)

# Diagnostics (will be used by auto-checks below)
def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable, 100.0 * trainable / max(1, total)

# Uncomment after you implement:
# total, trainable, pct = count_params(model)
# print(f"Total parameters: {total:,}")
# print(f"Trainable parameters: {trainable:,} ({pct:.2f}%)")


In [None]:
# === Auto-check: Step 2 ===
try:
    assert model is not None and processor is not None, "Model/processor not initialized."
    # encoders must be frozen
    enc_frozen = all(not p.requires_grad for p in model.text_model.parameters()) and all(not p.requires_grad for p in model.vision_model.parameters())
    assert enc_frozen, "Encoders should be frozen."
    # projections must be trainable
    assert any(p.requires_grad for p in model.text_projection.parameters()), "text_projection should be trainable."
    assert any(p.requires_grad for p in model.visual_projection.parameters()), "visual_projection should be trainable."
    print("Step 2 check passed.")
except Exception as e:
    print("Step 2 check failed:", e)


### Step 3: Data Preparation (Flickr8k) [4 marks]

In [None]:
# TODO: Load the Flickr8k dataset and split into train/test indices
# Use the split sizes: 6000 for training, remaining for test (~1000).
# Keep a fixed random seed for reproducibility.
# Hints: flickr = load_dataset("Naveengo/flickr8k") ; all_data = flickr["train"]

flickr = None
train_indices, test_indices = None, None

# TODO: Implement a custom Dataset that packs pixel_values, input_ids, attention_mask
class FlickrDataset(Dataset):
    def __init__(self, hf_dataset, indices, processor):
        self.hf_dataset = hf_dataset
        self.indices = indices
        self.processor = processor

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # Map local index to original index
        raise NotImplementedError

# TODO: Create train_dataset/test_dataset and DataLoaders (batch_size <= 32)
train_dataset = None
test_dataset = None
train_loader = None
test_loader = None


In [None]:
# === Auto-check: Step 3 ===
try:
    assert train_dataset is not None and test_dataset is not None, "Datasets not built."
    assert train_loader is not None and test_loader is not None, "DataLoaders not built."
    # quick sample
    sample = next(iter(train_loader))
    for k in ("pixel_values","input_ids","attention_mask"):
        assert k in sample, f"Missing key in batch: {k}"
    bs = sample["pixel_values"].shape[0]
    assert bs >= 1, "Empty batch."
    print("Step 3 check passed.")
except Exception as e:
    print("Step 3 check failed:", e)


### Step 4: Contrastive Losses and Recall@K [7 marks: InfoNCE 3m, Triplet 3m, Recall@K: 1m]

In [None]:
# TODO: Implement InfoNCE loss
def info_nce_loss(image_features: torch.Tensor, text_features: torch.Tensor, temperature: float = 0.07):
    """Return scalar loss."""
    raise NotImplementedError

# TODO: Implement simple triplet loss with cosine distance and a rolled negative
def triplet_loss(image_features: torch.Tensor, text_features: torch.Tensor, margin: float = 0.2):
    raise NotImplementedError

# TODO: Implement Recall@K for I2T and T2I in a light-weight way (no gradients)
@torch.no_grad()
def calculate_recall_at_k(model, test_dataset, device, k_values=(1,5,10)):
    """Return a dict with keys: I2T_R@1, I2T_R@5, I2T_R@10, T2I_R@1, T2I_R@5, T2I_R@10."""
    raise NotImplementedError


In [None]:
# === Auto-check: Step 4 ===
try:
    # create random embeddings to sanity-check loss shapes
    a = torch.randn(8, 512)
    b = torch.randn(8, 512)
    l1 = info_nce_loss(a, b)
    l2 = triplet_loss(a, b)
    assert torch.is_tensor(l1) and l1.ndim == 0, "InfoNCE must return a scalar tensor."
    assert torch.is_tensor(l2) and l2.ndim == 0, "Triplet loss must return a scalar tensor."
    print("Loss functions basic shape check passed.")
except Exception as e:
    print("Step 4 loss checks failed:", e)


### Step 5: Trainer [4 marks]

In [None]:
# TODO: Implement a small trainer that accepts a loss function
class RetrievalTrainer:
    def __init__(self, model, train_loader, test_dataset, device):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.test_dataset = test_dataset
        self.device = device
        self.writer = None

    def train(self, epochs: int, lr: float, loss_fn, loss_name: str):
        # Hints: Use AdamW over trainable params only; normalize embeddings; log scalar loss each epoch
        raise NotImplementedError

    def evaluate(self, epoch: int):
        # Hints: call calculate_recall_at_k and print metrics; log to TensorBoard as well
        raise NotImplementedError


In [None]:
# === Auto-check: Step 5 (structure only) ===
try:
    t = RetrievalTrainer
    assert callable(getattr(t, "train")), "train() missing."
    assert callable(getattr(t, "evaluate")), "evaluate() missing."
    print("Step 5 structure check passed.")
except Exception as e:
    print("Step 5 check failed:", e)


### Step 6: Experiments [Partial Provided, 2 marks for completion]

In [None]:
# Baseline: zero-shot metrics without fine-tuning
# Note: This will be slower on CPU. Consider reducing test set size temporarily for a quick smoke test.
print("--- Calculating Zero-Shot Baseline Performance ---")
try:
    baseline_metrics = calculate_recall_at_k(model, test_dataset, device)
    for k, v in baseline_metrics.items():
        print(f"{k}: {v:.4f}")
except Exception as e:
    print("Baseline evaluation failed (ok if you have not implemented calculate_recall_at_k yet):", e)


In [None]:
# Experiment 1: InfoNCE fine-tuning
# Reset model to fresh PEFT state
try:
    model = CLIPModel.from_pretrained(model_str).to(device)
    model = setup_peft_model(model)
    trainer = RetrievalTrainer(model, train_loader, test_dataset, device)
    learning_rate = 1e-5
    num_epochs = 3  # keep small for resource limits
    trainer.train(epochs=num_epochs, lr=learning_rate, loss_fn=info_nce_loss, loss_name="InfoNCE")
except Exception as e:
    print("InfoNCE experiment not executed:", e)


In [None]:
# Experiment 2: Triplet fine-tuning
try:
    model = CLIPModel.from_pretrained(model_str).to(device)
    model = setup_peft_model(model)
    trainer = RetrievalTrainer(model, train_loader, test_dataset, device)
    learning_rate = 1e-5
    num_epochs = 3
    trainer.train(epochs=num_epochs, lr=learning_rate, loss_fn=triplet_loss, loss_name="Triplet")
except Exception as e:
    print("Triplet experiment not executed:", e)


### Step 7: Analysis[5 marks]

**Your Task:** After completing the code above, answer the following analysis questions. *Write your answers in the Markdown cell below AND in the PDF report.*

1. **Performance Comparison:** Compare your fine-tuned results (InfoNCE vs Triplet) with the zero-shot baseline. Which improved Recall@K more on Flickr8k and why?  
2. **Convergence:** Show training loss plots (TensorBoard) and discuss which objective converged more smoothly.  
3. **Learning Rate Sweep:** Try `1e-4` and `1e-6` for InfoNCE. Summarize changes in Recall@K and loss curves.  
4. **Overfitting:** Train longer with the better setup and discuss if/when overfitting arises.  
5. **Qualitative:** One success and one failure case. Speculate on causes.

**(Your analysis goes here)**
