# LSDL CUB, Homework 1. Robust fine-tuning of CLIP [10 pts]

Your main goal in this home assignment is to implement the [Lipsum-FT](https://openreview.net/attachment?id=2JF8mJRJ7M&name=pdf) method for robust fine-tuning of CLIP.

Rules for the assignment:

- We will be using the same dataset, [DomainNet](https://ai.bu.edu/M3SDA/) (also can be found [here](https://huggingface.co/datasets/wltjr1007/DomainNet)), as in the original paper. Use the **Real** domain as in-distribution (ID) data and the rest of domains as out-of-distribution (OOD) data. Use the **Real** train split for training and test splits for evaluation on all of the domains.

- If training takes too much time, you may select a subset (i.e., 50% or 33%) of training data instead of the full split.

- `ViT-B/16` backbone is recommended.

- In order to **pass the assignment**, you need to plot a Pareto front (i.e., ID-OOD plot like Figure 1 from this [paper](https://arxiv.org/pdf/2109.01903)). We will be using 5 OOD domains, so you need to plot 5 Pareto fronts, one for each distribution shift.

- You may use any code from the [seminar](https://github.com/isadrtdinov/lsdl-cub/tree/2025/week01-finetune/seminar). Also, you may code your training pipelines either in pure PyTorch or combinine it with [huggingface](https://huggingface.co/) libraries. Additionally, you may find [`clip`](https://github.com/openai/CLIP) and [`wise-ft`](https://github.com/mlfoundations/wise-ft) repos useful.

- Do not use the implemention from the authors or any publicly available implementations of this method.

- It will be much easier to check your assigment if you maintain a clear code structure (e.g., put different blocks of code into separate files, add necessary coments, etc).

## 1. Zero-shot model [1 pts]

Create a zero-shot model on top of pre-trained CLIP and evaluate it on **Real** test set (ID accuracy) and on 5 distribution shifts: **Clipart**, **Infograph**, **Painting**, **Quickdraw**, **Sketch** (OOD accuracy).

In [1]:
import subprocess
import sys

try:
    import clip
except ImportError:
    print("Installing CLIP...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/openai/CLIP.git"])
    import clip

import torch
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from datasets import load_dataset
import numpy as np
from tqdm import tqdm

# Load pre-trained CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model, preprocess = clip.load("ViT-B/16", device=device)

# Load DomainNet dataset
print("Loading DomainNet dataset...")
dataset = load_dataset("wltjr1007/DomainNet")

# Check available splits
print("Available dataset splits:", list(dataset.keys()))

# The dataset structure is different - let's check the actual structure
sample = dataset["train"][0]
print("Sample from train:", sample)

# Check domain mapping
domain_mapping = {
    0: "real",
    1: "clipart",
    2: "infograph",
    3: "quickdraw",
    4: "painting",
    5: "sketch"
}

# Get all unique domains in the dataset
unique_domains = set()
for split in ['train', 'test']:
    sample_batch = dataset[split][:100]  # Check first 100 samples
    for domain_id in sample_batch['domain']:
        unique_domains.add(domain_id)

print("Unique domain IDs found:", sorted(unique_domains))

# Map domain IDs to names
available_domain_names = []
for domain_id in sorted(unique_domains):
    if domain_id in domain_mapping:
        available_domain_names.append(domain_mapping[domain_id])
    else:
        available_domain_names.append(f"domain_{domain_id}")

print("Available domains:", available_domain_names)

# Define domains
id_domain = "real"  # domain_id = 0
ood_domains = [d for d in available_domain_names if d != id_domain]

print(f"ID domain: {id_domain}")
print(f"OOD domains: {ood_domains}")

# Get class names
class_names = dataset["train"].features["label"].names
print(f"Number of classes: {len(class_names)}")
print(f"First 10 classes: {class_names[:10]}")

# Create text prompts for zero-shot classification
text_prompts = [f"a photo of a {class_name}" for class_name in class_names]
text_inputs = clip.tokenize(text_prompts).to(device)

# Get text features
print("Encoding text prompts...")
with torch.no_grad():
    text_features = model.encode_text(text_inputs)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

def filter_dataset_by_domain(dataset_split, domain_name, max_samples=None):
    """Filter dataset by domain and optionally limit number of samples"""
    # Get domain ID from name
    domain_id = None
    for did, dname in domain_mapping.items():
        if dname == domain_name:
            domain_id = did
            break

    if domain_id is None:
        print(f"Domain {domain_name} not found in mapping")
        return None

    # Filter by domain
    filtered_indices = []
    for i, domain in enumerate(dataset_split['domain']):
        if domain == domain_id:
            filtered_indices.append(i)
            if max_samples and len(filtered_indices) >= max_samples:
                break

    print(f"Found {len(filtered_indices)} samples for domain {domain_name}")

    if len(filtered_indices) == 0:
        return None

    # Create subset
    return Subset(dataset_split, filtered_indices)

def custom_collate_fn(batch):
    """Custom collate function to handle PIL images"""
    images = []
    labels = []

    for item in batch:
        # Convert PIL image to tensor using preprocess
        image = preprocess(item['image'].convert('RGB'))
        images.append(image)
        labels.append(item['label'])

    return {
        'image': torch.stack(images),
        'label': torch.tensor(labels)
    }

def evaluate_zero_shot(dataset_subset, domain_name):
    """Evaluate zero-shot CLIP on a dataset subset"""
    if dataset_subset is None:
        print(f"No dataset available for {domain_name}")
        return 0.0

    correct = 0
    total = 0

    # Create dataloader with custom collate function
    dataloader = DataLoader(dataset_subset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)

    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating {domain_name}"):
            try:
                images = batch['image'].to(device)
                labels = batch['label'].to(device)

                # Get image features
                image_features = model.encode_image(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                # Calculate similarities and predictions
                similarities = (image_features @ text_features.T)
                predictions = similarities.argmax(dim=-1)

                correct += (predictions == labels).sum().item()
                total += labels.size(0)
            except Exception as e:
                print(f"Error processing batch in {domain_name}: {e}")
                continue

    accuracy = correct / total if total > 0 else 0.0
    return accuracy

# Evaluate on ID (Real) domain
print(f"\nEvaluating on ID domain ({id_domain})...")
real_test_subset = filter_dataset_by_domain(dataset["test"], id_domain, max_samples=5000)  # Limit for speed

if real_test_subset is not None:
    id_accuracy = evaluate_zero_shot(real_test_subset, "Real")
    print(f"ID Accuracy (Real): {id_accuracy:.4f}")
else:
    print("Could not find Real domain test data")
    id_accuracy = 0.0

# Evaluate on OOD domains
ood_accuracies = {}
print(f"\nEvaluating on OOD domains...")
for domain in ood_domains:
    try:
        domain_test_subset = filter_dataset_by_domain(dataset["test"], domain, max_samples=5000)  # Limit for speed
        if domain_test_subset is not None:
            ood_accuracy = evaluate_zero_shot(domain_test_subset, domain.capitalize())
            ood_accuracies[domain] = ood_accuracy
            print(f"OOD Accuracy ({domain.capitalize()}): {ood_accuracy:.4f}")
        else:
            print(f"Could not find test data for domain: {domain}")
            ood_accuracies[domain] = 0.0
    except Exception as e:
        print(f"Error evaluating {domain}: {e}")
        ood_accuracies[domain] = 0.0

print("\n" + "="*50)
print("ZERO-SHOT RESULTS SUMMARY")
print("="*50)
print(f"ID Accuracy (Real): {id_accuracy:.4f}")
print("OOD Accuracies:")
for domain, acc in ood_accuracies.items():
    print(f"  {domain.capitalize()}: {acc:.4f}")

avg_ood_accuracy = np.mean(list(ood_accuracies.values())) if ood_accuracies else 0.0
print(f"Average OOD Accuracy: {avg_ood_accuracy:.4f}")

print("\nZero-shot evaluation completed!")

# Store variables for next cells
real_train_subset = filter_dataset_by_domain(dataset["train"], id_domain, max_samples=10000)  # Limit training data
if real_train_subset is not None:
    print(f"Training subset size: {len(real_train_subset)}")

    # Store additional variables needed for training
    train_subset = real_train_subset
    real_test_subset = real_test_subset
else:
    train_subset = None
    print("Warning: Could not find Real domain training data")

Installing CLIP...
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-qkucxyf8


  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-qkucxyf8


  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->clip==1.0)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torc

100%|███████████████████████████████████████| 335M/335M [00:09<00:00, 38.8MiB/s]


Loading DomainNet dataset...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00003.parquet:   0%|          | 0.00/759M [00:00<?, ?B/s]

data/train-00001-of-00003.parquet:   0%|          | 0.00/7.21G [00:00<?, ?B/s]

data/train-00002-of-00003.parquet:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/5.60G [00:00<?, ?B/s]

Generating train split:   0%|          | 0/409832 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/176743 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/27 [00:00<?, ?it/s]

Available dataset splits: ['train', 'test']
Sample from train: {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=300x300 at 0x7DE207F2FBD0>, 'label': 122, 'domain': 3, 'image_path': 'quickdraw/flying_saucer/4503624984035328.png'}
Unique domain IDs found: [3]
Available domains: ['quickdraw']
ID domain: real
OOD domains: ['quickdraw']
Number of classes: 345
First 10 classes: ['aircraft_carrier', 'airplane', 'alarm_clock', 'ambulance', 'angel', 'animal_migration', 'ant', 'anvil', 'apple', 'arm']
Encoding text prompts...

Evaluating on ID domain (real)...
Found 5000 samples for domain real


Evaluating Real: 100%|██████████| 313/313 [00:40<00:00,  7.80it/s]


ID Accuracy (Real): 0.7170

Evaluating on OOD domains...
Found 5000 samples for domain quickdraw


Evaluating Quickdraw: 100%|██████████| 313/313 [00:27<00:00, 11.41it/s]


OOD Accuracy (Quickdraw): 0.1254

ZERO-SHOT RESULTS SUMMARY
ID Accuracy (Real): 0.7170
OOD Accuracies:
  Quickdraw: 0.1254
Average OOD Accuracy: 0.1254

Zero-shot evaluation completed!
Found 10000 samples for domain real
Training subset size: 10000


## 2. Regular fine-tuning [2 pts]

Now, fine-tune the whole image encoder on the **Real** train split. Use the zero-shot classification head as an initialization for the last linear layer. Calculate ID and OOD accuracy of this model.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import random

# Define evaluation function for fine-tuned model
def evaluate_finetuned(ft_model, dataset_subset, domain_name):
    if dataset_subset is None:
        print(f"No dataset available for {domain_name}")
        return 0.0

    correct = 0
    total = 0

    # Use custom collate
    dataloader = DataLoader(dataset_subset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)

    ft_model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating {domain_name}"):
            try:
                images = batch['image'].to(device, dtype=torch.float32)
                labels = batch['label'].to(device)

                # Forward pass
                outputs = ft_model(images)
                predictions = outputs.argmax(dim=-1)

                correct += (predictions == labels).sum().item()
                total += labels.size(0)
            except Exception as e:
                print(f"Error processing batch in {domain_name}: {e}")
                continue

    accuracy = correct / total if total > 0 else 0.0
    ft_model.train()  # Set back to train mode
    return accuracy

# Fine-Tuning implementation
class FTCLIP(nn.Module):
    def __init__(self, clip_model, num_classes, text_features):
        super().__init__()
        self.visual = clip_model.visual.float()  # Ensure float32, make trainable
        self.num_classes = num_classes

        # Get the dimension of visual features
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224, dtype=torch.float32).to(device)
            visual_dim = self.visual(dummy_input).shape[-1]

        # Create classification head
        self.classifier = nn.Linear(visual_dim, num_classes).float()

        # Initialize with zero-shot text features
        with torch.no_grad():
            self.classifier.weight.data = text_features.clone().to(dtype=torch.float32)
            self.classifier.bias.data.zero_()

    def forward(self, x):
        features = self.visual(x)
        return self.classifier(features)

# Check if we have training data
if train_subset is None:
    print("Error: No training data available for fine-tuning.")
    ft_id_accuracy = id_accuracy
    ft_ood_accuracies = ood_accuracies.copy()
    print("Using zero-shot results as fine-tuning results (no training performed)")
else:
    # Ensure the CLIP model and text_features are in float32
    model = model.float()
    text_features = text_features.to(dtype=torch.float32)

    # Debugging: Check dtypes
    print(f"text_features dtype: {text_features.dtype}")

    # Create fine-tuned model
    ft_model = FTCLIP(model, len(class_names), text_features).to(device)

    # Debugging: Check model dtypes
    print(f"Visual conv1 weight dtype: {ft_model.visual.conv1.weight.dtype}")
    print(f"Classifier weight dtype: {ft_model.classifier.weight.dtype}")

    # Training setup
    optimizer = optim.AdamW(ft_model.parameters(), lr=1e-5, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()
    num_epochs = 3  # Same as before for consistency

    # Use custom collate function for DataLoader
    train_dataloader = DataLoader(train_subset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)

    print("Starting fine-tuning...")
    print(f"Training on {len(train_subset)} samples for {num_epochs} epochs")

    # Training loop
    ft_model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0

        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            try:
                images = batch['image'].to(device, dtype=torch.float32)
                labels = batch['label'].to(device)

                optimizer.zero_grad()
                outputs = ft_model(images)

                loss = criterion(outputs, labels)

                loss.backward()
                optimizer.step()

                # Statistics
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

            except Exception as e:
                print(f"Error in training batch: {e}")
                continue

        train_acc = 100. * correct / total if total > 0 else 0
        avg_loss = total_loss / len(train_dataloader) if len(train_dataloader) > 0 else 0

        print(f"Epoch {epoch+1}: Total Loss: {avg_loss:.4f}, Train Acc: {train_acc:.2f}%")

    # Evaluate fine-tuned model
    print("\nEvaluating fine-tuned model...")

    # ID accuracy
    ft_id_accuracy = evaluate_finetuned(ft_model, real_test_subset, "Real")
    print(f"Fine-Tuned ID Accuracy (Real): {ft_id_accuracy:.4f}")

    # OOD accuracies
    ft_ood_accuracies = {}
    for domain in ood_domains:
        try:
            domain_test_split = dataset["test"]
            domain_test_subset = filter_dataset_by_domain(domain_test_split, domain, max_samples=5000)
            if domain_test_subset is not None:
                ft_ood_accuracy = evaluate_finetuned(ft_model, domain_test_subset, domain.capitalize())
                ft_ood_accuracies[domain] = ft_ood_accuracy
                print(f"Fine-Tuned OOD Accuracy ({domain.capitalize()}): {ft_ood_accuracy:.4f}")
            else:
                ft_ood_accuracies[domain] = 0.0
        except Exception as e:
            print(f"Error evaluating {domain}: {e}")
            ft_ood_accuracies[domain] = 0.0

    print("\nFine-tuning completed!")

print("\n" + "="*50)
print("FINE-TUNING RESULTS SUMMARY")
print("="*50)
print(f"Fine-Tuned ID Accuracy: {ft_id_accuracy:.4f}")
print("Fine-Tuned OOD Accuracies:")
for domain, acc in ft_ood_accuracies.items():
    print(f"  {domain.capitalize()}: {acc:.4f}")

text_features dtype: torch.float32
Visual conv1 weight dtype: torch.float32
Classifier weight dtype: torch.float32
Starting fine-tuning...
Training on 10000 samples for 3 epochs


Epoch 1/3: 100%|██████████| 625/625 [07:20<00:00,  1.42it/s]


Epoch 1: Total Loss: 2.0358, Train Acc: 83.77%


Epoch 2/3: 100%|██████████| 625/625 [07:28<00:00,  1.39it/s]


Epoch 2: Total Loss: 0.5819, Train Acc: 92.49%


Epoch 3/3: 100%|██████████| 625/625 [07:27<00:00,  1.40it/s]


Epoch 3: Total Loss: 0.2968, Train Acc: 95.45%

Evaluating fine-tuned model...


Evaluating Real: 100%|██████████| 313/313 [01:28<00:00,  3.54it/s]


Fine-Tuned ID Accuracy (Real): 0.7714
Found 5000 samples for domain quickdraw


Evaluating Quickdraw: 100%|██████████| 313/313 [01:18<00:00,  3.99it/s]

Fine-Tuned OOD Accuracy (Quickdraw): 0.0660

Fine-tuning completed!

FINE-TUNING RESULTS SUMMARY
Fine-Tuned ID Accuracy: 0.7714
Fine-Tuned OOD Accuracies:
  Quickdraw: 0.0660





## 3. Lipsum-FT [4 pts]

Implement the Lipsum-FT method from the [paper](https://openreview.net/attachment?id=2JF8mJRJ7M&name=pdf). You may use the hyperparameters from the paper. Calculate ID and OOD accuracy of this model.

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import copy

# Define evaluation function for fine-tuned model (same as before)
def evaluate_finetuned(ft_model, dataset_subset, domain_name):
    if dataset_subset is None:
        print(f"No dataset available for {domain_name}")
        return 0.0

    correct = 0
    total = 0

    # Use custom collate
    dataloader = DataLoader(dataset_subset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)

    ft_model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating {domain_name}"):
            try:
                images = batch['image'].to(device, dtype=torch.float32)
                labels = batch['label'].to(device)

                # Forward pass
                outputs = ft_model(images)
                predictions = outputs.argmax(dim=-1)

                correct += (predictions == labels).sum().item()
                total += labels.size(0)
            except Exception as e:
                print(f"Error processing batch in {domain_name}: {e}")
                continue

    accuracy = correct / total if total > 0 else 0.0
    ft_model.train()  # Set back to train mode
    return accuracy

# Lipsum-FT implementation
class LipsumFTCLIP(nn.Module):
    def __init__(self, clip_model, num_classes, text_features):
        super().__init__()
        self.visual = clip_model.visual.float()  # Trainable visual encoder
        self.visual0 = copy.deepcopy(clip_model.visual).float().eval()  # Frozen copy
        self.num_classes = num_classes

        # Get the dimension of visual features
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224, dtype=torch.float32).to(device)
            visual_dim = self.get_features(dummy_input).shape[-1]

        # Create classification head
        self.classifier = nn.Linear(visual_dim, num_classes).float()

        # Initialize with zero-shot text features
        with torch.no_grad():
            self.classifier.weight.data = text_features.clone().to(dtype=torch.float32)
            self.classifier.bias.data.zero_()

    def get_features(self, x):
        return self.visual(x)

    def get_features0(self, x):
        return self.visual0(x)

    def forward(self, x):
        features = self.get_features(x)
        return self.classifier(features)

# Check if we have training data for Lipsum-FT
if train_subset is None:
    print("Error: No training data available for Lipsum-FT.")
    lipsum_id_accuracy = id_accuracy
    lipsum_ood_accuracies = ood_accuracies.copy()
    results["lipsum_ft"] = {"id": lipsum_id_accuracy, "ood": lipsum_ood_accuracies}
    print("Using zero-shot results as Lipsum-FT results (no training performed)")
else:
    # Ensure the CLIP model and text_features are in float32
    model = model.float()
    text_features = text_features.to(dtype=torch.float32)

    # Debugging: Check text_features dtype
    print(f"text_features dtype: {text_features.dtype}")

    # Create Lipsum-FT model
    lipsum_model = LipsumFTCLIP(model, len(class_names), text_features).to(device)

    # Debugging: Check model dtypes
    print(f"Visual conv1 weight dtype: {lipsum_model.visual.conv1.weight.dtype}")
    print(f"Classifier weight dtype: {lipsum_model.classifier.weight.dtype}")

    # Training setup
    optimizer_lipsum = optim.AdamW(lipsum_model.parameters(), lr=1e-5, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()
    num_epochs = 3
    lambda_reg = 1.0  # Assumed value based on similar methods; adjust if needed
    M = 80  # Number of random texts
    L = 8  # Text length
    context_length = 77
    sot_token = 49406
    eot_token = 49407
    vocab_size = 49408  # CLIP vocab size

    train_dataloader = DataLoader(train_subset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)

    print("Starting Lipsum-FT training...")
    print(f"Training on {len(train_subset)} samples for {num_epochs} epochs")
    print(f"Lipsum hyperparameters: M={M}, L={L}, lambda_reg={lambda_reg}")

    # Training loop
    lipsum_model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0

        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            try:
                images = batch['image'].to(device, dtype=torch.float32)
                labels = batch['label'].to(device)

                optimizer_lipsum.zero_grad()

                # Forward pass
                outputs = lipsum_model(images)

                # CE loss
                ce_loss = criterion(outputs, labels)

                # Generate random tokens
                random_tokens = torch.zeros((M, context_length), dtype=torch.long, device=device)
                for i in range(M):
                    random_tokens[i, 0] = sot_token
                    random_tokens[i, 1:1+L] = torch.randint(0, vocab_size - 2, (L,), device=device)  # Avoid SOT/EOT
                    random_tokens[i, 1+L] = eot_token

                # Get random text features (without normalization)
                with torch.no_grad():
                    random_text_feats = model.encode_text(random_tokens)

                # Get features (without normalization)
                features = lipsum_model.get_features(images)
                features0 = lipsum_model.get_features0(images)

                # Compute v and v0
                v = features @ random_text_feats.T
                v0 = features0 @ random_text_feats.T

                # Regularization term
                reg = ((v - v0) ** 2).sum(dim=1) / (2 * M)
                reg = reg.mean()  # Average over batch

                # Total loss
                loss = ce_loss + lambda_reg * reg

                loss.backward()
                optimizer_lipsum.step()

                # Statistics
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

            except Exception as e:
                print(f"Error in Lipsum training batch: {e}")
                continue

        train_acc = 100. * correct / total if total > 0 else 0
        avg_loss = total_loss / len(train_dataloader) if len(train_dataloader) > 0 else 0

        print(f"Epoch {epoch+1}: Total Loss: {avg_loss:.4f}, Train Acc: {train_acc:.2f}%")

    # Evaluate Lipsum-FT model
    print("\nEvaluating Lipsum-FT model...")

    # ID accuracy
    lipsum_id_accuracy = evaluate_finetuned(lipsum_model, real_test_subset, "Real")
    print(f"Lipsum-FT ID Accuracy (Real): {lipsum_id_accuracy:.4f}")

    # OOD accuracies
    lipsum_ood_accuracies = {}
    for domain in ood_domains:
        try:
            domain_test_split = dataset["test"]
            if domain_test_split is not None:
                domain_test_subset = filter_dataset_by_domain(domain_test_split, domain, max_samples=5000)
                lipsum_ood_accuracy = evaluate_finetuned(lipsum_model, domain_test_subset, domain.capitalize())
                lipsum_ood_accuracies[domain] = lipsum_ood_accuracy
                print(f"Lipsum-FT OOD Accuracy ({domain.capitalize()}): {lipsum_ood_accuracy:.4f}")
            else:
                lipsum_ood_accuracies[domain] = 0.0
        except Exception as e:
            print(f"Error evaluating {domain}: {e}")
            lipsum_ood_accuracies[domain] = 0.0

    print("\nLipsum-FT training completed!")

print("\n" + "="*50)
print("LIPSUM-FT RESULTS SUMMARY")
print("="*50)
print(f"Lipsum-FT ID Accuracy: {lipsum_id_accuracy:.4f}")
print("Lipsum-FT OOD Accuracies:")
for domain, acc in lipsum_ood_accuracies.items():
    print(f"  {domain.capitalize()}: {acc:.4f}")

text_features dtype: torch.float32
Visual conv1 weight dtype: torch.float32
Classifier weight dtype: torch.float32
Starting Lipsum-FT training...
Training on 10000 samples for 3 epochs
Lipsum hyperparameters: M=80, L=8, lambda_reg=1.0


Epoch 1/3: 100%|██████████| 625/625 [23:39<00:00,  2.27s/it]


Epoch 1: Total Loss: 0.8013, Train Acc: 97.07%


Epoch 2/3: 100%|██████████| 625/625 [22:41<00:00,  2.18s/it]


Epoch 2: Total Loss: 0.7166, Train Acc: 96.73%


Epoch 3/3: 100%|██████████| 625/625 [23:12<00:00,  2.23s/it]


Epoch 3: Total Loss: 0.5655, Train Acc: 96.96%

Evaluating Lipsum-FT model...


Evaluating Real: 100%|██████████| 313/313 [01:33<00:00,  3.36it/s]


Lipsum-FT ID Accuracy (Real): 0.7846
Found 5000 samples for domain quickdraw


Evaluating Quickdraw: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]

Lipsum-FT OOD Accuracy (Quickdraw): 0.0588

Lipsum-FT training completed!

LIPSUM-FT RESULTS SUMMARY
Lipsum-FT ID Accuracy: 0.7846
Lipsum-FT OOD Accuracies:
  Quickdraw: 0.0588





## 4. WiSE-FT [1.5 pts]

Create a [weight-space ensemble](https://arxiv.org/pdf/2109.01903) for an ordinary fine-tuned model. Calculate ID and OOD accuracy of this model.

In [4]:
import torch
import torch.nn as nn
import copy

# Define evaluation function for fine-tuned model (same as before)
def evaluate_finetuned(ft_model, dataset_subset, domain_name):
    if dataset_subset is None:
        print(f"No dataset available for {domain_name}")
        return 0.0

    correct = 0
    total = 0

    # Use custom collate
    dataloader = DataLoader(dataset_subset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)

    ft_model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating {domain_name}"):
            try:
                images = batch['image'].to(device, dtype=torch.float32)
                labels = batch['label'].to(device)

                # Forward pass
                outputs = ft_model(images)
                predictions = outputs.argmax(dim=-1)

                correct += (predictions == labels).sum().item()
                total += labels.size(0)
            except Exception as e:
                print(f"Error processing batch in {domain_name}: {e}")
                continue

    accuracy = correct / total if total > 0 else 0.0
    return accuracy  # No need to set train mode as we don't train here

# Assume ft_model is the fine-tuned model from the fine-tuning section
# Create zero-shot model using the same architecture
zs_model = FTCLIP(model, len(class_names), text_features.to(dtype=torch.float32)).to(device)

# Create WiSE-FT model by weight interpolation
alpha = 0.5  # As per general recommendation in the paper

wise_model = FTCLIP(model, len(class_names), text_features.to(dtype=torch.float32)).to(device)

# Interpolate weights
state_dict_ft = ft_model.state_dict()
state_dict_zs = zs_model.state_dict()
state_dict_wise = {}

for key in state_dict_ft.keys():
    state_dict_wise[key] = alpha * state_dict_ft[key] + (1 - alpha) * state_dict_zs[key]

wise_model.load_state_dict(state_dict_wise)

print(f"Created WiSE-FT model with alpha={alpha}")

# Evaluate WiSE-FT model
print("\nEvaluating WiSE-FT model...")

# ID accuracy
wise_id_accuracy = evaluate_finetuned(wise_model, real_test_subset, "Real")
print(f"WiSE-FT ID Accuracy (Real): {wise_id_accuracy:.4f}")

# OOD accuracies
wise_ood_accuracies = {}
for domain in ood_domains:
    try:
        domain_test_split = dataset["test"]
        if domain_test_split is not None:
            domain_test_subset = filter_dataset_by_domain(domain_test_split, domain, max_samples=5000)
            wise_ood_accuracy = evaluate_finetuned(wise_model, domain_test_subset, domain.capitalize())
            wise_ood_accuracies[domain] = wise_ood_accuracy
            print(f"WiSE-FT OOD Accuracy ({domain.capitalize()}): {wise_ood_accuracy:.4f}")
        else:
            wise_ood_accuracies[domain] = 0.0
    except Exception as e:
        print(f"Error evaluating {domain}: {e}")
        wise_ood_accuracies[domain] = 0.0

print("\nWiSE-FT evaluation completed!")

print("\n" + "="*50)
print("WiSE-FT RESULTS SUMMARY")
print("="*50)
print(f"WiSE-FT ID Accuracy: {wise_id_accuracy:.4f}")
print("WiSE-FT OOD Accuracies:")
for domain, acc in wise_ood_accuracies.items():
    print(f"  {domain.capitalize()}: {acc:.4f}")

Created WiSE-FT model with alpha=0.5

Evaluating WiSE-FT model...


Evaluating Real: 100%|██████████| 313/313 [01:33<00:00,  3.35it/s]


WiSE-FT ID Accuracy (Real): 0.7836
Found 5000 samples for domain quickdraw


Evaluating Quickdraw: 100%|██████████| 313/313 [01:22<00:00,  3.78it/s]

WiSE-FT OOD Accuracy (Quickdraw): 0.0590

WiSE-FT evaluation completed!

WiSE-FT RESULTS SUMMARY
WiSE-FT ID Accuracy: 0.7836
WiSE-FT OOD Accuracies:
  Quickdraw: 0.0590





## 5. WiSE for Lipsum-FT [1.5 pts]

Create a [weight-space ensemble](https://arxiv.org/pdf/2109.01903) for Lipsum-FT model. Calculate ID and OOD accuracy of this model.

In [5]:
import torch
import torch.nn as nn
import copy

# Define evaluation function for fine-tuned model (same as before)
def evaluate_finetuned(ft_model, dataset_subset, domain_name):
    if dataset_subset is None:
        print(f"No dataset available for {domain_name}")
        return 0.0

    correct = 0
    total = 0

    # Use custom collate
    dataloader = DataLoader(dataset_subset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)

    ft_model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating {domain_name}"):
            try:
                images = batch['image'].to(device, dtype=torch.float32)
                labels = batch['label'].to(device)

                # Forward pass
                outputs = ft_model(images)
                predictions = outputs.argmax(dim=-1)

                correct += (predictions == labels).sum().item()
                total += labels.size(0)
            except Exception as e:
                print(f"Error processing batch in {domain_name}: {e}")
                continue

    accuracy = correct / total if total > 0 else 0.0
    return accuracy  # No need to set train mode as we don't train here

# Assume lipsum_model is the trained Lipsum-FT model from the previous section
# Create zero-shot model using the same architecture as Lipsum-FT
zs_lipsum = LipsumFTCLIP(model, len(class_names), text_features.to(dtype=torch.float32)).to(device)

# Create WiSE-Lipsum model by weight interpolation
alpha = 0.5  # As per general recommendation in the paper

wise_lipsum_model = LipsumFTCLIP(model, len(class_names), text_features.to(dtype=torch.float32)).to(device)

# Interpolate weights
state_dict_lipsum = lipsum_model.state_dict()
state_dict_zs = zs_lipsum.state_dict()
state_dict_wise = {}

for key in state_dict_lipsum.keys():
    state_dict_wise[key] = alpha * state_dict_lipsum[key] + (1 - alpha) * state_dict_zs[key]

wise_lipsum_model.load_state_dict(state_dict_wise)

print(f"Created WiSE for Lipsum-FT model with alpha={alpha}")

# Evaluate WiSE-Lipsum model
print("\nEvaluating WiSE for Lipsum-FT model...")

# ID accuracy
wise_lipsum_id_accuracy = evaluate_finetuned(wise_lipsum_model, real_test_subset, "Real")
print(f"WiSE-Lipsum ID Accuracy (Real): {wise_lipsum_id_accuracy:.4f}")

# OOD accuracies
wise_lipsum_ood_accuracies = {}
for domain in ood_domains:
    try:
        domain_test_split = dataset["test"]
        if domain_test_split is not None:
            domain_test_subset = filter_dataset_by_domain(domain_test_split, domain, max_samples=5000)
            wise_lipsum_ood_accuracy = evaluate_finetuned(wise_lipsum_model, domain_test_subset, domain.capitalize())
            wise_lipsum_ood_accuracies[domain] = wise_lipsum_ood_accuracy
            print(f"WiSE-Lipsum OOD Accuracy ({domain.capitalize()}): {wise_lipsum_ood_accuracy:.4f}")
        else:
            wise_lipsum_ood_accuracies[domain] = 0.0
    except Exception as e:
        print(f"Error evaluating {domain}: {e}")
        wise_lipsum_ood_accuracies[domain] = 0.0

print("\nWiSE for Lipsum-FT evaluation completed!")
print("\n" + "="*50)
print("WiSE for Lipsum-FT RESULTS SUMMARY")
print("="*50)
print(f"WiSE-Lipsum ID Accuracy: {wise_lipsum_id_accuracy:.4f}")
print("WiSE-Lipsum OOD Accuracies:")
for domain, acc in wise_lipsum_ood_accuracies.items():
    print(f"  {domain.capitalize()}: {acc:.4f}")

Created WiSE for Lipsum-FT model with alpha=0.5

Evaluating WiSE for Lipsum-FT model...


Evaluating Real: 100%|██████████| 313/313 [01:30<00:00,  3.47it/s]


WiSE-Lipsum ID Accuracy (Real): 0.7846
Found 5000 samples for domain quickdraw


Evaluating Quickdraw: 100%|██████████| 313/313 [01:19<00:00,  3.94it/s]

WiSE-Lipsum OOD Accuracy (Quickdraw): 0.0596

WiSE for Lipsum-FT evaluation completed!

WiSE for Lipsum-FT RESULTS SUMMARY
WiSE-Lipsum ID Accuracy: 0.7846
WiSE-Lipsum OOD Accuracies:
  Quickdraw: 0.0596





## Pareto front (ID-OOD plots)

Now, when all the methods are trained, put them on a single Pareto front.

In [9]:
results = {
    "zero_shot": {"id": id_accuracy, "ood": ood_accuracies},
    "ft": {"id": ft_id_accuracy, "ood": ft_ood_accuracies},
    "lipsum_ft": {"id": lipsum_id_accuracy, "ood": lipsum_ood_accuracies},
    "wise_ft": {"id": wise_id_accuracy, "ood": wise_ood_accuracies},
    "wise_lipsum_ft": {"id": wise_lipsum_id_accuracy, "ood": wise_lipsum_ood_accuracies}
}

results

{'zero_shot': {'id': 0.717, 'ood': {'quickdraw': 0.1254}},
 'ft': {'id': 0.7714, 'ood': {'quickdraw': 0.066}},
 'lipsum_ft': {'id': 0.7846, 'ood': {'quickdraw': 0.0588}},
 'wise_ft': {'id': 0.7836, 'ood': {'quickdraw': 0.059}},
 'wise_lipsum_ft': {'id': 0.7846, 'ood': {'quickdraw': 0.0596}}}

In [8]:
import numpy as np
import json

# Function to compute Pareto front
def compute_pareto_front(points):
    """Compute the Pareto front from a list of (x, y) points.

    Args:
        points: List of [x, y] points where x is ID accuracy and y is average OOD accuracy.

    Returns:
        np.ndarray: Array of non-dominated points sorted by x (ID accuracy).
    """
    points = np.array(points)
    n = len(points)
    is_pareto = np.ones(n, dtype=bool)

    for i in range(n):
        for j in range(n):
            if i != j:
                # A point is dominated if another point has both higher x and y
                if (points[j, 0] >= points[i, 0]) and (points[j, 1] >= points[i, 1]) and \
                   (points[j, 0] > points[i, 0] or points[j, 1] > points[i, 1]):
                    is_pareto[i] = False
                    break

    # Sort Pareto points by x (ID accuracy) in ascending order
    pareto_points = points[is_pareto]
    if pareto_points.size == 0:
        print("Warning: No Pareto front points found.")
        return np.array([])
    
    sorted_indices = np.argsort(pareto_points[:, 0])
    pareto_points = pareto_points[sorted_indices]
    return pareto_points

# Extract ID and average OOD accuracies
methods = ["zero_shot", "ft", "lipsum_ft", "wise_ft", "wise_lipsum_ft"]
method_names = ["Zero-Shot", "Fine-Tuned", "Lipsum-FT", "WiSE-FT", "WiSE-Lipsum"]
points = []

for method in methods:
    if method in results:
        id_acc = results[method]["id"]
        ood_accs = results[method]["ood"]
        avg_ood_acc = np.mean(list(ood_accs.values())) if ood_accs else 0.0
        if id_acc < 0 or id_acc > 1 or avg_ood_acc < 0 or avg_ood_acc > 1:
            print(f"Warning: Invalid accuracies for {method} (ID: {id_acc:.4f}, Avg OOD: {avg_ood_acc:.4f}), skipping.")
            points.append([0.0, 0.0])
        else:
            points.append([id_acc, avg_ood_acc])
    else:
        print(f"Warning: Results for {method} not found, skipping.")
        points.append([0.0, 0.0])  # Placeholder to avoid errors

points = np.array(points)

# Compute Pareto front
pareto_points = compute_pareto_front(points)

# Prepare data for scatter plot
scatter_data = [
    {
        "x": float(points[i, 0]) * 100,  # Convert to percentage
        "y": float(points[i, 1]) * 100,  # Convert to percentage
        "label": method_names[i]
    }
    for i in range(len(methods)) if points[i, 0] > 0 or points[i, 1] > 0  # Skip invalid points
]

# Prepare Pareto front line data (connect Pareto points)
pareto_line_data = [
    {"x": float(p[0]) * 100, "y": float(p[1]) * 100}
    for p in pareto_points
]

# Generate Chart.js configuration
chart_config = {
    "type": "scatter",
    "data": {
        "datasets": [
            {
                "label": "Models",
                "data": scatter_data,
                "backgroundColor": "rgba(54, 162, 235, 0.8)",
                "borderColor": "rgba(54, 162, 235, 1)",
                "pointRadius": 8,
                "showLine": False
            },
            {
                "label": "Pareto Front",
                "data": pareto_line_data,
                "type": "line",
                "fill": False,
                "borderColor": "rgba(255, 99, 132, 1)",
                "backgroundColor": "rgba(255, 99, 132, 0.8)",
                "pointRadius": 0,
                "borderWidth": 2,
                "tension": 0
            }
        ]
    },
    "options": {
        "responsive": True,
        "plugins": {
            "title": {
                "display": True,
                "text": "Pareto Front: ID vs Average OOD Accuracy"
            },
            "tooltip": {
                "callbacks": {
                    "label": "function(context) { return context.dataset.label + ': ' + context.raw.label + ' (' + context.raw.x.toFixed(2) + '%, ' + context.raw.y.toFixed(2) + '%)'; }"
                }
            }
        },
        "scales": {
            "x": {
                "title": {
                    "display": True,
                    "text": "ID Accuracy (%)"
                },
                "min": 0,
                "max": 100
            },
            "y": {
                "title": {
                    "display": True,
                    "text": "Average OOD Accuracy (%)"
                },
                "min": 0,
                "max": 100
            }
        }
    }
}

# Print chart configuration
print("```chartjs")
print(json.dumps(chart_config, indent=2))
print("```")

# Print Pareto front points for reference
print("\nPareto Front Points (ID Accuracy, Avg OOD Accuracy):")
for i, p in enumerate(pareto_points):
    method_idx = np.where((points[:, 0] == p[0]) & (points[:, 1] == p[1]))[0]
    if method_idx.size > 0:
        method_idx = method_idx[0]
        print(f"{method_names[method_idx]}: ({p[0]*100:.2f}%, {p[1]*100:.2f}%)")
    else:
        print(f"Point ({p[0]*100:.2f}%, {p[1]*100:.2f}%): No matching method found")

```chartjs
{
  "type": "scatter",
  "data": {
    "datasets": [
      {
        "label": "Models",
        "data": [
          {
            "x": 71.7,
            "y": 12.540000000000001,
            "label": "Zero-Shot"
          },
          {
            "x": 77.14,
            "y": 6.6000000000000005,
            "label": "Fine-Tuned"
          },
          {
            "x": 78.46,
            "y": 5.88,
            "label": "Lipsum-FT"
          },
          {
            "x": 78.36,
            "y": 5.8999999999999995,
            "label": "WiSE-FT"
          },
          {
            "x": 78.46,
            "y": 5.96,
            "label": "WiSE-Lipsum"
          }
        ],
        "backgroundColor": "rgba(54, 162, 235, 0.8)",
        "borderColor": "rgba(54, 162, 235, 1)",
        "pointRadius": 8,
        "showLine": false
      },
      {
        "label": "Pareto Front",
        "data": [
          {
            "x": 71.7,
            "y": 12.540000000000001
          },


## Bonus: Fine-tuning data-efficiency [2 pts]

Create a data-efficiency plot similar to Figure 6 from the [CLIP paper](https://arxiv.org/pdf/2103.00020) for the considered fine-tuning methods (regular vs. Lipsum-FT). On the horizontal axis you will have the number of fine-tuning samples per class (in logarithmic scale) and on the vertical axis &mdash; ID accuracy. What conclusions can be drawn about the data-efficiency of these methods?