# 0. Overview
Author: Darrin O'Brien, email darrinobrien5@gmail.com
1. Preparation
2. Loads Base and Fine-Tuned on MNIST CLIP Models
3. Seperates Test Set into 10 different subsets, each of the number {1,2,3...9}. So one testset is only comprised of a single label, e.g. label 0.
3. Extracts Image Embedding Vectors of both models on test sets. From base layers {1,2,...5} to fine-tuned last layer's transformer layer 12. 
4. Applys learned translation vector b term to augment base CLIP in 5 different ways.  
5. Evaluates the performance of the augmented models in comparison to the base and fine-tuned models. 

## 1. Installations

In [None]:
!pip install torch torchvision
!pip install -U transformers datasets
!pip install fifty regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install matplotlib
!pip install -U pillow
%matplotlib inline

### Extra Installs for Runpod

In [None]:
!pip install --force-reinstall --no-cache-dir scipy datasets # Only needed within runpod environment

In [None]:
!pip install numpy==1.26.4 # only needed for runpod environment

## 2. Imports

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import clip
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import copy
import torch.nn.functional as F

## 3. Setting up Device + Loading Dataset

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device) # https://github.com/openai/CLIP
clip_model = clip_model.float().to(device) # For fp-32 precision
mnist = load_dataset("ylecun/mnist") # https://huggingface.co/datasets/ylecun/mnist
split = mnist["train"].train_test_split(test_size=0.2, seed=66)

## 4. Wrapper Class for Extracting Embeddings

In [None]:
class CLIPClassifier(nn.Module): # for fine-tuned model
  def __init__(self, clip_model, num_classes=10):
    super().__init__()
    self.clip = clip_model
    self.classifier = nn.Linear(self.clip.visual.output_dim, num_classes)

  def forward(self, images):
    image_features = self.clip.encode_image(images)
    logits = self.classifier(image_features)
    return logits

In [None]:
class CLIPWithHooks(nn.Module):
  def __init__(self, clip_model, classifier_head):
    super().__init__()
    self.clip = clip_model
    self.cls_tokens = []
    self.classifier = classifier_head

  def forward(self, images):
    self.cls_tokens = []

    # B is batch
    x = self.clip.visual.conv1(images)  # Convert image into patch embeddings. Divided into 32*32 patches. Shape is [B, 768, 7, 7]. Each 32*32 batch becomes a 768 dimensional vector. For 224*224 input, get 7*7=49 patches. Now have 49 such vectors per image.
    x = x.reshape(x.shape[0], x.shape[1], -1) # -> [B, 768, 49] -> [B, 49, 768]; Each image is a sequence of 49 token vectors each of size 768, ready for the transformer.
    x = x.permute(0,2,1)

    x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # [B, Patchs+CLS (sequence_length), Embedding Dimension] -> [64, 50 (49+1), 768]
    x = x + self.clip.visual.positional_embedding.to(x.dtype) # Adds positional information so transformer knows order and position. [B, 50, 768] + [1, 50, 768]
    x = self.clip.visual.ln_pre(x) # Normalize to stablize it

    x = x.permute(1,0,2) # [50, 64, 768]

    # Run resblocks manually, so hooks definitely trigger
    for i, resblock in enumerate(self.clip.visual.transformer.resblocks):
        x = resblock(x)
        self.cls_tokens.append(x[0, :, :].detach())

    x = x.permute(1,0,2) # [batch_size, sequence_length, embedding_dim] -> [64, 50, 768]

    x = self.clip.visual.ln_post(x[:, 0, :])

    if self.clip.visual.proj is not None: # Linear Projection from 768 CLS token to 512 dimension vector for compatability
      final_embed = x @ self.clip.visual.proj
      final_embed = final_embed.detach()
    else:
      final_embed = x
      final_embed = final_embed.detach()
    
    logits = self.classifier(final_embed)

    return {
      "first_cls": self.cls_tokens[0],
      "second_cls": self.cls_tokens[1],
      "third_cls": self.cls_tokens[2],
      "fourth_cls": self.cls_tokens[3],
      "fifth_cls": self.cls_tokens[4],
      "sixth_cls": self.cls_tokens[5],
      "seventh_cls": self.cls_tokens[6],
      "eighth_cls": self.cls_tokens[7],
      "ninth_cls": self.cls_tokens[8],
      "tenth_cls": self.cls_tokens[9],
      "eleventh_cls": self.cls_tokens[10],
      "twelfth_cls": self.cls_tokens[11],
      "final_embed": final_embed,
      "logits": logits,
      }

In [None]:
refer, preprocess = clip.load("ViT-B/32", device=device)
refer = refer.float().to(device)

base = CLIPWithHooks(copy.deepcopy(refer), nn.Linear(refer.visual.output_dim, 2)) # Random Classifier Head. 50% chance of being right.
base = base.eval().to(device)

f_t = CLIPClassifier(clip_model=copy.deepcopy(refer)).to(device) # Wrap in classifer to retrieve classifier head
f_t.load_state_dict(torch.load("best_clip_mnist_fp32.pt"))
fine_tuned = CLIPWithHooks(f_t.clip, classifier_head=f_t.classifier) # Load in fine-tuned model with fine-tuned visual encoder. Basically just the fine-tuned model's visual and text encoder.
fine_tuned = fine_tuned.eval().to(device)

In [None]:
def clip_collate_fn(batch):
  images = []
  labels = []

  for item in batch:
    img = item["image"].convert("RGB")  # Already a PIL Image
    img = preprocess(img)
    images.append(img)
    labels.append(item["label"])

  images = torch.stack(images)
  labels = torch.tensor(labels, dtype=torch.long)

  return {
      "pixel_values": images.to(device),
      "labels": labels.to(device)
  }

## 5. Setting up Individual Label Test Sets

In [None]:
train_loader = {}
val_loader = {}
test_loader = {}

# 0-9
for i in range(10):
    train_dataset = split["train"].filter(lambda example: example["label"] == i)
    val_dataset = split["test"].filter(lambda example: example["label"] == i)
    test_dataset = mnist["test"].filter(lambda example: example["label"] == i)

    # Making the training dataset for learned transformation/translation smaller. 
    '''
    train_dataset = train_dataset.select(range(min(600, len(train_dataset))))
    '''

    train_dataset.set_format(type="python", columns=["image", "label"])
    val_dataset.set_format(type="python", columns=["image", "label"])
    test_dataset.set_format(type="python", columns=["image", "label"])

    train_loader[i] = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
    val_loader[i] = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
    test_loader[i] = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

In [None]:
# Returns Translation Vector b
def retrieve_avg(Z0, Z1):
    Z0_arr = np.array(Z0)
    Z1_arr = np.array(Z1)

    diff = Z1_arr - Z0_arr

    b = np.mean(diff, axis=0)
    
    return b

In [None]:
b_first = {}
b_second = {}
b_third = {}
b_fourth = {}
b_fifth = {}

# For Labels 0-9
for i in range(10):
    # Least Squares Regression
    Z0_first_lsr = []
    Z0_second_lsr = []
    Z0_third_lsr = []
    Z0_fourth_lsr = []
    Z0_fifth_lsr = []

    Z1_twelfth_lsr = []

    with torch.no_grad():
        for batch in tqdm(train_loader[i], desc=f"Extracting Train Set Label = {i} Vectors"):
            images = batch["pixel_values"]

            out_base = base(images)
            out_fine_tuned = fine_tuned(images)

            Z0_first_lsr.append(out_base["first_cls"].float())
            Z0_second_lsr.append(out_base["second_cls"].float())
            Z0_third_lsr.append(out_base["third_cls"].float())
            Z0_fourth_lsr.append(out_base["fourth_cls"].float())
            Z0_fifth_lsr.append(out_base["fifth_cls"].float())

            Z1_twelfth_lsr.append(out_fine_tuned["twelfth_cls"].float())

    Z0_first_lsr = torch.cat(Z0_first_lsr)
    Z0_second_lsr = torch.cat(Z0_second_lsr)
    Z0_third_lsr = torch.cat(Z0_third_lsr)
    Z0_fourth_lsr = torch.cat(Z0_fourth_lsr)
    Z0_fifth_lsr = torch.cat(Z0_fifth_lsr)

    Z1_twelfth_lsr = torch.cat(Z1_twelfth_lsr)

    Z0_first_lsr = Z0_first_lsr.cpu().numpy()
    Z0_second_lsr = Z0_second_lsr.cpu().numpy()
    Z0_third_lsr = Z0_third_lsr.cpu().numpy()
    Z0_fourth_lsr = Z0_fourth_lsr.cpu().numpy()
    Z0_fifth_lsr = Z0_fifth_lsr.cpu().numpy()

    Z1_twelfth_lsr = Z1_twelfth_lsr.cpu().numpy()

    b_first[i] = retrieve_avg(Z0_first_lsr, Z1_twelfth_lsr)
    b_second[i] = retrieve_avg(Z0_second_lsr, Z1_twelfth_lsr)
    b_third[i] = retrieve_avg(Z0_third_lsr, Z1_twelfth_lsr)
    b_fourth[i] = retrieve_avg(Z0_fourth_lsr, Z1_twelfth_lsr)
    b_fifth[i] = retrieve_avg(Z0_fifth_lsr, Z1_twelfth_lsr)

## 6. Wrapper Classes for Augmentation

In [None]:
class AugmentedCLIP(nn.Module):
    def __init__(self, clip, W=None, b=None, transform_stage=None, classifier=None):
        super().__init__()
        self.clip = clip
        self.W = torch.from_numpy(W.astype(np.float32)).to(device) if W is not None else None
        self.b = torch.from_numpy(b.astype(np.float32)).to(device) if b is not None else None
        self.transform_stage = transform_stage if transform_stage is not None else -1
        self.classifier = classifier if classifier is not None else nn.Linear(self.clip.visual.output_dim, 10)
    
    def forward(self, image): 
        image = image.to(device)
        x = self.clip.visual.conv1(image)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0,2,1)

        x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # [B, Patchs+CLS (sequence_length), Embedding Dimension] -> [64, 50 (49+1), 768]
        x = x + self.clip.visual.positional_embedding.to(x.dtype)
        x = self.clip.visual.ln_pre(x) # Normalize for Stability

        x = x.permute(1,0,2) # [sequence_length, batch_size, embedding_dim] -> [50, 64, 768]

        if self.transform_stage == -1:
            x = self.clip.visual.transformer(x)
        else:
            for i, block in enumerate(self.clip.visual.transformer.resblocks):
                x = block(x)
                if i+1 == self.transform_stage:
                    cls = x[0, :, :]
                    cls = cls.to(torch.float32)
                    if self.W is None:
                        self.W = torch.eye(cls.shape[-1], device=cls.device, dtype=cls.dtype)
                    if self.b is None:
                        self.b = torch.zeros(cls.shape[-1], device=cls.device, dtype=cls.dtype)
                    manipulated = cls @ self.W + self.b
                    break
            manipulated = manipulated.unsqueeze(0) # Shape (1, B, D)
            x = torch.cat([manipulated, x[1:, :, :]], dim=0) # Adds manipulated cls token all together, not seperately
        twelfth_cls = x[0, :, :].squeeze()
        
        x = x.permute(1,0,2) # [batch_size, sequence_length, embedding_dim] -> [64, 50, 768]

        x = self.clip.visual.ln_post(x[:, 0, :])

        if self.clip.visual.proj is not None:
            final_embed = x @ self.clip.visual.proj
            final_embed = final_embed
        else:
            final_embed = x
            final_embed = final_embed
        
        logits = self.classifier(final_embed)

        return {
            "logits": logits,
            "final_embed": final_embed,
            "manipulated_cls": manipulated.squeeze(0) if self.transform_stage != -1 else twelfth_cls
        }

In [None]:
refer, _ = clip.load("ViT-B/32", device=device)
refer = refer.float()

f_t = CLIPClassifier(clip_model=copy.deepcopy(refer)).to(device) # Wrap in classifer to retrieve classifier head
f_t.load_state_dict(torch.load("best_clip_mnist_fp32.pt"))
fine_tuned = AugmentedCLIP(f_t.clip, classifier=f_t.classifier) # Load in a raw CLIP model
fine_tuned = fine_tuned.eval().to(device)

base_classifier = nn.Linear(refer.visual.output_dim, 10)
base = AugmentedCLIP(copy.deepcopy(refer), classifier=base_classifier)
base = base.eval().to(device)

In [None]:
aug_1 = {}
aug_2 = {}
aug_3 = {}
aug_4 = {}
aug_5 = {}

# Creating Augmented Models for each of the labels
for i in range(10):
    model_1 = AugmentedCLIP(copy.deepcopy(refer), b=b_first[i], transform_stage=1, classifier=f_t.classifier)
    model_1 = model_1.eval().to(device)
    aug_1[i] = model_1 

    model_2 = AugmentedCLIP(copy.deepcopy(refer), b=b_second[i], transform_stage=2, classifier=f_t.classifier)
    model_2 = model_2.eval().to(device)
    aug_2[i] = model_2 

    model_3 = AugmentedCLIP(copy.deepcopy(refer), b=b_third[i], transform_stage=3, classifier=f_t.classifier)
    model_3 = model_3.eval().to(device)
    aug_3[i] = model_3 

    model_4 = AugmentedCLIP(copy.deepcopy(refer), b=b_fourth[i], transform_stage=4, classifier=f_t.classifier)
    model_4 = model_4.eval().to(device)
    aug_4[i] = model_4

    model_5 = AugmentedCLIP(copy.deepcopy(refer), b=b_fifth[i], transform_stage=5, classifier=f_t.classifier)
    model_5 = model_5.eval().to(device)
    aug_5[i] = model_5 


## 7. Evaluating Augmented Performance

In [None]:
def calcPred(model, images):
    logits = model(images)["logits"]
    pred = logits.argmax(dim=1)
    return pred

def cosineSimilarity(aug, fine, images):
    out_aug = aug(images)
    out_aug_embed = out_aug["final_embed"]
    out_aug_cls = out_aug["manipulated_cls"]

    out_fine = fine(images)
    out_fine_embed = out_fine["final_embed"]
    out_fine_cls = out_fine["manipulated_cls"]

    # Prevent NaNs
    eps = 1e-8
    out_aug_embed = F.normalize(out_aug_embed, dim=1, eps=eps)
    out_fine_embed = F.normalize(out_fine_embed, dim=1, eps=eps)
    out_aug_cls = F.normalize(out_aug_cls, dim=1, eps=eps)
    out_fine_cls = F.normalize(out_fine_cls, dim=1, eps=eps)

    cos_sim_final = (out_aug_embed * out_fine_embed).sum(dim=1).mean().item()
    cos_sim_cls = (out_aug_cls * out_fine_cls).sum(dim=1).mean().item()
    return cos_sim_final, cos_sim_cls

In [None]:
accuracy_1 = {}
accuracy_2 = {}
accuracy_3 = {}
accuracy_4 = {}
accuracy_5 = {}

accuracy_base = {}
accuracy_fine_tuned = {}

co_sim_final_1 = {}
co_sim_cls_1 = {}

co_sim_final_2 = {}
co_sim_cls_2 = {}

co_sim_final_3 = {}
co_sim_cls_3 = {}

co_sim_final_4 = {}
co_sim_cls_4 = {}

co_sim_final_5 = {}
co_sim_cls_5 = {}

for i in range(10):
    correct_1 = 0
    correct_2 = 0
    correct_3 = 0
    correct_4 = 0
    correct_5 = 0
    correct_base = 0
    correct_fine_tuned = 0

    total_samples = 0

    sim_final_1 = []
    sim_cls_1 = []

    sim_final_2 = []
    sim_cls_2 = []

    sim_final_3 = []
    sim_cls_3 = []

    sim_final_4 = []
    sim_cls_4 = []

    sim_final_5 = []
    sim_cls_5 = []

    with torch.no_grad():
        for batch in tqdm(test_loader[i], desc=f"Evaluating {i}'s"):
            images = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)
            total_samples += labels.size(0)

            correct_1 += (calcPred(aug_1[i], images) == labels).sum().item()
            correct_2 += (calcPred(aug_2[i], images) == labels).sum().item()
            correct_3 += (calcPred(aug_3[i], images) == labels).sum().item()
            correct_4 += (calcPred(aug_4[i], images) == labels).sum().item()
            correct_5 += (calcPred(aug_5[i], images) == labels).sum().item()

            correct_base += (calcPred(base, images) == labels).sum().item()
            correct_fine_tuned += (calcPred(fine_tuned, images) == labels).sum().item()

            cs_final, cs_cls = cosineSimilarity(aug_1[i], fine_tuned, images)
            sim_final_1.append(cs_final)
            sim_cls_1.append(cs_cls)

            cs_final, cs_cls = cosineSimilarity(aug_2[i], fine_tuned, images)
            sim_final_2.append(cs_final)
            sim_cls_2.append(cs_cls)

            cs_final, cs_cls = cosineSimilarity(aug_3[i], fine_tuned, images)
            sim_final_3.append(cs_final)
            sim_cls_3.append(cs_cls)

            cs_final, cs_cls = cosineSimilarity(aug_4[i], fine_tuned, images)
            sim_final_4.append(cs_final)
            sim_cls_4.append(cs_cls)

            cs_final, cs_cls = cosineSimilarity(aug_5[i], fine_tuned, images)
            sim_final_5.append(cs_final)
            sim_cls_5.append(cs_cls)

    accuracy_1[i] = correct_1  / total_samples
    accuracy_2[i]= correct_2 / total_samples
    accuracy_3[i] = correct_3 / total_samples
    accuracy_4[i] = correct_4 / total_samples
    accuracy_5[i] = correct_5 / total_samples

    accuracy_base[i] = correct_base / total_samples
    accuracy_fine_tuned[i] = correct_fine_tuned / total_samples

    co_sim_final_1[i] = np.mean(sim_final_1)
    co_sim_cls_1[i] = np.mean(sim_cls_1)

    co_sim_final_2[i] = np.mean(sim_final_2)
    co_sim_cls_2[i] = np.mean(sim_cls_2)

    co_sim_final_3[i] = np.mean(sim_final_3)
    co_sim_cls_3[i] = np.mean(sim_cls_3)

    co_sim_final_4[i] = np.mean(sim_final_4)
    co_sim_cls_4[i] = np.mean(sim_cls_4)

    co_sim_final_5[i] = np.mean(sim_final_5)
    co_sim_cls_5[i] = np.mean(sim_cls_5)

In [None]:
for i in range(10):
    print(f"Augmented CLIP on MNIST Label {i} Results")
    print(f"\tAugmented 1st - Last Layer Accuracy: {accuracy_1[i]:.4f}")
    print(f"\tAugmented 2nd - Last Layer Accuracy: {accuracy_2[i]:.4f}")
    print(f"\tAugmented 3rd - Last Layer Accuracy: {accuracy_3[i]:.4f}")
    print(f"\tAugmented 4th - Last Layer Accuracy: {accuracy_4[i]:.4f}")
    print(f"\tAugmented 5th - Last Layer Accuracy: {accuracy_5[i]:.4f}")
    print(f"\tBase Accuracy: {accuracy_base[i]:.4f}")
    print(f"\tFine-Tuned Accuracy: {accuracy_fine_tuned[i]:.4f}")
    print("\n")
    print(f"\tAverage Cosine Similarity for final embedding of Augmented 1st Layer: {co_sim_final_1[i]:.4f}")
    print(f"\tAverage Cosine Similarity for final embedding of Augmented 2nd Layer: {co_sim_final_2[i]:.4f}")
    print(f"\tAverage Cosine Similarity for final embedding of Augmented 3rd Layer: {co_sim_final_3[i]:.4f}")
    print(f"\tAverage Cosine Similarity for final embedding of Augmented 4th Layer: {co_sim_final_4[i]:.4f}")
    print(f"\tAverage Cosine Similarity for final embedding of Augmented 5th Layer: {co_sim_final_5[i]:.4f}")
    print("\n")
    print(f"\tAverage Cosine Similarity for CLS token of Augmented 1st Layer: {co_sim_cls_1[i]:.4f}")
    print(f"\tAverage Cosine Similarity for CLS token of Augmented 2nd Layer: {co_sim_cls_2[i]:.4f}")
    print(f"\tAverage Cosine Similarity for CLS token of Augmented 3rd Layer: {co_sim_cls_3[i]:.4f}")
    print(f"\tAverage Cosine Similarity for CLS token of Augmented 4th Layer: {co_sim_cls_4[i]:.4f}")
    print(f"\tAverage Cosine Similarity for CLS token of Augmented 5th Layer: {co_sim_cls_5[i]:.4f}")
    print("\n\n\n")

In [None]:
torch.cuda.empty_cache()