In [None]:
!pip install -U torch==2.6 torchvision tqdm
!pip install -U transformers datasets
!pip install --upgrade Pillow

In [None]:
!pip uninstall -y Pillow
!pip install Pillow

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm
import copy
import torch.nn.functional as F
import pandas as pd
import os
from transformers import CLIPVisionModel, CLIPModel, CLIPProcessor
from PIL import ExifTags
import PIL

In [None]:
Affine = False
Transformation_Matrix = True
Translation_Vector = False

type = ""

if Affine:
    type = "Affine_W_b"
elif Transformation_Matrix:
    type = "Transformation_Matrix_W"
elif Translation_Vector:
    type = "Translation_Vector_b"

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float32).to(device)
preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

mnist = load_dataset("ylecun/mnist") # https://huggingface.co/datasets/ylecun/mnist
split = mnist["train"].train_test_split(test_size=0.2, seed=66)
train_dataset = split["train"] # 48,000 examples (direct training data from training set)
val_dataset = split["test"] # 12,000 examples (validation set split from training set)
test_dataset = mnist["test"] # 10,000 examples (direct test set)

def clip_collate_fn(batch):
    img = [item["image"].convert("RGB") for item in batch]
    labels = [item["label"] for item in batch]
    inputs = preprocess(images=img, padding=True, return_tensors="pt")
    return {
        "pixel_values": inputs["pixel_values"].to(device),
        "labels": torch.tensor(labels, dtype=torch.long).to(device)
    }

train_dataset.set_format(type="python", columns=["image", "label"])
val_dataset.set_format(type="python", columns=["image", "label"])
test_dataset.set_format(type="python", columns=["image", "label"])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

text_prompts = [f"a photo of {i}" for i in range(10)] # The 10 classes fed to CLIP for every image
text_inputs = preprocess(text=text_prompts, padding=True, return_tensors="pt").to(device)
with torch.no_grad():
  text_feats = model.get_text_features(**text_inputs) # (10, D) -> (10, 768) vectors. One for each prompt.
  text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)

In [None]:
class CLIPWithHooks(nn.Module):
  def __init__(self, clip_model):
    super().__init__()
    self.clip = clip_model
    self.cls_tokens = []
    self.final_embed = None

  def forward(self, images):
    self.cls_tokens = []

    x = self.clip.vision_model.embeddings(images)
    x = self.clip.vision_model.pre_layrnorm(x)
    x = x.permute(1,0,2) # [50, 64, 768]

    # Run resblocks manually, so hooks definitely trigger
    for i, resblock in enumerate(self.clip.vision_model.encoder.layers):
      x = resblock(hidden_states=x, attention_mask=None, causal_attention_mask=None)[0]
      self.cls_tokens.append(x[0, :, :])

    x = x.permute(1,0,2) # [batch_size, sequence_length, embedding_dim] -> [64, 50, 768]

    x = self.clip.vision_model.post_layernorm(x)

    pooled_output = x[:, 0, :]

    if self.clip.visual_projection is not None: # Linear Projection from 768 CLS token to 512 dimension vector for compatability
      self.final_embed = self.clip.visual_projection(pooled_output)
    else:
      self.final_embed = pooled_output

    return {
      "cls": self.cls_tokens,
      "final_embedding": self.final_embed,
      }

In [None]:
refer = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float32).to(device)

base = CLIPWithHooks(copy.deepcopy(refer)).eval().to(device)

vision_model = CLIPVisionModel.from_pretrained('tanganke/clip-vit-base-patch32_mnist')
f_t = copy.deepcopy(refer)
f_t.vision_model.load_state_dict(vision_model.vision_model.state_dict())
fine_tuned = CLIPWithHooks(f_t).eval().to(device)

In [None]:
# Calculates Least Squares
def leastSquares(Z0, Z1):
    W_full, residuals, rank, s = np.linalg.lstsq(Z0, Z1, rcond=None)
    return W_full

def addBiasColumn(Z0):
    ones = np.ones((Z0.shape[0], 1)) # (1,1)
    return np.hstack([Z0, ones])

def retrieve_avg(Z0, Z1):
    if Affine:
        Z0 = addBiasColumn(Z0)
        full = leastSquares(Z0, Z1)
        W = full[:-1]
        b = full[-1]
        return W, b
    if Transformation_Matrix:
        return leastSquares(Z0, Z1), None
    if Translation_Vector:
        Z0 = np.array(Z0)
        Z1 = np.array(Z1)
        diff = Z1 - Z0
        return np.mean(diff, axis=0), None
    return None, None

In [None]:
Z0 = {}
for i in range(12):
  Z0[i] = []

Z1_11 = []

with torch.no_grad():
  for batch in tqdm(train_loader, desc=f"Extracting Train Dataset Vectors"):
    out_base = base(batch["pixel_values"])
    out_fine_tuned = fine_tuned(batch["pixel_values"])

    for i in range(12):
      Z0[i].append(out_base["cls"][i].float().cpu())

    Z1_11.append(out_fine_tuned["cls"][11].float().cpu())

Z1_11 = torch.cat(Z1_11)
Z1_11 = Z1_11.cpu().numpy()

W = {}
b = {}

for key, value in Z0.items():
    value = torch.cat(value)
    value = value.cpu().numpy()
    W[key], b[key] = retrieve_avg(value, Z1_11)

In [None]:
class AugmentedCLIP(nn.Module):
  def __init__(self, clip, W=None, b=None, transform_stage=-1):
    super().__init__()
    self.clip = clip # Augmented Model
    self.logit_scale = clip.logit_scale
    self.W = torch.from_numpy(W.astype(np.float32)).to(device) if W is not None else None
    self.b = torch.from_numpy(b.astype(np.float32)).to(device) if b is not None else None
    self.transform_stage = transform_stage

  def encode_image(self, image):
    x = self.clip.vision_model.embeddings(image)
    x = self.clip.vision_model.pre_layrnorm(x)
    x = x.permute(1,0,2) # [sequence_length, batch_size, embedding_dim] -> [50, 64, 768]

    for i, block in enumerate(self.clip.vision_model.encoder.layers):
      residual = x
      x = block(hidden_states=x, attention_mask=None, causal_attention_mask=None)[0]

      if i == self.transform_stage:
        cls = x[0, :, :]
        cls = cls.to(torch.float32)

        if self.W is None:
          self.W = torch.eye(cls.shape[-1], device=cls.device, dtype=cls.dtype)
        if self.b is None:
          self.b = torch.zeros(cls.shape[-1], device=cls.device, dtype=cls.dtype)
        manipulated = cls @ self.W + self.b
        manipulated = manipulated.unsqueeze(0)

        x = torch.cat(
            [manipulated, x[1:, :, :]],
            dim=0
        )
        break

    twelfth_cls = x[0, :, :]

    x = x.permute(1,0,2) # [batch_size, sequence_length, embedding_dim] -> [64, 50, 768]

    x = self.clip.vision_model.post_layernorm(x)

    pooled_output = x[:, 0, :]

    if self.clip.visual_projection is not None:
      image_embeds = self.clip.visual_projection(pooled_output)
    else:
      image_embeds = pooled_output

    return image_embeds, twelfth_cls # Proj, Final CLS Embedding

  def encode_text(self, text):
    return self.clip.get_text_features(input_ids=text["input_ids"], attention_mask=text["attention_mask"])

  def forward(self, image, text):
    image_features, final_embed = self.encode_image(image)
    text_features = self.encode_text(text)

    # normalized features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    # cosine similarity as logits
    logit_scale = self.logit_scale.exp()
    logits_per_image = logit_scale * image_features @ text_features.t()
    logits_per_text = logits_per_image.t()

    # shape = [global_batch_size, global_batch_size]
    return {
        "logits": (logits_per_image, logits_per_text),
        "last_embed": final_embed,
    }

In [None]:
base = copy.deepcopy(refer)
base = base.eval().to(device)

f_t = copy.deepcopy(refer)
f_t.vision_model.load_state_dict(vision_model.vision_model.state_dict())
# fine_tuned = AugmentedCLIP(f_t,transform_stage=-1)
fine_tuned = f_t.eval().to(device)

aug = {}

for i in range(12):
  model = AugmentedCLIP(copy.deepcopy(refer), W=W[i], b=b[i], transform_stage=i)
  model = model.eval().to(device)
  aug[i] = model

In [None]:
print(isinstance(fine_tuned, AugmentedCLIP))

In [None]:
def calcPred(model, images, text_inputs):
  if isinstance(model, AugmentedCLIP):
    outputs = model(images, text_inputs)
    logits_per_image = outputs["logits"][0]
  else:
    image_embeds = model.get_image_features(pixel_values=images)
    image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
    logit_scale = model.logit_scale.exp()
    logits_per_image = logit_scale * image_embeds @ text_feats.T

  pred = logits_per_image.argmax(dim=1)
  return pred

def cosineSimilarity(aug, fine, images, text_inputs):
    with torch.no_grad():
      aug_cls = aug(images, text_inputs)["last_embed"]
      fine_tuned_cls = fine(images, text_inputs)["last_embed"]

    # Normalize and compute similarity
    eps = 1e-8
    out_aug_cls = F.normalize(aug_cls, dim=1, eps=eps)
    out_fine_cls = F.normalize(fine_tuned_cls, dim=1, eps=eps)

    return (out_aug_cls * out_fine_cls).sum(dim=1).mean().item()

In [None]:
correct_base = 0
correct_fine_tuned = 0
total_samples = 0

correct = {}
sim_cls = {}

for i in range(12):
  correct[i] = 0
  sim_cls[i] = []

with torch.no_grad():
  for batch in tqdm(test_loader, desc=f"Evaluating"):
    images = batch["pixel_values"].to(device)
    labels = batch["labels"].to(device)
    total_samples += labels.size(0)

    correct_base += (calcPred(base, images, text_inputs) == labels).sum().item()
    correct_fine_tuned += (calcPred(fine_tuned, images, text_inputs) == labels).sum().item()

    for i in range(12):
        correct[i] += (calcPred(aug[i], images, text_inputs) == labels).sum().item()
        # sim_cls[i].append(cosineSimilarity(aug[i], fine_tuned, images, text_inputs))

In [None]:
for i in range(12):
  correct[i] = correct[i] / total_samples
  sim_cls[i] = np.mean(sim_cls[i])

correct_base = correct_base / total_samples
correct_fine_tuned = correct_fine_tuned / total_samples

In [None]:
print(f"Augmented CLIP on Entire MNIST Results")
for i in range(12):
    print(f"\tAugmented {i+1} - Last (12) Layer Accuracy: {correct[i]}")
    # print(f"\tAverage Cosine Similarity of CLS Token of Augmented {i+1} Layer: {sim_cls[i]:.4f}")
print(f"Base Accuracy: {correct_base:.4f}")
print(f"Fine-Tuned Accuracy: {correct_fine_tuned:.4f}")

Issue is that the logic of matching within AugmentedCLIP and possibly CLIPWithHooks is incorrect. Without wrapping it, the fine-tuned does achieve correct fine-tuned accuracy. 