##0. Ideation
1. Find-tune CLIP on MNIST
2. Load fine-tuned CLIP and base-CLIP
3. Begin Inference
4. Register Hooks/Intercept Activations to get First layer CLS embedding, last layer CLS embedding, final projected CLIP image embedding
5. Use a classifier head to keep track of the first layer, last layer, and final embedding vectors from the base model and fine-tuned model
    1. base_first -> fine_first (Early low-level shift) - Not as likely. After getting transformation matrix, multiply base_first by it and continue through with rest of layers
    2. base_first -> fine_last (Deep vision changes) - Not as likely. After getting transformation matrix, multiply base_first by it and go through to the projection matrix.
    3. base_first -> fine_final (Total task-level transformation) - Possible. After getting transformation matrix, multiply base_first by it and go straight to classification.
    4. base_last -> fine_last (High-level internal difference) - Possible. After getting transformation matrix, multiply base_last by it and go straight to the projection matrix.
    5. base_last -> fine_final (Final classification alignment) - Possible. After getting transformation matrix, multiply base_last by it and go straight to classification.
    6. **base_final -> fine_final (End-to-end latent mapping) - Most Important. After getting transformation matrix, multiply base_final by it and go straight to classification. **
6. Average out each one, and use least squares regression to calculate it in relation to the fine-tuned model. Use affine transformation to include a bias term to shift to issues. If it doesn’t go well, can also switch to ridge regression to see if it matches it better.
7. Then, use the multiplicative matrix transformation and employ that into the base model itself’s layer right before classification
8. Then, run testing on it and benchmark performance.


##1. Quick Installs for Essential Libraries

In [1]:
!pip install torch torchvision
!pip install -U transformers datasets
!pip install fifty regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install matplotlib
!pip install -U pillow
%matplotlib inline

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-zy3bnwo1
  Running command git clone --filter=blob:none --quiet https://github.com

In [2]:
!pip install --force-reinstall --no-cache-dir scipy datasets # Only needed within runpod environment

Collecting scipy
  Downloading scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting numpy<2.5,>=1.23.5 (from scipy)
  Downloading numpy-2.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m162.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting filelock (from datasets)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-20.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets)
  Downloading pandas-2.3.0-cp31

In [3]:
!pip install numpy==1.26.4 # only needed for runpod environment

Collecting numpy==1.26.4
  Using cached numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.2.6
    Uninstalling numpy-2.2.6:
      Successfully uninstalled numpy-2.2.6
Successfully installed numpy-1.26.4
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


##2. Importing the Libraries

In [4]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import clip
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

##3. Setting Device, Preparing Data, and Loading CLIP

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device) # https://github.com/openai/CLIP
clip_model.float()
mnist = load_dataset("ylecun/mnist") # https://huggingface.co/datasets/ylecun/mnist
split = mnist["train"].train_test_split(test_size=0.2, seed=66)

train_dataset = split["train"] # 48,000 examples (direct training data from training set)
val_dataset = split["test"] # 12,000 examples (validation set split from training set)
test_dataset = mnist["test"] # 10,000 examples (direct test set)

100%|████████████████████████████████████████| 338M/338M [00:01<00:00, 181MiB/s]


README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/15.6M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/2.60M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

##4. Pre-Processing MNIST for CLIP

In [6]:
def add_text_labels(example):
  '''
  Creates a new column of "text_label"
  Processes the MNIST labels into CLIP's prediction format.
  Adds the processed labels to "text_label"
  '''
  return {"text_label": f"a photo of the number {example['label']}"}

train_dataset = train_dataset.map(add_text_labels)
val_dataset = val_dataset.map(add_text_labels)
test_dataset = test_dataset.map(add_text_labels)

Map:   0%|          | 0/48000 [00:00<?, ? examples/s]

Map:   0%|          | 0/12000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [7]:
train_dataset.set_format(type="python", columns=["image", "label", "text_label"])
val_dataset.set_format(type="python", columns=["image", "label", "text_label"])
test_dataset.set_format(type="python", columns=["image", "label", "text_label"])

##5. Dataloaders

In [8]:
def clip_collate_fn(batch):
  images = []
  labels = []

  for item in batch:
    img = item["image"].convert("RGB")  # Already a PIL Image
    img = preprocess(img)
    images.append(img)
    labels.append(item["label"])

  images = torch.stack(images)
  labels = torch.tensor(labels, dtype=torch.long)

  return {
      "pixel_values": images.to(device),
      "labels": labels.to(device)
  }

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

##6. Fine-Tune Prep

In [9]:
class CLIPClassifier(nn.Module):
  def __init__(self, clip_model, num_classes=10):
    super().__init__()
    self.clip = clip_model
    self.classifier = nn.Linear(self.clip.visual.output_dim, num_classes)

  def forward(self, images):
    image_features = self.clip.encode_image(images)
    logits = self.classifier(image_features)
    return logits

model = CLIPClassifier(clip_model=clip_model).to(device)

In [10]:
if device == "cpu":
  model.float()

optimizer = optim.Adam(model.parameters(), lr=1e-5)

criterion = nn.CrossEntropyLoss() # Use Contrastive Loss later to see whether it does the same but better. Can test zero shot behavior where we can see if it can identify a photo of 7 without explicitly trainining for it.

EPOCHS = 5
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader) * EPOCHS)

##7. Fine-Tuning CLIP on MNIST

In [11]:
best_val_loss = float('inf')
best_epoch = -1

for epoch in range(EPOCHS):
  print(f"Epoch {epoch+1}/{EPOCHS} - Best Val Loss: {best_val_loss:.4f} (Epoch {best_epoch})")

  model.train()
  total_train_loss = 0
  train_steps = 0

  for batch in tqdm(train_loader, desc="Training"):
    optimizer.zero_grad()

    images = batch["pixel_values"]
    labels = batch["labels"]

    logits = model(images)
    loss = criterion(logits, labels)

    loss.backward()
    optimizer.step()

    total_train_loss += loss.item()
    train_steps += 1

  avg_train_loss = total_train_loss / train_steps

  # Validation
  model.eval()
  total_val_loss = 0
  val_steps = 0

  with torch.no_grad():
    for batch in tqdm(val_loader, desc="Validation"):
      images = batch["pixel_values"]
      labels = batch["labels"]

      logits = model(images)
      loss = criterion(logits, labels)

      total_val_loss += loss.item()
      val_steps += 1

  avg_val_loss = total_val_loss / val_steps

  print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f}")

  if avg_val_loss < best_val_loss:
    best_val_loss = avg_val_loss
    best_epoch = epoch
    torch.save(model.state_dict(), "best_clip_mnist.pt")

  scheduler.step()

torch.save(model.state_dict(), "last_clip_mnist.pt")

Epoch 1/5 - Best Val Loss: inf (Epoch -1)


Training: 100%|██████████| 750/750 [05:03<00:00,  2.47it/s]
Validation: 100%|██████████| 188/188 [01:13<00:00,  2.57it/s]


[Epoch 1] Train Loss: 0.1190 | Validation Loss: 0.0371
Epoch 2/5 - Best Val Loss: 0.0371 (Epoch 0)


Training: 100%|██████████| 750/750 [05:13<00:00,  2.39it/s]
Validation: 100%|██████████| 188/188 [01:09<00:00,  2.70it/s]


[Epoch 2] Train Loss: 0.0289 | Validation Loss: 0.0344
Epoch 3/5 - Best Val Loss: 0.0344 (Epoch 1)


Training: 100%|██████████| 750/750 [05:00<00:00,  2.50it/s]
Validation: 100%|██████████| 188/188 [01:09<00:00,  2.70it/s]


[Epoch 3] Train Loss: 0.0233 | Validation Loss: 0.0474
Epoch 4/5 - Best Val Loss: 0.0344 (Epoch 1)


Training: 100%|██████████| 750/750 [05:18<00:00,  2.36it/s]
Validation: 100%|██████████| 188/188 [01:11<00:00,  2.62it/s]


[Epoch 4] Train Loss: 0.0164 | Validation Loss: 0.0264
Epoch 5/5 - Best Val Loss: 0.0264 (Epoch 3)


Training: 100%|██████████| 750/750 [04:30<00:00,  2.77it/s]
Validation: 100%|██████████| 188/188 [01:02<00:00,  3.01it/s]


[Epoch 5] Train Loss: 0.0141 | Validation Loss: 0.0387


##7.5 Testing if Fine-Tuning Actually Worked

In [15]:
base_CLIP, _ = clip.load("ViT-B/32", device=device)
base_CLIP.float()
model = CLIPClassifier(clip_model=base_CLIP).to(device)

best_CLIP, _ = clip.load("ViT-B/32", device=device)
best_CLIP.float()
best_CLIP_MNIST = CLIPClassifier(clip_model=best_CLIP).to(device)
best_CLIP_MNIST.load_state_dict(torch.load("best_clip_mnist.pt", map_location=device)) # map_location tells where to place the model's weights in memory

last_CLIP, _ = clip.load("ViT-B/32", device=device)
last_CLIP.float()
last_CLIP_MNIST = CLIPClassifier(clip_model=last_CLIP).to(device)
last_CLIP_MNIST.load_state_dict(torch.load("last_clip_mnist.pt", map_location=device)) # map_location tells where to place the model's weights in memory

model.eval()
best_CLIP_MNIST.eval()
last_CLIP_MNIST.eval()

total_test_loss_base = 0
total_base = 0
total_test_loss_best = 0
total_best = 0
total_test_loss_last = 0
total_last = 0

correct_base = 0
correct_best = 0
correct_last = 0
total_samples = 0

with torch.no_grad():
  for batch in tqdm(test_loader, desc="Testing"):
    images = batch["pixel_values"]
    labels = batch["labels"]
    total_samples += labels.size(0)

    # Base model
    logits_base = model(images)
    loss_base = criterion(logits_base, labels)
    total_test_loss_base += loss_base.item()
    total_base += 1

    # Best model
    logits_best = best_CLIP_MNIST(images)
    loss_best = criterion(logits_best, labels)
    total_test_loss_best += loss_best.item()
    total_best += 1

    # Last Model
    logits_last = last_CLIP_MNIST(images)
    loss_last = criterion(logits_last, labels)
    total_test_loss_last += loss_last.item()
    total_last += 1

    # Classification Accuracy
    pred_base = logits_base.argmax(dim=1)
    pred_best = logits_best.argmax(dim=1)
    pred_last = logits_last.argmax(dim=1)

    correct_base += (pred_base == labels).sum().item()
    correct_best += (pred_best == labels).sum().item()
    correct_last += (pred_last == labels).sum().item()

avg_base_loss = total_test_loss_base / total_base
avg_best_loss = total_test_loss_best / total_best
avg_last_loss = total_test_loss_last / total_last

accuracy_base = correct_base / total_samples
accuracy_best = correct_best / total_samples
accuracy_last = correct_last / total_samples
print(f"\nAverage base loss: {avg_base_loss:.4f}, Base Accuracy: {accuracy_base:.4f}")
print(f"Average best loss: {avg_best_loss:.4f}, Best Accuracy: {accuracy_best:.4f}")
print(f"Average last loss: {avg_last_loss:.4f}, Best Accuracy: {accuracy_best:.4f}")

Testing: 100%|██████████| 157/157 [00:54<00:00,  2.89it/s]


Average base loss: 2.3599, Base Accuracy: 0.1055
Average best loss: 0.0195, Best Accuracy: 0.9951
Average last loss: 0.0293, Best Accuracy: 0.9951



