# 0. Overview

Author: Darrin O'Brien, email: darrinobrien5@gmail.com

1. Fine-Tunes CLIP ViT-32 on the subset of images and labels equal to 0 from the MNIST dataset. 
2. Evaluates the performance of the fine-tuned model. 

# 1. Fine-Tuning Pipeline

## 1. Quick Installs for Essential Libraries

In [None]:
!pip install torch torchvision
!pip install -U transformers datasets
!pip install fifty regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install matplotlib
!pip install -U pillow
%matplotlib inline

In [None]:
!pip install --force-reinstall --no-cache-dir scipy datasets # Only needed within runpod environment

In [None]:
!pip install numpy==1.26.4 # only needed for runpod environment

## 2. Importing the Libraries

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import clip 
import numpy as np 
from datasets import load_dataset 
from tqdm import tqdm 
import copy

## 3. Setting Up

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)
clip_model = clip_model.float()
mnist = load_dataset("ylecun/mnist")
split = mnist["train"].train_test_split(test_size=0.2, seed=66)

train_dataset = split["train"].filter(lambda example: example["label"] == 0)
val_dataset = split["test"].filter(lambda example: example["label"] == 0)
test_dataset = mnist["test"].filter(lambda example: example["label"] == 0)

In [None]:
train_dataset.set_format(type="python", columns=["image", "label"])
val_dataset.set_format(type="python", columns=["image", "label"])
test_dataset.set_format(type="python", columns=["image", "label"])

In [None]:
def clip_collate_fn(batch):
    images=[]
    labels=[]

    for item in batch:
        img = preprocess(item["image"])
        images.append(img)
        labels.append(item["label"])
    
    images = torch.stack(images) # [Batch, channels, height, width] -> [64, 3, 224, 224]
    labels = torch.tensor(labels, dtype=torch.long) # 64 bit integer, pytorch tensor

    return {
        "pixel_values": images.to(device),
        "labels": labels.to(device)
    }

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

## 4. Wrapping CLIP for Classification

In [None]:
class CLIPClassifier0(nn.Module):
    def __init__(self, clip_model, num_classes=2):
        super().__init__()
        self.clip = clip_model
        self.classifier = nn.Linear(self.clip.visual.output_dim, num_classes) # 512 -> 2
    
    def forward(self, images):
        image_features = self.clip.encode_image(images)
        logits = self.classifier(image_features)
        return logits 

model = CLIPClassifier0(clip_model=clip_model).to(device)
model = model.float()

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-5)

criterion = nn.CrossEntropyLoss()

EPOCHS = 15
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader) * EPOCHS)

## 5. Fine-Tuning CLIP on MNIST 0 Labels

In [None]:
best_val_loss = float('inf')
best_epoch = -1

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")

    # Training
    model.train()
    total_train_loss = 0
    train_steps = 0

    for batch in tqdm(train_loader, desc="Training"):
        optimizer.zero_grad()

        images = batch["pixel_values"]
        labels = batch["labels"]

        logits = model(images)
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        train_steps += 1
    
    avg_train_loss = total_train_loss / train_steps

    # Validation
    model.eval()
    total_val_loss = 0
    val_steps = 0
    total_val_correct = 0
    total_val = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            images = batch["pixel_values"]
            labels = batch["labels"]

            logits = model(images)
            loss = criterion(logits, labels)

            total_val_loss += loss.item()
            val_steps += 1

            pred = logits.argmax(dim=1)

            total_val_correct += (pred == labels).sum().item()
            total_val += labels.size(0)
    
    avg_val_loss = total_val_loss / val_steps
    val_classification_acc = total_val_correct / total_val

    print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f} | Validation Accuracy {val_classification_acc:.4f}")
    print(f"Best Validation Loss: {best_val_loss:.4f} (Epoch {best_epoch})")
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = epoch + 1
        torch.save(model.state_dict(), "best_clip_mnist_0_fp32.pt")
    
    scheduler.step()


## 6. Evaluating Fine-Tuned Performance

In [None]:
base, _ = clip.load("ViT-B/32", device=device)
base = base.float()
model = CLIPClassifier0(clip_model=base).to(device)

fine_tuned, _ = clip.load("ViT-B/32", device=device)
fine_tuned = fine_tuned.float()
fine_tuned_model = CLIPClassifier0(clip_model=fine_tuned).to(device)
fine_tuned_model.load_state_dict(torch.load("best_clip_mnist_0_fp32"))

model.eval()
fine_tuned_model.eval()

loss_base = 0
total_base = 0
loss_fine_tuned = 0
total_fine_tuned = 0

correct_base = 0
correct_fine_tuned = 0
total_samples = 0

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        images = batch["pixel_values"]
        labels = batch["labels"]
        total_samples += labels.size(0)

        # Base Model
        logits_base = model(images)
        loss_base = criterion(logits_base, labels).item()
        total_base += 1
        correct_base += (logits_base.argmax(dim=1) == labels).sum().item()

        # Fine-Tuned Model
        logits_fine_tuned = fine_tuned_model(images)
        loss_fine_tuned = criterion(logits_fine_tuned, labels).item()
        total_fine_tuned += 1
        correct_fine_tuned += (logits_fine_tuned.argmax(dim=1) == labels).sum().item()

avg_base_loss = loss_base / total_base
avg_fine_tuned_loss = loss_fine_tuned / total_fine_tuned 

base_acc = correct_base / total_samples 
fine_tuned_acc = correct_fine_tuned / total_samples 

print(f"\nAverage Base Loss: {avg_base_loss:.4f}, Base Classification Accuracy: {base_acc:.4f}")
print(f"Average Fine-Tuned Loss: {avg_fine_tuned_loss:.4f}, Fine-Tuned Classification: {fine_tuned_acc:.4f}")