# LSDL CUB, Homework 1. Robust fine-tuning of CLIP [10 pts]

Your main goal in this home assignment is to implement the [Lipsum-FT](https://openreview.net/attachment?id=2JF8mJRJ7M&name=pdf) method for robust fine-tuning of CLIP.

Rules for the assignment:

- We will be using the same dataset, [DomainNet](https://ai.bu.edu/M3SDA/) (also can be found [here](https://huggingface.co/datasets/wltjr1007/DomainNet)), as in the original paper. Use the **Real** domain as in-distribution (ID) data and the rest of domains as out-of-distribution (OOD) data. Use the **Real** train split for training and test splits for evaluation on all of the domains.

- If training takes too much time, you may select a subset (i.e., 50% or 33%) of training data instead of the full split.

- `ViT-B/16` backbone is recommended.

- In order to **pass the assignment**, you need to plot a Pareto front (i.e., ID-OOD plot like Figure 1 from this [paper](https://arxiv.org/pdf/2109.01903)). We will be using 5 OOD domains, so you need to plot 5 Pareto fronts, one for each distribution shift.

- You may use any code from the [seminar](https://github.com/isadrtdinov/lsdl-cub/tree/2025/week01-finetune/seminar). Also, you may code your training pipelines either in pure PyTorch or combinine it with [huggingface](https://huggingface.co/) libraries. Additionally, you may find [`clip`](https://github.com/openai/CLIP) and [`wise-ft`](https://github.com/mlfoundations/wise-ft) repos useful.

- Do not use the implemention from the authors or any publicly available implementations of this method.

- It will be much easier to check your assigment if you maintain a clear code structure (e.g., put different blocks of code into separate files, add necessary coments, etc).

## 1. Zero-shot model [1 pts]

Create a zero-shot model on top of pre-trained CLIP and evaluate it on **Real** test set (ID accuracy) and on 5 distribution shifts: **Clipart**, **Infograph**, **Painting**, **Quickdraw**, **Sketch** (OOD accuracy).

In [None]:
# Install CLIP if not already installed
import subprocess
import sys

try:
    import clip
except ImportError:
    print("Installing CLIP...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/openai/CLIP.git"])
    import clip

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
import numpy as np
from tqdm import tqdm

# Load pre-trained CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)

# Load DomainNet dataset
dataset = load_dataset("wltjr1007/DomainNet")

# Define domains
id_domain = "real"
ood_domains = ["clipart", "infograph", "painting", "quickdraw", "sketch"]

# Get class names from the dataset
class_names = dataset["real_test"].features["label"].names

# Create text prompts for zero-shot classification
text_prompts = [f"a photo of a {class_name}" for class_name in class_names]
text_inputs = clip.tokenize(text_prompts).to(device)

# Get text features
with torch.no_grad():
    text_features = model.encode_text(text_inputs)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

def evaluate_zero_shot(dataset_split, domain_name):
    """Evaluate zero-shot CLIP on a dataset split"""
    correct = 0
    total = 0
    
    # Create dataloader
    dataloader = DataLoader(dataset_split, batch_size=32, shuffle=False)
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating {domain_name}"):
            images = torch.stack([preprocess(img.convert("RGB")) for img in batch["image"]]).to(device)
            labels = torch.tensor(batch["label"]).to(device)
            
            # Get image features
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            
            # Calculate similarities and predictions
            similarities = (image_features @ text_features.T)
            predictions = similarities.argmax(dim=-1)
            
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    
    accuracy = correct / total
    return accuracy

# Evaluate on ID (Real) test set
id_accuracy = evaluate_zero_shot(dataset["real_test"], "Real")
print(f"ID Accuracy (Real): {id_accuracy:.4f}")

# Evaluate on OOD domains
ood_accuracies = {}
for domain in ood_domains:
    ood_accuracy = evaluate_zero_shot(dataset[f"{domain}_test"], domain.capitalize())
    ood_accuracies[domain] = ood_accuracy
    print(f"OOD Accuracy ({domain.capitalize()}): {ood_accuracy:.4f}")

print("\nZero-shot evaluation completed!")

## 2. Regular fine-tuning [2 pts]

Now, fine-tune the whole image encoder on the **Real** train split. Use the zero-shot classification head as an initialization for the last linear layer. Calculate ID and OOD accuracy of this model.

In [None]:
from torch.utils.data import DataLoader
import copy

import torch.nn as nn
import torch.optim as optim

# Create a fine-tunable model by replacing the final layer
class FineTunedCLIP(nn.Module):
    def __init__(self, clip_model, num_classes):
        super().__init__()
        self.visual = clip_model.visual
        self.num_classes = num_classes
        
        # Get the dimension of visual features
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224).to(device)
            visual_dim = self.visual(dummy_input).shape[-1]
        
        # Create classification head initialized with text features
        self.classifier = nn.Linear(visual_dim, num_classes)
        
        # Initialize with zero-shot text features (transposed)
        with torch.no_grad():
            self.classifier.weight.data = text_features.clone()
            self.classifier.bias.data.zero_()
    
    def forward(self, x):
        features = self.visual(x)
        return self.classifier(features)

# Create fine-tuned model
ft_model = FineTunedCLIP(model, len(class_names)).to(device)

# Training setup
optimizer = optim.AdamW(ft_model.parameters(), lr=1e-5, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
num_epochs = 5

# Prepare training data
train_dataset = dataset["real_train"]
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

print("Starting regular fine-tuning...")

# Training loop
ft_model.train()
for epoch in range(num_epochs):
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images = torch.stack([preprocess(img.convert("RGB")) for img in batch["image"]]).to(device)
        labels = torch.tensor(batch["label"]).to(device)
        
        optimizer.zero_grad()
        
        outputs = ft_model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    train_acc = 100. * correct / total
    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}: Loss: {avg_loss:.4f}, Train Acc: {train_acc:.2f}%")

def evaluate_finetuned(model, dataset_split, domain_name):
    """Evaluate fine-tuned model on a dataset split"""
    model.eval()
    correct = 0
    total = 0
    
    dataloader = DataLoader(dataset_split, batch_size=32, shuffle=False)
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating {domain_name}"):
            images = torch.stack([preprocess(img.convert("RGB")) for img in batch["image"]]).to(device)
            labels = torch.tensor(batch["label"]).to(device)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    
    accuracy = correct / total
    return accuracy

# Evaluate fine-tuned model
print("\nEvaluating fine-tuned model...")

# ID accuracy
ft_id_accuracy = evaluate_finetuned(ft_model, dataset["real_test"], "Real")
print(f"Fine-tuned ID Accuracy (Real): {ft_id_accuracy:.4f}")

# OOD accuracies
ft_ood_accuracies = {}
for domain in ood_domains:
    ft_ood_accuracy = evaluate_finetuned(ft_model, dataset[f"{domain}_test"], domain.capitalize())
    ft_ood_accuracies[domain] = ft_ood_accuracy
    print(f"Fine-tuned OOD Accuracy ({domain.capitalize()}): {ft_ood_accuracy:.4f}")

print("\nRegular fine-tuning completed!")

# Store results for later comparison
results = {
    "zero_shot": {"id": id_accuracy, "ood": ood_accuracies},
    "fine_tuned": {"id": ft_id_accuracy, "ood": ft_ood_accuracies}
}

## 3. Lipsum-FT [4 pts]

Implement the Lipsum-FT method from the [paper](https://openreview.net/attachment?id=2JF8mJRJ7M&name=pdf). You may use the hyperparameters from the paper. Calculate ID and OOD accuracy of this model.

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿

## 4. WiSE-FT [1.5 pts]

Create a [weight-space ensemble](https://arxiv.org/pdf/2109.01903) for an ordinary fine-tuned model. Calculate ID and OOD accuracy of this model.

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿

## 5. WiSE for Lipsum-FT [1.5 pts]

Create a [weight-space ensemble](https://arxiv.org/pdf/2109.01903) for Lipsum-FT model. Calculate ID and OOD accuracy of this model.

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿

## Pareto front (ID-OOD plots)

Now, when all the methods are trained, put them on a single Pareto front.

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿

## Bonus: Fine-tuning data-efficiency [2 pts]

Create a data-efficiency plot similar to Figure 6 from the [CLIP paper](https://arxiv.org/pdf/2103.00020) for the considered fine-tuning methods (regular vs. Lipsum-FT). On the horizontal axis you will have the number of fine-tuning samples per class (in logarithmic scale) and on the vertical axis &mdash; ID accuracy. What conclusions can be drawn about the data-efficiency of these methods?

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿