# 0. Overview
Author: Darrin O'Brien, email darrinobrien5@gmail.com

1. Preparation.
2. Loads Base and Fine-Tuned on MNIST CLIP Models.
3. Loads and formats entire MNIST dataset. Option to minimize training dataset.
4. Extracts Image Embedding Vectors of both models on the test sets. From base layers {1,2,..,5} to fine-tuned last layer's transformer layer 12.
5. Applys learned transformation matrix W term to augment base CLIP in 5 different ways.
6. Evaluates the performance of the augmented models in comparison to the base and fine-tuned models.

## 1. Installations

In [None]:
!pip install torch torchvision
!pip install -U transformers datasets
!pip install fifty regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install matplotlib
!pip install -U pillow
!pip install pandas

### 1. Runpod Installs Only

In [None]:
!pip install --force-reinstall --no-cache-dir scipy datasets # Only needed within runpod environment

In [None]:
!pip install numpy==1.26.4 # only needed for runpod environment

## 2. Imports

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import clip
import numpy as np
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm
import copy
import torch.nn.functional as F
import pandas as pd
import os

## 3. Setting up Device + Entire Test Set

In [None]:
Affine = False
Transformation_Matrix = True
Translation_Vector = False

type = ""

if Affine:
    type = "Affine_W_b"
elif Transformation_Matrix:
    type = "Transformation_Matrix_W"
elif Translation_Vector:
    type = "Translation_Vector_b"

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device) # https://github.com/openai/CLIP
clip_model = clip_model.float().to(device) # For fp-32 precision
mnist = load_dataset("ylecun/mnist") # https://huggingface.co/datasets/ylecun/mnist
split = mnist["train"].train_test_split(test_size=0.2, seed=66)

train_dataset = split["train"] # 48,000 examples (direct training data from training set)
val_dataset = split["test"] # 12,000 examples (validation set split from training set)
test_dataset = mnist["test"] # 10,000 examples (direct test set)

# Between 1-Testing Images. 1/10 of Test Image Set. MNIST - 1k -> 
size_nums = [i for i in range(1, 100, 5)]
size_nums.insert(0, float('inf'))
size_indice = 0
train_size = 0

# This minimizes the training dataset to learn the transformations/translations. 
labels = []
for i in range(10):
    ds = train_dataset.filter(lambda example: example["label"] == i)
    num = min(size_nums[size_indice], len(ds))
    ds = ds.select(range(num))
    train_size += num
    labels.append(ds)

train_dataset = concatenate_datasets(labels)

train_dataset.set_format(type="python", columns=["image", "label"])
val_dataset.set_format(type="python", columns=["image", "label"])
test_dataset.set_format(type="python", columns=["image", "label"])

## 4. Wrapper Class for Extracting Embeddings

In [None]:
class CLIPClassifier(nn.Module): # for fine-tuned model
  def __init__(self, clip_model, num_classes=10):
    super().__init__()
    self.clip = clip_model
    self.classifier = nn.Linear(self.clip.visual.output_dim, num_classes)

  def forward(self, images):
    image_features = self.clip.encode_image(images)
    logits = self.classifier(image_features)
    return logits

In [None]:
# Returns All CLS Tokens + Final Embedding + Logits 
class CLIPWithHooks(nn.Module):
  def __init__(self, clip_model, classifier_head):
    super().__init__()
    self.clip = clip_model
    self.cls_tokens = []
    self.classifier = classifier_head

  def forward(self, images):
    self.cls_tokens = []

    # B is batch
    x = self.clip.visual.conv1(images)  # Convert image into patch embeddings. Divided into 32*32 patches. Shape is [B, 768, 7, 7]. Each 32*32 batch becomes a 768 dimensional vector. For 224*224 input, get 7*7=49 patches. Now have 49 such vectors per image.
    x = x.reshape(x.shape[0], x.shape[1], -1) # -> [B, 768, 49] -> [B, 49, 768]; Each image is a sequence of 49 token vectors each of size 768, ready for the transformer.
    x = x.permute(0,2,1)

    x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # [B, Patchs+CLS (sequence_length), Embedding Dimension] -> [64, 50 (49+1), 768]
    x = x + self.clip.visual.positional_embedding.to(x.dtype) # Adds positional information so transformer knows order and position. [B, 50, 768] + [1, 50, 768]
    x = self.clip.visual.ln_pre(x) # Normalize to stablize it

    x = x.permute(1,0,2) # [50, 64, 768]

    # Run resblocks manually, so hooks definitely trigger
    for i, resblock in enumerate(self.clip.visual.transformer.resblocks):
        x = resblock(x)
        self.cls_tokens.append(x[0, :, :].detach())

    x = x.permute(1,0,2) # [batch_size, sequence_length, embedding_dim] -> [64, 50, 768]

    x = self.clip.visual.ln_post(x[:, 0, :])

    if self.clip.visual.proj is not None: # Linear Projection from 768 CLS token to 512 dimension vector for compatability
      final_embed = x @ self.clip.visual.proj
      final_embed = final_embed.detach()
    else:
      final_embed = x
      final_embed = final_embed.detach()
    
    logits = self.classifier(final_embed)

    return {
      "cls": [i for i in self.cls_tokens],
      "logits": logits,
      }

In [None]:
refer, preprocess = clip.load("ViT-B/32", device=device)
refer = refer.float().to(device)

base = CLIPWithHooks(copy.deepcopy(refer), nn.Linear(refer.visual.output_dim, 2)) # Random Classifier Head. 50% chance of being right.
base = base.eval().to(device)

f_t = CLIPClassifier(clip_model=copy.deepcopy(refer)).to(device) # Wrap in classifer to retrieve classifier head
f_t.load_state_dict(torch.load("best_clip_mnist_fp32.pt"))
fine_tuned = CLIPWithHooks(f_t.clip, classifier_head=f_t.classifier) # Load in fine-tuned model with fine-tuned visual encoder
fine_tuned = fine_tuned.eval().to(device)

In [None]:
def clip_collate_fn(batch):
  images = []
  labels = []

  for item in batch:
    img = item["image"].convert("RGB")  # Already a PIL Image
    img = preprocess(img)
    images.append(img)
    labels.append(item["label"])

  images = torch.stack(images)
  labels = torch.tensor(labels, dtype=torch.long)

  return {
      "pixel_values": images.to(device),
      "labels": labels.to(device)
  }

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

## 5. Extracting Image Embedding Vectors

In [None]:
# Calculates Least Squares
def leastSquares(Z0, Z1):
    W_full, residuals, rank, s = np.linalg.lstsq(Z0, Z1, rcond=None)
    return W_full  

def addBiasColumn(Z0):
    ones = np.ones((Z0.shape[0], 1)) # (1,1)
    return np.hstack([Z0, ones])

def retrieve_avg(Z0, Z1):
    if Affine:
        Z0 = addBiasColumn(Z0)
        full = leastSquares(Z0, Z1)
        W = full[:-1]
        b = full[-1]
        return W, b
    if Transformation_Matrix:
        return leastSquares(Z0, Z1), None
    if Translation_Vector:
        Z0 = np.array(Z0)
        Z1 = np.array(Z1)
        diff = Z1 - Z0
        return np.mean(diff, axis=0), None
    return None, None

In [None]:
Z0 = {}
for i in [0,2,4,6,7,8,10,11]: 
    Z0[i] = []

Z1_twelfth_lsr = []

with torch.no_grad():
    for batch in tqdm(train_loader, desc=f"Extracting Train Dataset Vectors"):
        images = batch["pixel_values"]

        out_base = base(images)
        out_fine_tuned = fine_tuned(images)
        
        for i in [0,2,4,6,7,8,10,11]:
            Z0[i].append(out_base["cls"][i].float().cpu())

        Z1_twelfth_lsr.append(out_fine_tuned["cls"][11].float().cpu())

Z1_twelfth_lsr = torch.cat(Z1_twelfth_lsr)
Z1_twelfth_lsr = Z1_twelfth_lsr.cpu().numpy()

W = {}
b = {}

for key, value in Z0.items():
    value = torch.cat(value)
    value = value.cpu().numpy()
    W[key], b[key] = retrieve_avg(value, Z1_twelfth_lsr)

## 6. Wrapper Classes for Augmentation

In [None]:
class AugmentedCLIP(nn.Module):
    def __init__(self, clip, W=None, b=None, transform_stage=None, classifier=None):
        super().__init__()
        self.clip = clip
        self.W = torch.from_numpy(W.astype(np.float32)).to(device) if W is not None else None
        self.b = torch.from_numpy(b.astype(np.float32)).to(device) if b is not None else None
        self.transform_stage = transform_stage if transform_stage is not None else -1
        self.classifier = classifier if classifier is not None else nn.Linear(self.clip.visual.output_dim, 10)
    
    def forward(self, image): 
        image = image.to(device)
        x = self.clip.visual.conv1(image)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0,2,1)

        x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # [B, Patchs+CLS (sequence_length), Embedding Dimension] -> [64, 50 (49+1), 768]
        x = x + self.clip.visual.positional_embedding.to(x.dtype)
        x = self.clip.visual.ln_pre(x) # Normalize for Stability

        x = x.permute(1,0,2) # [sequence_length, batch_size, embedding_dim] -> [50, 64, 768]

        if self.transform_stage == -1:
            x = self.clip.visual.transformer(x)
        else:
            for i, block in enumerate(self.clip.visual.transformer.resblocks):
                x = block(x)
                if i == self.transform_stage:
                    cls = x[0, :, :]
                    cls = cls.to(torch.float32)
                    if self.W is None:
                        self.W = torch.eye(cls.shape[-1], device=cls.device, dtype=cls.dtype)
                    if self.b is None:
                        self.b = torch.zeros(cls.shape[-1], device=cls.device, dtype=cls.dtype)
                    manipulated = cls @ self.W + self.b
                    break
            manipulated = manipulated.unsqueeze(0) # Shape (1, B, D)
            x = torch.cat([manipulated, x[1:, :, :]], dim=0) # Adds manipulated cls token all together, not seperately
        
        twelfth_cls = x[0, :, :].squeeze()
        
        x = x.permute(1,0,2) # [batch_size, sequence_length, embedding_dim] -> [64, 50, 768]

        x = self.clip.visual.ln_post(x[:, 0, :])

        if self.clip.visual.proj is not None:
            final_embed = x @ self.clip.visual.proj
        else:
            final_embed = x
        
        logits = self.classifier(final_embed)

        return {
            "logits": logits,
            "manipulated_cls": manipulated.squeeze(0) if self.transform_stage != -1 else twelfth_cls
        }

In [None]:
refer, _ = clip.load("ViT-B/32", device=device)
refer = refer.float()

f_t = CLIPClassifier(clip_model=copy.deepcopy(refer)).to(device) # Wrap in classifer to retrieve classifier head
f_t.load_state_dict(torch.load("best_clip_mnist_fp32.pt"))
fine_tuned = AugmentedCLIP(f_t.clip, classifier=f_t.classifier) # Load in a raw CLIP model
fine_tuned = fine_tuned.eval().to(device)

base_classifier = nn.Linear(refer.visual.output_dim, 10)
base = AugmentedCLIP(copy.deepcopy(refer), classifier=base_classifier)
base = base.eval().to(device)

In [None]:
aug = {}

# 6 augmented models {0,2,4,6,7,8,10,11}
for i in [0,2,4,6,7,8,10,11]: 
    model = AugmentedCLIP(copy.deepcopy(refer), W=W[i], b=b[i], transform_stage=i,classifier=f_t.classifier)
    model = model.eval().to(device)
    aug[i] = model

## 7. Evaluating Augmented Performance

In [None]:
def calcPred(model, images):
    logits = model(images)["logits"]
    pred = logits.argmax(dim=1)
    return pred

def cosineSimilarity(aug, fine, images):
    out_aug = aug(images)
    out_aug_cls = out_aug["manipulated_cls"]

    out_fine = fine(images)
    out_fine_cls = out_fine["manipulated_cls"] # Not Actually manipulated. Just changed. 

    # Prevent NaNs
    eps = 1e-8
    out_aug_cls = F.normalize(out_aug_cls, dim=1, eps=eps)
    out_fine_cls = F.normalize(out_fine_cls, dim=1, eps=eps)

    cos_sim_cls = (out_aug_cls * out_fine_cls).sum(dim=1).mean().item()
    return cos_sim_cls

In [None]:
correct_base = 0
correct_fine_tuned = 0
total_samples = 0

correct = {}
sim_cls = {}

for i in [0,2,4,6,7,8,10,11]:
    correct[i] = 0
    sim_cls[i] = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc=f"Evaluating"):
        images = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        total_samples += labels.size(0)

        correct_base += (calcPred(base, images) == labels).sum().item()
        correct_fine_tuned += (calcPred(fine_tuned, images) == labels).sum().item()

        for i in [0,2,4,6,7,8,10,11]:
            correct[i] += (calcPred(aug[i], images) == labels).sum().item()
            sim_cls[i].append(cosineSimilarity(aug[i], fine_tuned, images))

for i in [0,2,4,6,7,8,10,11]:
    correct[i] = correct[i] / total_samples
    sim_cls[i] = np.mean(sim_cls[i])

correct_base = correct_base / total_samples
correct_fine_tuned = correct_fine_tuned / total_samples

In [None]:
print(f"Augmented CLIP on Entire MNIST Results")
for i in [0,2,4,6,7,8,10,11]:
    print(f"\tAugmented {i+1} - Last (12) Layer Accuracy: {correct[i]}")
    print(f"\tAverage Cosine Similarity of CLS Token of Augmented {i+1} Layer: {sim_cls[i]:.4f}")
print(f"Base Accuracy: {correct_base:.4f}")
print(f"Fine-Tuned Accuracy: {correct_fine_tuned:.4f}")

## 8. Saving Results

In [None]:
folder = f"./Entire_{type}"
os.makedirs(folder, exist_ok=True)

indices = list(W.keys())

data = {
    'Train_Data_Size': [train_size]*len(indices),
    "Transformation": [type if type != "" else "None"] * len(indices),
    'W': [W[i] for i in indices], # data.loc[0, "W"] -> First layer transformation
    'b': [b[i] for i in indices],
    'Accuracy': [correct[i] for i in indices], 
    "Co_Sim_CLS": [sim_cls[i] for i in indices],
}

df = pd.DataFrame(data, index=indices)

name = f"{type}_Entire_Size_{train_size}_Augmentation_Results.csv"
path = os.path.join(folder, name)
df.to_csv(path)

In [None]:
torch.cuda.empty_cache()