# Assignment: Exploring CLIP for Zero-Shot Learning and Linear Probing

**Total Marks: 30**

**Group Name:** __________________

**Student Name (Student ID):**

1. __________________ (__________________)
2. __________________ (__________________)
3. __________________ (__________________)
4. __________________ (__________________)

## Introduction

This assignment aims to deepen your understanding of large-scale vision-language models like CLIP (Contrastive Language-Image Pre-Training). You will explore two primary methods for applying CLIP to image classification tasks:

1.  **Zero-Shot Prediction**: Using CLIP's ability to associate images with arbitrary text descriptions without any task-specific training. You will implement the core prediction logic and investigate how different "prompts" and similarity metrics affect performance.
2.  **Linear Probing**: Using CLIP as a feature extractor. You will freeze the powerful pre-trained image encoder and **build and train** a simple linear classifier on top of these features.

You will compare the performance, trade-offs, and characteristics of these two approaches using the ImageNet dataset.

---

## **Question 1: Setting up CLIP** [6 Marks]

### **Step 1: Environment Setup and Model Loading (Provided)**

This section loads the pre-trained CLIP model and processor. This part is provided so you can focus on the core logic.

In [None]:
import os
import io
import json
from datasets import Features, Sequence, Value
from pathlib import Path
from typing import Optional, Tuple, List, Dict
from PIL import Image
from tqdm.auto import tqdm
import datasets
import torch
import torch.nn as nn
from torch.utils.data import Dataset as TorchDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split
from transformers import CLIPModel, CLIPProcessor
import torch.nn.functional as F
import numpy as np
import requests

CACHE_DIR = "./cache" # Cache directory to save CLIP features

model_str = "openai/clip-vit-base-patch32"
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

# Load the CLIP model with sdpa attention implementation for efficiency
model = CLIPModel.from_pretrained(
    pretrained_model_name_or_path=model_str, attn_implementation="sdpa"
).to(device)
model.eval()  # Set the model to evaluation mode

# Load the processor for preparing image and text data
processor = CLIPProcessor.from_pretrained(model_str)

print(f"Model loaded on device: {device}")

### **Step 2: A Quick Demonstration (Provided)**

This section loads a demonstration image from the web to show a simple example of what CLIP can do.

In [None]:
demo_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
print(f"Loading a demo image from: {demo_image_url}")
demo_image = Image.open(requests.get(demo_image_url, stream=True).raw).convert("RGB")
demo_image

### **Step 3: Implementing the Core CLIP Encoder [3 Marks]**

**Your Task:** Implement the core functionalities of the `CLIPEncoder` class. This class is central to performing zero-shot predictions.

In [None]:
class CLIPEncoder:
    """
    A class that encapsulates the CLIP model's encoding functionalities for images and text.
    """
    def __init__(self, model: CLIPModel, processor: CLIPProcessor, device: torch.device):
        self.model = model
        self.processor = processor
        self.device = device
    
    @torch.no_grad()
    def encode_image(self, images: List[Image.Image]) -> torch.Tensor:
        """
        Encodes a batch of PIL images into feature vectors.

        Instructions:
        1. Use `self.processor` to process the list of images. Ensure tensors are returned ("pt").
        2. Move the processed inputs to `self.device`.
        3. Use `self.model.get_image_features` to get the embeddings for the batch.

        Helpful Documentation ðŸ“š:
        - CLIPProcessor: https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPProcessor
        - CLIPModel.get_image_features: https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPModel.get_image_features
        """
        # ================================
        # TODO: YOUR CODE GOES HERE (1 Mark)
        # ================================
        raise NotImplementedError("Please implement the `encode_image` method for batch processing.")

    @torch.no_grad()
    def encode_text(self, text: List[str]) -> torch.Tensor:
        """
        Encodes a list of texts into feature vectors.
        
        Instructions:
        1. Use `self.processor` to tokenize the text. Ensure tensors are returned ("pt"),
           and that padding and truncation are enabled.
        2. Move the processed inputs to `self.device`.
        3. Use `self.model.get_text_features` to get the embeddings.
        
        Helpful Documentation ðŸ“š:
        - CLIPProcessor: https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPProcessor
        - CLIPModel.get_text_features: https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPModel.get_text_features
        """
        # ================================
        # TODO: YOUR CODE GOES HERE (1 Mark)
        # ================================
        raise NotImplementedError("Please implement the `encode_text` method.")

    @staticmethod
    @torch.no_grad()
    def predict(image_features: torch.Tensor, text_features: torch.Tensor, criterion: str) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Calculates the similarity between image and text features and returns predictions.

        (... documentation omitted for brevity ...)
        """
        assert criterion in ["dot_product", "cosine_similarity"], "Invalid criterion"
        
        # ================================
        # TODO: YOUR CODE GOES HERE (1 Mark)
        # ================================
        raise NotImplementedError("Please implement the `predict` method.")

# Initialize the encoder (this will use your implementation)
encoder = CLIPEncoder(model, processor, device)

### **Step 4: Test Case for the Encoder (Provided)**

This section provides a simple test to check if your encoder implementation is working correctly.

In [None]:
# A simple test on the demo image
try:
    demo_labels = ["a photo of a cat", "a photo of a dog", "a photo of two cats"]
    image_features = encoder.encode_image([demo_image]) 
    text_features = encoder.encode_text(demo_labels)
    predicted_idx, logits = encoder.predict(image_features, text_features, "cosine_similarity")
    print(f"Predicted Label: '{demo_labels[predicted_idx.item()]}'")
    print(f"Probabilities: {logits.cpu().numpy().flatten()}")
except NotImplementedError:
    print("One or more methods in CLIPEncoder are not yet implemented.")

### **Step 5: Data Loading and Preparation (Provided)**

In [None]:
def load_imagenet_val_clip_features(
    encoder: "CLIPEncoder",
    cache_path: str = os.path.join(CACHE_DIR, "imagenet_val_clip"),
    batch_size: int = 256,
    tqdm_total: int | None = 50000  # ImageNet-1k val has 50,000 images
) -> datasets.Dataset:
    """
    Stream the full ImageNet validation split, encode images into CLIP vectors,
    cache features+labels to disk, and return a Dataset that yields torch tensors.
    """
    if os.path.exists(cache_path):
        print("Loading CLIP features from local cache...")
        ds = datasets.Dataset.load_from_disk(cache_path)
        ds.set_format(type="torch", columns=["features", "label"])
        return ds

    print("Downloading ImageNet validation (streaming mode)...")
    stream_dataset = datasets.load_dataset(
        "benjamin-paine/imagenet-1k-256x256",
        split="validation",
        streaming=True
    )

    feats_chunks: list[torch.Tensor] = []
    labels_all: list[int] = []
    buf_images, buf_labels = [], []

    @torch.no_grad()
    def flush():
        nonlocal feats_chunks, labels_all, buf_images, buf_labels
        if not buf_images:
            return
        
        # Calling the refactored method.
        feats_on_gpu = encoder.encode_image(buf_images)
        feats = feats_on_gpu.to("cpu", dtype=torch.float32)
        
        feats_chunks.append(feats)
        labels_all.extend(buf_labels)
        buf_images, buf_labels = [], []

    print("Encoding full ImageNet validation set with CLIP...")
    pbar = tqdm(stream_dataset, desc="Encoding", total=tqdm_total)
    for row in pbar:
        buf_images.append(row["image"])
        buf_labels.append(int(row["label"]))
        if len(buf_images) >= batch_size:
            flush()
    
    flush() # Process any remaining images
    pbar.close()

    all_features = torch.cat(feats_chunks, dim=0).contiguous().numpy().astype(np.float32)
    labels_np = np.asarray(labels_all, dtype=np.int64)

    feature_dim = all_features.shape[1]
    schema = Features({
        "features": Sequence(Value("float32"), length=feature_dim),
        "label": Value("int64"),
    })
    
    ds = datasets.Dataset.from_dict({"features": all_features, "label": labels_np}, features=schema)

    ds.save_to_disk(cache_path)
    print("CLIP features cached locally!")

    ds.set_format(type="torch", columns=["features", "label"])
    return ds

# The code below will run your function and prepare the final training and test sets.
dataset = load_imagenet_val_clip_features(
    encoder,
    cache_path=os.path.join(CACHE_DIR, "imagenet_val_clip"),
    batch_size=256
)

with Path("./id_to_label.json").open("r") as f:
    idx_to_label = json.load(f)

class_labels = list(idx_to_label.values())

all_features, all_labels = zip(*[(item["features"], int(item["label"])) for item in dataset])
all_features = torch.stack(all_features).to(device)
all_labels = torch.tensor(all_labels, dtype=torch.int64).to(device)
train_images, test_images, train_labels, test_labels = train_test_split(
    all_features, all_labels, test_size=0.2, shuffle=True, random_state=42
)

print(f"Train set shape: {train_images.shape}")
print(f"Test set shape: {test_images.shape}")

### **Step 6: Implementing Evaluation Metrics [3 Marks]**

**Your Task:** Implement the `accuracy_score` and `f1_score` methods. Correctly calculating metrics is fundamental to any machine learning task.

In [None]:
class Metrics:
    """
    Provides static methods for calculating evaluation metrics.
    """
    @staticmethod
    def _ensure_tensor(x, device=None) -> torch.Tensor:
        # This is a helper function, provided for convenience.
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x)
        if device is not None:
            x = x.to(device)
        return x

    @staticmethod
    def accuracy_score(predictions: torch.Tensor | List, targets: torch.Tensor | List, device=None) -> float:
        """
        Calculates the classification accuracy. (2 Marks)
        (... documentation omitted for brevity ...)
        """
        # ================================
        # TODO: YOUR CODE GOES HERE
        # ================================
        raise NotImplementedError("Please implement the `accuracy_score` method.")

    @staticmethod
    def f1_score(predictions: torch.Tensor | List, targets: torch.Tensor | List,
                 num_classes: Optional[int] = None, device=None) -> float:
        """
        Calculates the macro-average F1 score. (2 Marks)
        (... documentation omitted for brevity ...)
        """
        # ================================
        # TODO: YOUR CODE GOES HERE
        # ================================
        raise NotImplementedError("Please implement the `f1_score` method.")

In [None]:
# This block will check your metric implementations against numpy/sklearn for a sample case.
try:
    # Generate some dummy data
    sample_preds = [0, 1, 2, 0, 1, 2, 0, 1, 2]
    sample_targets = [0, 2, 1, 0, 2, 1, 0, 0, 1]
    metrics = Metrics()
    
    # Test accuracy
    accuracy_torch = metrics.accuracy_score(sample_preds, sample_targets)
    accuracy_np = np.mean(np.array(sample_preds) == np.array(sample_targets))
    print(f"Accuracy (yours): {accuracy_torch:.4f}, Accuracy (numpy): {accuracy_np:.4f}")
    assert np.isclose(accuracy_torch, accuracy_np), "Accuracy mismatch!"

    # Test F1 score
    from sklearn.metrics import f1_score as f1_sklearn
    f1_torch = metrics.f1_score(sample_preds, sample_targets)
    f1_np = f1_sklearn(sample_targets, sample_preds, average='macro')
    print(f"F1 Score (yours): {f1_torch:.4f}, F1 Score (sklearn): {f1_np:.4f}")
    assert np.isclose(f1_torch, f1_np), "F1 Score mismatch!"

    print("\nâœ… All metrics passed the sanity check!")
except (NotImplementedError, AssertionError) as e:
    print(f"Metrics check failed: {e}")

--- 

## **Question 2: Zero-Shot Classification Analysis** [6 Marks]

### **Step 7: The Art of Prompt Engineering** [2 Marks]

**Your Task:** Investigate how different text prompts affect classification accuracy.

In [None]:
def evaluate_zero_shot_prompts(
    prompt_templates: List[str], 
    class_labels: List[str], 
    test_images: torch.Tensor, 
    test_labels: torch.Tensor, 
    encoder: "CLIPEncoder",
    metrics: "Metrics"
) -> Dict[str, float]:
    results = {}
    print("Evaluating different prompts...")
    for template in tqdm(prompt_templates, desc="Prompts"):
        # ================================
        # TODO: YOUR CODE GOES HERE
        # ================================
        raise NotImplementedError("Please implement the prompt evaluation loop.")
    return results

prompts_to_test = [
    "a photo of a {}", 
    "a picture of a {}", 
    "an image of a {}", 
    "{}",
    "a wild animal: {}",
    "Beneath the fading sunset, the curious child wandered along the winding path. This is an image of {}."
]

prompt_accuracies = evaluate_zero_shot_prompts(prompts_to_test, class_labels, test_images, test_labels, encoder, metrics)
for prompt, accuracy in prompt_accuracies.items():
    print(f'Prompt: "{prompt}" -> Accuracy: {accuracy:.4f}')

### **Step 8: Analysis for Prompt Engineering** [1 Mark]
**(Your analysis goes here)**

### **Step 9: Similarity Metric Comparison** [2 Marks]

**Your Task:** Compare the performance of `dot_product` and `cosine_similarity`.

In [None]:
best_prompt, accuracy_dot, accuracy_cos = None, None, None
"""
Instructions:
1. Choose your best prompt from Step 7.
2. Generate and encode the corresponding text labels.
3. Predict using 'dot_product' and calculate accuracy.
4. Predict using 'cosine_similarity' and calculate accuracy.
5. Print the results clearly.
"""
# ================================
# TODO: YOUR CODE GOES HERE [1 Mark]
# ================================

print(f"Using prompt: '{best_prompt}'\n")
print(f"Accuracy with dot_product: {accuracy_dot:.4f}")
print(f"Accuracy with cosine_similarity: {accuracy_cos:.4f}")

### **Step 10: Analysis for Similarity Metric Comparison** [1 Mark]
**(Your analysis goes here)**

--- 

## **Question 3: Linear Probing** [8 Marks]

Linear Probing is a techinique of training a simple classification network by freezing the backbone. I.e., during this stage, the encoder (CLIP in our case) is frozen and not updated, while only the downstream classification head is trained.

### **Step 11: Implementing the Classification Network** [2 Marks]

**Your Task:** Implement the `ClsNetwork`.

In [None]:
class ResidualFCNBlock(nn.Module): # Provided
    def __init__(self, input_dim: int, hidden_dim: int, p: float = 0.2):
        super().__init__()
        self.norm = nn.LayerNorm(input_dim)
        self.ff = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(hidden_dim, input_dim),
            nn.Dropout(p),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.ff(self.norm(x))

class ClsNetwork(nn.Module):
    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        hidden_dim = 4 * input_dim
        self.res_blk = ResidualFCNBlock(input_dim, hidden_dim, p=0.2)
        # ================================
        # TODO: YOUR CODE GOES HERE (1 Mark)
        # ================================
        raise NotImplementedError("Please define the layers in ClsNetwork `__init__`")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # ================================
        # TODO: YOUR CODE GOES HERE (1 Mark)
        # ================================
        raise NotImplementedError("Please implement the `forward` method of ClsNetwork.")

classifier_model = ClsNetwork(input_dim=512, num_classes=len(class_labels)).to(device)

### **Step 12: Implementing the Training Step** [2 Marks]

**Your Task:** Implement the core training step inside the `Trainer.train` method.

In [None]:
class Trainer:
    def __init__(self, model, train_loader, test_loader, device):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.writer = SummaryWriter('runs/linear_probe_experiment')

    def train(self, epochs: int, lr: float):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            self.model.train()
            total_loss = 0.0
            pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
            for features, labels in pbar:
                features, labels = features.to(self.device), labels.to(self.device)
                # ================================
                # TODO: YOUR CODE GOES HERE (1 Mark): The core training step
                # ================================
                raise NotImplementedError("Please implement the core training step.")
                
                total_loss += loss.item()
                pbar.set_postfix(loss=f"{total_loss / (pbar.n + 1):.4f}")

            self.writer.add_scalar('Loss/train', total_loss / len(self.train_loader), epoch)
            self.evaluate(epoch)
        self.writer.close()
        print("\nTraining finished!")

    def evaluate(self, epoch: int):
        self.model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for features, labels in self.test_loader:
                features, labels = features.to(self.device), labels.to(self.device)
                outputs = self.model(features)
                preds = outputs.argmax(dim=1)
                all_preds.extend(preds.cpu().tolist())
                all_labels.extend(labels.cpu().tolist())
        metrics = Metrics()
        acc = metrics.accuracy_score(all_preds, all_labels)
        f1  = metrics.f1_score(all_preds, all_labels)
        print(f"Epoch {epoch+1} Test Metrics: accuracy={acc:.4f}, f1={f1:.4f}")
        self.writer.add_scalar('Accuracy/test', acc, epoch)
        self.writer.add_scalar('F1_Score/test', f1, epoch)

class ClsDataset(TorchDataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels
    def __len__(self) -> int:
        return len(self.features)
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.features[idx], self.labels[idx]

train_dataset = ClsDataset(train_images, train_labels)
test_dataset = ClsDataset(test_images, test_labels)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# ================================
# TODO: YOUR CODE (1 Mark): Instantiate and run the trainer
# ================================
raise NotImplementedError("Please instantiate and call the Trainer.")

### **Step 13: Performance Analysis and Comparison** [4 Marks]

**Your Task:** After completing the code above, answer the following analysis questions. *Write your answers in the Markdown cell below AND in the PDF report.*

1.  Report the **final test accuracy and F1-score** of your trained linear classifier. (1 Mark)
2.  Create a summary table comparing the performance (Accuracy) of Zero-Shot Classification and the Trained Linear Classifier. (1 Mark)
3.  Analyze the results. Why does the trained classifier significantly outperform the zero-shot approach? (1 Mark)
4.  Discuss the **trade-offs** between these two methods (Performance, Data Requirement, Flexibility, Computational Cost). (1 Mark)

**(Your analysis goes here)**

## **Question 4: Paper Reading** [10 Marks]

### **Step 14: Questions for Paper Reading** 

Read the paper titled ``Learning Transferable Visual Models From Natural Language Supervision" (https://arxiv.org/abs/2103.00020) and answer the following questions. The answers must be included in your report.

1. What does the CLIP model learn? [1 mark]
2. Explain in at most 3 sentences what "contrastive learning" means. [1 mark]
3. Why do you think CLIPâ€™s zero-shot performance can sometimes surpass supervised baselines? What does this say about the generalization abilities of representation learning? [2 marks]
4. How do the labels in CLIP-based zero-shot classification differ from traditional image classification models from supervised learning? [2 marks]
5. What do you think CLIP falls short in, and why do you think this happens? [4 marks]

