# 0. Overview

Author: Darrin O'Brien, email: darrinobrien5@gmail.com

1. Preparation
2. Loads Base and Fine-Tuned on MNIST CLIP Models
3. Extracts Image Embedding Vectors of both models on test set. The vectors calculated comprise of the base and fine-tuned models first, last, and final embedding. 
4. Applys learned weight and bias terms to augment base CLIP in 6 different ways.  
5. Evaluates the performance of the augmented models in comparison to the base and fine-tuned models. 

## 1. Installations

In [None]:
!pip install torch torchvision
!pip install -U transformers datasets
!pip install fifty regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install matplotlib
!pip install -U pillow
%matplotlib inline

### 1 Possible Installs - Runpod only

In [None]:
!pip install --force-reinstall --no-cache-dir scipy datasets # Only needed within runpod environment

In [None]:
!pip install numpy==1.26.4 # only needed for runpod environment

## 2. Imports

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import clip
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import copy

## 3. Setting up Device + Test Set

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
mnist = load_dataset("ylecun/mnist") # https://huggingface.co/datasets/ylecun/mnist
train_dataset = mnist["train"] # 60,000 Samples
test_dataset = mnist["test"] # 10,000 examples (direct test set)

In [None]:
train_dataset.sest_format(type="python", columns=["image", "label"])
test_dataset.set_format(type="python", columns=["image", "label"])

## 4. Wrapping Models & Prepping Dataset

In [None]:
class CLIPClassifier(nn.Module):
  def __init__(self, clip_model, num_classes=10):
    super().__init__()
    self.clip = clip_model
    self.classifier = nn.Linear(self.clip.visual.output_dim, num_classes)

  def forward(self, images):
    image_features = self.clip.encode_image(images)
    logits = self.classifier(image_features)
    return logits

In [None]:
class CLIPWithHooks(nn.Module):
  def __init__(self, clip_model, classifier_head):
    super().__init__()
    self.clip = clip_model
    self.cls_tokens = []
    self.classifier = classifier_head

  def forward(self, images):
    self.cls_tokens = []

    # B is batch
    x = self.clip.visual.conv1(images)  # Convert image into patch embeddings. Divided into 32*32 patches. Shape is [B, 768, 7, 7]. Each 32*32 batch becomes a 768 dimensional vector. For 224*224 input, get 7*7=49 patches. Now have 49 such vectors per image.
    x = x.reshape(x.shape[0], x.shape[1], -1) # -> [B, 768, 49] -> [B, 49, 768]; Each image is a sequence of 49 token vectors each of size 768, ready for the transformer.
    x = x.permute(0,2,1)

    x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # [B, Patchs+CLS (sequence_length), Embedding Dimension] -> [64, 50 (49+1), 768]
    x = x + self.clip.visual.positional_embedding.to(x.dtype) # Adds positional information so transformer knows order and position. [B, 50, 768] + [1, 50, 768]
    x = self.clip.visual.ln_pre(x) # Normalize to stablize it

    x = x.permute(1,0,2) # [50, 64, 768]

    # Run resblocks manually, so hooks definitely trigger
    for i, resblock in enumerate(self.clip.visual.transformer.resblocks):
        x = resblock(x)
        self.cls_tokens.append(x[0, :, :].detach())

    x = x.permute(1,0,2) # [batch_size, sequence_length, embedding_dim] -> [64, 50, 768]

    x = self.clip.visual.ln_post(x[:, 0, :])

    if self.clip.visual.proj is not None: # Linear Projection from 768 CLS token to 512 dimension vector for compatability
      final_embed = x @ self.clip.visual.proj
      final_embed = final_embed.detach()
    else:
      final_embed = x
      final_embed = final_embed.detach()
    
    logits = self.classifier(final_embed)

    return {
      "first_cls": self.cls_tokens[0],
      "second_cls": self.cls_tokens[1],
      "third_cls": self.cls_tokens[2],
      "fourth_cls": self.cls_tokens[3],
      "fifth_cls": self.cls_tokens[4],
      "sixth_cls": self.cls_tokens[5],
      "seventh_cls": self.cls_tokens[6],
      "eighth_cls": self.cls_tokens[7],
      "ninth_cls": self.cls_tokens[8],
      "tenth_cls": self.cls_tokens[9],
      "eleventh_cls": self.cls_tokens[10],
      "twelfth_cls": self.cls_tokens[11],
      "final_embed": final_embed,
      "logits": logits,
      }

In [None]:
refer, preprocess = clip.load("ViT-B/32", device=device)
refer = refer.float()

base = CLIPWithHooks(copy.deepcopy(refer), nn.Linear(refer.visual.output_dim, 2)) # Random Classifier Head. 50% chance of being right.
base = base.eval()

f_t = CLIPClassifier(clip_model=copy.deepcopy(refer)).to(device) # Wrap in classifer to retrieve classifier head
f_t.load_state_dict(torch.load("best_clip_mnist_0_fp32"))
fine_tuned = CLIPWithHooks(copy.deepcopy(refer), classifier_head=f_t.classifier) # Load in a raw CLIP model
fine_tuned = fine_tuned.eval()

In [None]:
def clip_collate_fn(batch):
  images = []
  labels = []

  for item in batch:
    img = item["image"].convert("RGB")  # Already a PIL Image
    img = preprocess(img)
    images.append(img)
    labels.append(item["label"])

  images = torch.stack(images)
  labels = torch.tensor(labels, dtype=torch.long)

  return {
      "pixel_values": images.to(device),
      "labels": labels.to(device)
  }

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

In [None]:
# Base Model Embeddings
Z0_first = []
Z0_second = []
Z0_third = []
Z0_fourth = []
Z0_fifth = []
Z0_sixth = []
Z0_seventh = []
Z0_eight = []
Z0_ninth = []
Z0_tenth = []
Z0_eleventh = []
Z0_twelfth = []

# Fine-Tuned Model Embeddings
# Z1_first = []
# Z1_second = []
# Z1_third = []
# Z1_fourth = []
# Z1_fifth = []
# Z1_sixth = []
# Z1_seventh = []
# Z1_eight = []
# Z1_ninth = []
# Z1_tenth = []
# Z1_eleventh = []
Z1_twelfth = []

with torch.no_grad():
  for i, batch in enumerate(tqdm(train_loader, desc="Extracting Train Set Vectors")):
    images = batch["pixel_values"]

    out_base = base(images)
    out_fine_tuned = fine_tuned(images)

    Z0_first.append(out_base["first_cls"].float())
    Z0_second.append(out_base["second_cls"].float())
    Z0_third.append(out_base["third_cls"].float())
    Z0_fourth.append(out_base["fourth_cls"].float())
    Z0_fifth.append(out_base["fifth_cls"].float())
    Z0_sixth.append(out_base["sixth_cls"].float())
    Z0_seventh.append(out_base["seventh_cls"].float())
    Z0_eight.append(out_base["eighth_cls"].float())
    Z0_ninth.append(out_base["ninth_cls"].float())
    Z0_tenth.append(out_base["tenth_cls"].float())
    Z0_eleventh.append(out_base["eleventh_cls"].float())
    Z0_twelfth.append(out_base["twelfth_cls"].float())

    # Z1_first.append(out_fine_tuned["first_cls"].float())
    # Z1_second.append(out_fine_tuned["second_cls"].float())
    # Z1_third.append(out_fine_tuned["third_cls"].float())
    # Z1_fourth.append(out_fine_tuned["fourth_cls"].float())
    # Z1_fifth.append(out_fine_tuned["fifth_cls"].float())
    # Z1_sixth.append(out_fine_tuned["sixth_cls"].float())
    # Z1_seventh.append(out_fine_tuned["seventh_cls"].float())
    # Z1_eight.append(out_fine_tuned["eighth_cls"].float())
    # Z1_ninth.append(out_fine_tuned["ninth_cls"].float())
    # Z1_tenth.append(out_fine_tuned["tenth_cls"].float())
    # Z1_eleventh.append(out_fine_tuned["eleventh_cls"].float())
    Z1_twelfth.append(out_fine_tuned["twelfth_cls"].float())

Z0_first = torch.cat(Z0_first)
Z0_second = torch.cat(Z0_second)
Z0_third = torch.cat(Z0_third)
Z0_fourth = torch.cat(Z0_fourth)
Z0_fifth = torch.cat(Z0_fifth)
Z0_sixth = torch.cat(Z0_sixth)
Z0_seventh = torch.cat(Z0_seventh)
Z0_eight = torch.cat(Z0_eight)
Z0_ninth = torch.cat(Z0_ninth)
Z0_tenth = torch.cat(Z0_tenth)
Z0_eleventh = torch.cat(Z0_eleventh)
Z0_twelfth = torch.cat(Z0_twelfth)

Z1_twelfth = torch.cat(Z1_twelfth)

In [None]:
# torch.save({
#     "base_first": Z0_first,
#     "base_last": Z0_last,
#     "base_final": Z0_final,
#     "fine_tuned_first": Z1_best_first,
#     "fine_tuned_last": Z1_best_last,
#     "fine_tuned_final": Z1_best_final
# }, "embedding_pairs.pt")

# Code to load in
# pairs = torch.load("embedding_pairs.pt", map_location=device)
# Z0_first = pairs["base_first"]
# Z0_last = pairs["base_last"]
# Z0_final = pairs["base_final"]
# Z1_best_first = pairs["fine_tuned_first"]
# Z1_best_last = pairs["fine_tuned_last"]
# Z1_best_final = pairs["fine_tuned_final"]

## 5. Calculating Linear Transformation Matrix and Bias

In [None]:
Z0_first = Z0_first.cpu().numpy()
Z0_second = Z0_second.cpu().numpy()
Z0_third = Z0_third.cpu().numpy()
Z0_fourth = Z0_fourth.cpu().numpy()
Z0_fifth = Z0_fifth.cpu().numpy()
Z0_sixth = Z0_sixth.cpu().numpy()
Z0_seventh = Z0_seventh.cpu().numpy()
Z0_eight = Z0_eight.cpu().numpy()
Z0_ninth = Z0_ninth.cpu().numpy()
Z0_tenth = Z0_tenth.cpu().numpy()
Z0_eleventh = Z0_eleventh.cpu().numpy()
Z0_twelfth = Z0_twelfth.cpu().numpy()

Z1_twelfth = Z1_twelfth.cpu().numpy()

In [None]:
# Add constant 1 column for bias term
def addBiasColumn(Z0):
    ones = np.ones((Z0.shape[0], 1)) # (1,1)
    return np.hstack([Z0, ones])

Z0_first  = addBiasColumn(Z0_first)
Z0_second = addBiasColumn(Z0_second)
Z0_third  = addBiasColumn(Z0_third)
Z0_fourth = addBiasColumn(Z0_fourth)
Z0_fifth  = addBiasColumn(Z0_fifth)
Z0_sixth  = addBiasColumn(Z0_sixth)
Z0_seventh  = addBiasColumn(Z0_seventh)
Z0_eighth = addBiasColumn(Z0_eight)
Z0_ninth = addBiasColumn(Z0_ninth)
Z0_tenth = addBiasColumn(Z0_tenth)
Z0_eleventh = addBiasColumn(Z0_eleventh)
Z0_twelfth  = addBiasColumn(Z0_twelfth)

In [None]:
# Calculate Bias b and Linear Transformation Matrix W
def leastSquares(Z0, Z1):
    W_full, residuals, rank, s = np.linalg.lstq(Z0, Z1, rcond=None)
    return W_full  

W = leastSquares(Z0_first, Z1_twelfth)
W1 = W[:-1]
b1 = W[-1]

W = leastSquares(Z0_second, Z1_twelfth)
W2 = W[:-1]
b2 = W[-1]

W = leastSquares(Z0_third, Z1_twelfth)
W3 = W[:-1]
b3 = W[-1]

W = leastSquares(Z0_fourth, Z1_twelfth)
W4 = W[:-1]
b4 = W[-1]

W = leastSquares(Z0_fifth, Z1_twelfth)
W5 = W[:-1]
b5 = W[-1]

W = leastSquares(Z0_sixth, Z1_twelfth)
W6 = W[:-1]
b6 = W[-1]

W = leastSquares(Z0_seventh, Z1_twelfth)
W7 = W[:-1]
b7 = W[-1]

W = leastSquares(Z0_eight, Z1_twelfth)
W8 = W[:-1]
b8 = W[-1]

W = leastSquares(Z0_ninth, Z1_twelfth)
W9 = W[:-1]
b9 = W[-1]

W = leastSquares(Z0_tenth, Z1_twelfth)
W10 = W[:-1]
b10 = W[-1]

W = leastSquares(Z0_eleventh, Z1_twelfth)
W11 = W[:-1]
b11 = W[-1]

W = leastSquares(Z0_twelfth, Z1_twelfth)
W12 = W[:-1]
b12 = W[-1]

In [None]:
# torch.save({
#     "W_first_first": W_0_1,
#     "b_first_first": b_0_1,
#     "W_first_last": W_1_1,
#     "b_first_last": b_1_1,
#     "W_first_final": W_2_1,
#     "b_first_final": b_2_1,
#     "W_last_last": W_3_1,
#     "b_last_last": b_3_1,
#     "W_last_final": W_4_1,
#     "b_last_final": b_4_1,
#     "W_final_final": W_5_1,
#     "b_final_final": b_5_1
# }, "clip_mnist_transformations.pt")

# Code to Load:
# vals = torch.load("clip_mnist_transformations.pt")
# W_0_1 = vals["W_first_first"]
# b_0_1 = vals["b_first_first"]
# W_1_1 = vals["W_first_last"]
# b_1_1 = vals["b_first_last"]
# W_2_1 = vals["W_first_final"]
# b_2_1 = vals["b_first_final"]
# W_3_1 = vals["W_last_last"]
# b_3_1 = vals["b_last_last"]
# W_4_1 = vals["W_last_final"]
# b_4_1 = vals["b_last_final"]
# W_5_1 = vals["W_final_final"]
# b_5_1 = vals["b_final_final"]

## 6. Augmenting Base CLIP

In [None]:
# For fine-tuned classifier head
class AugmentedCLIP(nn.Module):
  def __init__(self, clip_model, W=None, b=None, transform_stage=None, classifier=None):
    super().__init__()
    self.clip = clip_model
    self.W = torch.from_numpy(W.astype(np.float32)).to(device) if W is not None else None
    self.b = torch.from_numpy(b.astype(np.float32)).to(device) if b is not None else None
    self.transform_stage = transform_stage if transform_stage is not None else -1
    self.classifier = classifier if classifier is not None else nn.Linear(self.clip.visual.output_dim, 2)

  def forward(self, image):
    image = image.to(device) # [B, 3, 224, 224]
    x = self.clip.visual.conv1(image) # [B, 768, 7, 7]
    x = x.reshape(x.shape[0], x.shape[1], -1) # [B, width, 49]
    x = x.permute(0, 2, 1) # [B, 49, 768]

    x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # [B, Patchs+CLS (sequence_length), Embedding Dimension] -> [64, 50 (49+1), 768]
    x = x + self.clip.visual.positional_embedding.to(x.dtype)
    x = self.clip.visual.ln_pre(x) # Normalize for Stability

    x = x.permute(1,0,2) # [sequence_length, batch_size, embedding_dim] -> [50, B, 768] # NLD -> LND
    
    if i == -1:
      x = self.clip.visual.transformer(x)
    else:
      for i, block in enumerate(self.clip.visual.transformer.resblocks): # Compatible for 1-11 layers to #12
        x = block(x)
        if i+1 == self.transform_stage:
          cls = x[0, :, :]
          cls = cls.to(torch.float32)
          manipulated = cls @ self.W + self.b 
          break 
        manipulated = manipulated.unsqueeze(0) # Shape (1, B, D) -> (1, 64, 768)
        x = torch.cat([manipulated, x[1:, :, :]], dim=0) # Adds manipulated cls token all together, not seperately
    
    x = x.permute(1,0,2) # [batch_size, sequence_length, embedding_dim] -> [64, 50, 768]; LND -> NLD
    x = self.clip.visual.ln_post(x[:, 0, :])

    if self.clip.visual.proj is not None:
      final_embed = x @ self.clip.visual.proj
      final_embed = final_embed
    else:
      final_embed = x
      final_embed = final_embed
        
    logits = self.classifier(final_embed)

    return logits

    # Don't need to do normalization from original code. Need to preserve scale information to accurately fit W, b. Keep the true structural nature
    # Can normalize if evaluating cosine similarities between embeddings. Normalize both base and fine-tuned if doign so before computing similarity/dot product.
    # Just doing classification so no need for text

In [None]:
refer, _ = clip.load("ViT-B/32", device=device)
refer = refer.float()

f_t = CLIPClassifier(clip_model=copy.deepcopy(refer)).to(device) # Wrap in classifer to retrieve classifier head
f_t.load_state_dict(torch.load("best_clip_mnist_fp32"))
fine_tuned = AugmentedCLIP(copy.deepcopy(refer), classifier=f_t.classifier) # Load in a raw CLIP model
fine_tuned = fine_tuned.eval()

base = AugmentedCLIP(copy.deepcopy(refer), classifier=f_t.classifier) # Random Classifier Head. 50% chance of being right.
base = base.eval()

# Base First layer -> Fine-Tuned last layer (12)
first_last = AugmentedCLIP(copy.deepcopy(refer), W1, b1, transform_stage=1, classifier=f_t.classifier)
first_last.eval()

# Base Second layer -> Fine-Tuned last layer (12)
second_last = AugmentedCLIP(copy.deepcopy(refer), W2, b2, transform_stage=2, classifier=f_t.classifier)
second_last.eval()

# Base Third layer -> Fine-Tuned last layer (12)
third_last = AugmentedCLIP(copy.deepcopy(refer), W3, b3, transform_stage=3, classifier=f_t.classifier)
third_last.eval()

# Base Fourth layer -> Fine-Tuned last layer (12)
fourth_last = AugmentedCLIP(copy.deepcopy(refer), W4, b4, transform_stage=4, classifier=f_t.classifier)
fourth_last.eval()

# Base Fifth layer -> Fine-Tuned last layer (12)
fifth_last = AugmentedCLIP(copy.deepcopy(refer), W5, b5, transform_stage=5, classifier=f_t.classifier)
fifth_last.eval()

# Base Sixth layer -> Fine-Tuned last layer (12)
sixth_last = AugmentedCLIP(copy.deepcopy(refer), W6, b6, transform_stage=6, classifier=f_t.classifier)
sixth_last.eval()

# Base Seventh layer -> Fine-Tuned last layer (12)
seventh_last = AugmentedCLIP(copy.deepcopy(refer), W7, b7, transform_stage=7, classifier=f_t.classifier)
seventh_last.eval()

# Base Eight layer -> Fine-Tuned last layer (12)
eighth_last = AugmentedCLIP(copy.deepcopy(refer), W8, b8, transform_stage=8, classifier=f_t.classifier)
eighth_last.eval()

# Base Ninth layer -> Fine-Tuned last layer (12)
ninth_last = AugmentedCLIP(copy.deepcopy(refer), W9, b9, transform_stage=9, classifier=f_t.classifier)
ninth_last.eval()

# Base Tenth layer -> Fine-Tuned last layer (12)
tenth_last = AugmentedCLIP(copy.deepcopy(refer), W10, b10, transform_stage=10, classifier=f_t.classifier)
tenth_last.eval()

# Base Eleventh layer -> Fine-Tuned last layer (12)
eleventh_last = AugmentedCLIP(copy.deepcopy(refer), W11, b11, transform_stage=11, classifier=f_t.classifier)
eleventh_last.eval()

# Base Twelfth layer -> Fine-Tuned last layer (12)
twelfth_last = AugmentedCLIP(copy.deepcopy(refer), W12, b12, transform_stage=12, classifier=f_t.classifier)
twelfth_last.eval()

## 7. Evaluating Base v Fine-Tuned vs Augmented

In [None]:
correct_1 = 0
correct_2 = 0
correct_3 = 0
correct_4 = 0
correct_5 = 0
correct_6 = 0
correct_7 = 0
correct_8 = 0
correct_9 = 0
correct_10 = 0
correct_11 = 0
correct_12 = 0
correct_base = 0
correct_fine_tuned = 0

total_samples = 0

def calcPred(model, images):
    logits = model(images)
    pred = logits.argmax(dim=1)
    return pred

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating"):
        images = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        total_samples += labels.size(0)

        correct_1 += (calcPred(first_last, images) == labels).sum().item()
        correct_2 += (calcPred(second_last, images) == labels).sum().item()
        correct_3 += (calcPred(third_last, images) == labels).sum().item()
        correct_4 += (calcPred(fourth_last, images) == labels).sum().item()
        correct_5 += (calcPred(fifth_last, images) == labels).sum().item()
        correct_6 += (calcPred(sixth_last, images) == labels).sum().item()
        correct_7 += (calcPred(seventh_last, images) == labels).sum().item()
        correct_8 += (calcPred(eighth_last, images) == labels).sum().item()
        correct_9 += (calcPred(ninth_last, images) == labels).sum().item()
        correct_10 += (calcPred(tenth_last, images) == labels).sum().item()
        correct_11 += (calcPred(eleventh_last, images) == labels).sum().item()
        correct_12 += (calcPred(twelfth_last, images) == labels).sum().item()

        correct_base += (calcPred(base, images) == labels).sum().item()
        correct_fine_tuned += (calcPred(fine_tuned, images) == labels).sum().item()

acc_1 = correct_1 / total_samples
acc_2 = correct_2 / total_samples
acc_3 = correct_3 / total_samples
acc_4 = correct_4 / total_samples
acc_5 = correct_5 / total_samples
acc_6 = correct_6 / total_samples
acc_7 = correct_7 / total_samples
acc_8 = correct_8 / total_samples
acc_9 = correct_9 / total_samples
acc_10 = correct_10 / total_samples
acc_11 = correct_11 / total_samples
acc_12 = correct_12 / total_samples

acc_base = correct_base / total_samples
acc_fine_tuned = correct_fine_tuned / total_samples

print(f"\n")
for i, acc in enumerate([
    acc_1, acc_2, acc_3, acc_4, acc_5, acc_6,
    acc_7, acc_8, acc_9, acc_10, acc_11, acc_12
]):
    print(f"Augmented {i+1} - Last Layer CLIP Accuracy: {acc:.4f}")