##1. Installations

In [None]:
!pip install torch torchvision
!pip install -U transformers datasets
!pip install fifty regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install matplotlib
!pip install -U pillow
%matplotlib inline

1.1 Possible Installs - Runpod only

In [None]:
!pip install --force-reinstall --no-cache-dir scipy datasets # Only needed within runpod environment

In [None]:
!pip install numpy==1.26.4 # only needed for runpod environment

##2. Imports

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import clip
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

##3. Setting up Device + Test Set

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
mnist = load_dataset("ylecun/mnist") # https://huggingface.co/datasets/ylecun/mnist
test_dataset = mnist["test"] # 10,000 examples (direct test set)

In [3]:
test_dataset.set_format(type="python", columns=["image", "label"])

##4. Wrapping Models & Prepping Test Set

In [4]:
class CLIPClassifier(nn.Module):
  def __init__(self, clip_model, num_classes=10):
    super().__init__()
    self.clip = clip_model
    self.classifier = nn.Linear(self.clip.visual.output_dim, num_classes)

  def forward(self, images):
    image_features = self.clip.encode_image(images)
    logits = self.classifier(image_features)
    return logits

In [5]:
base_CLIP, preprocess = clip.load("ViT-B/32", device=device)
model = CLIPClassifier(clip_model=base_CLIP).to(device)
model = model.float()

best_CLIP, _ = clip.load("ViT-B/32", device=device)
best_CLIP_MNIST = CLIPClassifier(clip_model=best_CLIP).to(device)
best_CLIP_MNIST = best_CLIP_MNIST.float()
best_CLIP_MNIST.load_state_dict(torch.load("best_clip_mnist.pt", map_location=device)) # map_location tells where to place the model's weights in memory

model.eval()
best_CLIP_MNIST.eval()

CLIPClassifier(
  (clip): CLIP(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantiza

In [6]:
def clip_collate_fn(batch):
  images = []
  labels = []

  for item in batch:
    img = item["image"].convert("RGB")  # Already a PIL Image
    img = preprocess(img)
    images.append(img)
    labels.append(item["label"])

  images = torch.stack(images)
  labels = torch.tensor(labels, dtype=torch.long)

  return {
      "pixel_values": images.to(device),
      "labels": labels.to(device)
  }

test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

In [7]:
class CLIPWithHooks(nn.Module):
  def __init__(self, clip_model):
    super().__init__()
    self.clip = clip_model

    self.first_cls = None
    self.last_cls = None

    self.clip.visual.transformer.resblocks[0].register_forward_hook(self.save_first_cls)
    self.clip.visual.transformer.resblocks[-1].register_forward_hook(self.save_last_cls)

  def save_first_cls(self, module, input, output):
    # input[0] is the full sequence: [batch, seq_len, dim]
    self.first_cls = output[:, 0, :].detach() # CLS token

  def save_last_cls(self, module, input, output):
    self.last_cls = output[:, 0, :].detach() # CLS token

  def forward(self, images):
    self.first_cls = None
    self.last_cls = None

    # B is batch
    x = self.clip.visual.conv1(images)  # Convert image into patch embeddings. Divided into 32*32 patches. Shape is [B, 768, 7, 7]. Each 32*32 batch becomes a 768 dimensional vector. For 224*224 input, get 7*7=49 patches. Now have 49 such vectors per image.
    x = x.flatten(2).transpose(1,2) # -> [B, 768, 49] -> [B, 49, 768]; Each image is a sequence of 49 token vectors each of size 768, ready for the transformer.
    x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # Adds learnable CLS token at the start of every image's token sequence. [1,768] -> [B,1,768] -> [B, 50, 768]
    x = x + self.clip.visual.positional_embedding.to(x.dtype) # Adds positional information so transformer knows order and position. [B, 50, 768] + [1, 50, 768]
    x = self.clip.visual.ln_pre(x) # Normalize to stablize it

    # Run resblocks manually, so hooks definitely trigger
    for i, resblock in enumerate(self.clip.visual.transformer.resblocks):
        x = resblock(x)
        if i == 0: # First layer
            self.first_cls = x[:, 0, :].detach()
        if i == len(self.clip.visual.transformer.resblocks) - 1: # Last layer
            self.last_cls = x[:, 0, :].detach()

    x = self.clip.visual.ln_post(x[:, 0, :])

    final_embed = x @ self.clip.visual.proj # Linear Projection from 768 CLS token to 512 dimension vector for compatability

    return {
        "first_cls": self.first_cls,
        "last_cls": self.last_cls,
        "final_embed": final_embed.detach()
    }

In [8]:
wrapped_base = CLIPWithHooks(base_CLIP)
wrapped_best = CLIPWithHooks(best_CLIP_MNIST.clip)

base_CLIP.eval()
best_CLIP_MNIST.eval()

def setNoGrad(model):
  for param in model.parameters():
    param.requires_grad = False

setNoGrad(base_CLIP)
setNoGrad(best_CLIP_MNIST)

In [10]:
def print_module_dtypes(model): # Sanity Check for fp32
    param_dtypes = {p.dtype for p in model.parameters()}
    buffer_dtypes = {b.dtype for b in model.buffers()}
    print(f"Parameter dtypes: {param_dtypes}")
    print(f"Buffer    dtypes: {buffer_dtypes}")

# Example usage:
print_module_dtypes(wrapped_base)
print_module_dtypes(wrapped_best)

Parameter dtypes: {torch.float32}
Buffer    dtypes: set()
Parameter dtypes: {torch.float32}
Buffer    dtypes: set()


In [12]:
Z0_first = [] # base model's first-layer embedding (CLS)
Z0_last = [] # base model's last-layer embedding (CLS)
Z0_final = [] # base model's final image embedding

Z1_best_first = [] # best finetuned model's first-layer embedding (CLS)
Z1_best_last = [] # best finetuned model's last-layer embedding (CLS)
Z1_best_final = [] # best finetuned model's final image embedding

with torch.no_grad():
  for i, batch in enumerate(tqdm(test_loader, desc="Extracting")):
    images = batch["pixel_values"]
      
    out_base = wrapped_base(images)
    out_best = wrapped_best(images)

    Z0_first.append(out_base["first_cls"].float())
    Z0_last.append(out_base["last_cls"].float())
    Z0_final.append(out_base["final_embed"].float())

    Z1_best_first.append(out_best["first_cls"].float())
    Z1_best_last.append(out_best["last_cls"].float())
    Z1_best_final.append(out_best["final_embed"].float())

Z0_first = torch.cat(Z0_first)  # shape: [N, D]
Z0_last = torch.cat(Z0_last)
Z0_final = torch.cat(Z0_final)

Z1_best_first = torch.cat(Z1_best_first)
Z1_best_last = torch.cat(Z1_best_last)
Z1_best_final = torch.cat(Z1_best_final)

Extracting: 100%|██████████| 157/157 [01:54<00:00,  1.37it/s]


In [13]:
torch.save({
    "base_first": Z0_first,
    "base_last": Z0_last,
    "base_final": Z0_final,
    "fine_tuned_first": Z1_best_first,
    "fine_tuned_last": Z1_best_last,
    "fine_tuned_final": Z1_best_final
}, "embedding_pairs.pt")

# Code to load in
# pairs = torch.load("embedding_pairs.pt", map_location=device)
# Z0_first = pairs["base_first"]
# Z0_last = pairs["base_last"]
# Z0_final = pairs["base_final"]
# Z1_best_first = pairs["fine_tuned_first"]
# Z1_best_last = pairs["fine_tuned_last"]
# Z1_best_final = pairs["fine_tuned_final"]

##5. Calculating Linear Transformation Matrix and Bias

In [14]:
Z0_first_np = Z0_first.cpu().numpy()
Z0_last_np = Z0_last.cpu().numpy()
Z0_final_np = Z0_final.cpu().numpy()

Z1_best_first_np = Z1_best_first.cpu().numpy()
Z1_best_last_np = Z1_best_last.cpu().numpy()
Z1_best_final_np = Z1_best_final.cpu().numpy()

In [15]:
# Add constant 1 column for bias term
def constant1_full(Z0):
  ones = np.ones((Z0.shape[0], 1)) # (1,1)
  return np.hstack([Z0, ones])

Z0_first_aug = constant1_full(Z0_first_np)
Z0_last_aug = constant1_full(Z0_last_np)
Z0_final_aug = constant1_full(Z0_final_np)

In [16]:
# Calculate Bias b and Linear Transformation Matrix W
def leastSquares(Z0, Z1):
  W_full, residuals, rank, s = np.linalg.lstsq(Z0, Z1, rcond=None)
  return W_full

W_full = leastSquares(Z0_first_aug, Z1_best_first_np)  # base_first -> fine-tuned first
W_0_1 = W_full[:-1]
b_0_1 = W_full[-1]

W_full = leastSquares(Z0_first_aug, Z1_best_last_np)   # base_first -> fine-tuned last
W_1_1 = W_full[:-1]
b_1_1 = W_full[-1]

W_full = leastSquares(Z0_first_aug, Z1_best_final_np)  # base_first -> fine-tuned final
W_2_1 = W_full[:-1]
b_2_1 = W_full[-1]

W_full = leastSquares(Z0_last_aug, Z1_best_last_np)    # base_last -> fine-tuned last
W_3_1 = W_full[:-1]
b_3_1 = W_full[-1]

W_full = leastSquares(Z0_last_aug, Z1_best_final_np)   # base_last -> fine-tuned final
W_4_1 = W_full[:-1]
b_4_1 = W_full[-1]

W_full = leastSquares(Z0_final_aug, Z1_best_final_np)  # base_final -> fine-tuned final
W_5_1 = W_full[:-1]
b_5_1 = W_full[-1]

In [17]:
torch.save({
    "W_first_first": W_0_1,
    "b_first_first": b_0_1,
    "W_first_last": W_1_1,
    "b_first_last": b_1_1,
    "W_first_final": W_2_1,
    "b_first_final": b_2_1,
    "W_last_last": W_3_1,
    "b_last_last": b_3_1,
    "W_last_final": W_4_1,
    "b_last_final": b_4_1,
    "W_final_final": W_5_1,
    "b_final_final": b_5_1
}, "clip_mnist_transformations.pt")

# Code to Load:
# vals = torch.load("clip_mnist_transformations.pt")
# W_0_1 = vals["W_first_first"]
# b_0_1 = vals["b_first_first"]
# W_1_1 = vals["W_first_last"]
# b_1_1 = vals["b_first_last"]
# W_2_1 = vals["W_first_final"]
# b_2_1 = vals["b_first_final"]
# W_3_1 = vals["W_last_last"]
# b_3_1 = vals["b_last_last"]
# W_4_1 = vals["W_last_final"]
# b_4_1 = vals["b_last_final"]
# W_5_1 = vals["W_final_final"]
# b_5_1 = vals["b_final_final"]

##6. Augmenting Base CLIP

In [24]:
class TransformedCLIP(nn.Module):
  def __init__(self, clip_model, W=None, b=None, transform_stage="", add_classifier=False):
    super().__init__()
    self.clip = clip_model
    self.W = torch.from_numpy(W.astype(np.float32)).to(device) if W is not None else None
    self.b = torch.from_numpy(b.astype(np.float32)).to(device) if b is not None else None
    self.transform_stage = transform_stage
    # base_first -> fine-tuned_first
    # base_first -> fine-tuned_last
    # base_first -> fine-tuned_final
    # base_last -> fine-tuned_last
    # base_last -> fine-tuned_final
    # base_final -> fine-tuned_final
    self.add_classifier = add_classifier
    if self.add_classifier:
        self.classifier = nn.Linear(512, 10) # Mapping 512 image embeddings to 10 for (0-9 MNIST Classification)

  def customPass(self, image):
    image = image.to(device) # [B, 3, 224, 224]
    x = self.clip.visual.conv1(image) # [B, 768, 7, 7]
    x = x.reshape(x.shape[0], x.shape[1], -1) # [B, width, 49]
    x = x.permute(0, 2, 1) # [B, 49, 768]

    cls_token = self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
    x = torch.cat([cls_token, x], dim=1)  # [B, 50, 768]
    x = x + self.clip.visual.positional_embedding.to(x.dtype)
    x = self.clip.visual.ln_pre(x)  # [B, 50, 768]

    x = x.permute(1,0,2) # [50, B, 768] # NLD -> LND
    for i, block in enumerate(self.clip.visual.transformer.resblocks):
      x = block(x)
      if i == 0 and (self.transform_stage in {"first-first", "first-last", "first-final"}): # After first transformer block
        if self.transform_stage == "first-final":  # Need to project first layer into final embedding space (512-d vector) before manipulation
          # x = x.permute(1,0,2) # [B, 50, 768]; LND -> NLD
          # cls_token = x[:,0,:]
          # cls_token = self.clip.visual.ln_post(cls_token)

          # if self.clip.visual.proj is not None:
          #   cls_token = cls_token @ self.clip.visual.proj.to(x.dtype)
          # else:
          #   raise ValueError("Projection layer is None. Expected [768, 512]")

          # cls_token = cls_token @ self.W + self.b

          # return cls_token
          cls = x[0, :, :] # [B, 768]; Raw CLS after first block (768 dimensions)
          cls = cls.to(torch.float32)
          out = cls @ self.W + self.b # [B, 512] # Apply learned 768-> 512 W and bias
          return out

        # For first-first
        cls_token = x[0,:,:] # [B, 768]
        cls_token = cls_token.to(torch.float32)
        cls_token = cls_token @ self.W + self.b
        x = torch.cat([cls_token.unsqueeze(0), x[1:]], dim=0) # Replace CLS Token

        if self.transform_stage == "first-last":
          break

      if i == len(self.clip.visual.transformer.resblocks) - 1 and (self.transform_stage == "last-last" or self.transform_stage == "last-final"):
        if self.transform_stage == "last-final": # Need to project last layer into final embedding space (512-d vector) before manipulation
          #  x = x.permute(1,0,2) # [B, 50, 768]; LND -> NLD
          #  cls_token = x[:,0,:]
          #  cls_token = self.clip.visual.ln_post(cls_token) 

          #  if self.clip.visual.proj is not None:
          #    cls_token = cls_token @ self.clip.visual.proj.to(x.dtype)
          #  else:
          #    raise ValueError("Projection layer is None. Expected [768, 512]")

          #  cls_token = cls_token @ self.W + self.b

          #  return cls_token
          cls = x[0, :, :] # [B, 768]; Raw CLS from last block (768 dimensions)
          cls = cls.to(torch.float32)
          out = cls @ self.W + self.b # [B, 512]
          return out

        # For last-last
        cls_token = x[0,:,:] # [B, 768]
        cls_token = cls_token.to(torch.float32)
        cls_token = cls_token @ self.W + self.b
        x = torch.cat([cls_token.unsqueeze(0), x[1:]], dim=0) # Replaces CLS Token and adds it to shape of this [49, B, 768] <- all tokens after CLS (excluding CLS) -> [50, B, 768]

    x = x.permute(1,0,2) # [B, 50, 768]; LND -> NLD
    x = self.clip.visual.ln_post(x[:, 0, :])

    if self.clip.visual.proj is not None:
      x = x @ self.clip.visual.proj

    return x

  def forward(self, image):
    image = image.to(device)
    if self.transform_stage == "final-final":
      image_embed = self.clip.encode_image(image)
      image_embed = image_embed @ self.W.T + self.b
    elif self.transform_stage in {"first-first", "first-last", "first-final", "last-last", "last-final"}:
      image_embed = self.customPass(image)
    else: # From CLIP GitHub -> model.py @ 359. https://github.com/openai/CLIP/blob/main/clip/model.py
      image_embed = self.clip.encode_image(image)

    # Don't need to do normalization from original code. Need to preserve scale information to accurately fit W, b. Keep the true structural nature
    # Can normalize if evaluating cosine similarities between embeddings. Normalize both base and fine-tuned if doign so before computing similarity/dot product.
    # Just doing classification so no need for text

    if self.add_classifier:
        logits = self.classifier(image_embed)
        return logits
    else:
        return image_embed

In [25]:
CLIP_aug_first_first, _ = clip.load("ViT-B/32", device=device)
CLIP_aug_first_first = CLIP_aug_first_first.float()
CLIP_aug_first_first = TransformedCLIP(CLIP_aug_first_first, W_0_1, b_0_1, transform_stage="first-first", add_classifier=True).to(device)
CLIP_aug_first_first.eval()

CLIP_aug_first_last, _ = clip.load("ViT-B/32", device=device)
CLIP_aug_first_last = CLIP_aug_first_last.float()
CLIP_aug_first_last = TransformedCLIP(CLIP_aug_first_last, W_1_1, b_1_1, transform_stage="first-last", add_classifier=True).to(device)
CLIP_aug_first_last.eval()

CLIP_aug_first_final, _ = clip.load("ViT-B/32", device=device)
CLIP_aug_first_final = CLIP_aug_first_final.float()
CLIP_aug_first_final = TransformedCLIP(CLIP_aug_first_final, W_2_1, b_2_1, transform_stage="first-final", add_classifier=True).to(device)
CLIP_aug_first_final.eval()

CLIP_aug_last_last, _ = clip.load("ViT-B/32", device=device)
CLIP_aug_last_last = CLIP_aug_last_last.float()
CLIP_aug_last_last.eval()
CLIP_aug_last_last = TransformedCLIP(CLIP_aug_last_last, W_3_1, b_3_1, transform_stage="last-last", add_classifier=True).to(device)

CLIP_aug_last_final, _ = clip.load("ViT-B/32", device=device)
CLIP_aug_last_final = CLIP_aug_last_final.float()
CLIP_aug_last_final = TransformedCLIP(CLIP_aug_last_final, W_4_1, b_4_1, transform_stage="last-final", add_classifier=True).to(device)
CLIP_aug_last_final.eval()

CLIP_aug_final_final, _ = clip.load("ViT-B/32", device=device)
CLIP_aug_final_final = CLIP_aug_final_final.float()
CLIP_aug_final_final = TransformedCLIP(CLIP_aug_final_final, W_5_1, b_5_1, transform_stage="final-final", add_classifier=True).to(device)
CLIP_aug_final_final.eval()

CLIP_base, _ = clip.load("ViT-B/32", device=device)
CLIP_base = CLIP_base.float()
CLIP_base = TransformedCLIP(CLIP_base, add_classifier=True).to(device)
CLIP_base.eval()

CLIP_fine_tuned, _ = clip.load("ViT-B/32", device=device)
CLIP_fine_tuned = CLIP_fine_tuned.float()
CLIP_fine_tuned = CLIPClassifier(clip_model=CLIP_fine_tuned).to(device)
CLIP_fine_tuned.load_state_dict(torch.load("best_clip_mnist.pt", map_location=device))
CLIP_fine_tuned.eval()

CLIPClassifier(
  (clip): CLIP(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantiza

##7. Evaluating Base v Fine-Tuned vs Augmented

In [26]:
correct_first_first = 0
correct_first_last = 0
correct_first_final = 0
correct_last_last = 0
correct_last_final = 0
correct_final_final = 0
correct_base = 0
correct_fine_tuned = 0

total_samples = 0

with torch.no_grad():
  for batch in tqdm(test_loader, desc="Testing"):
    images = batch["pixel_values"].to(device)
    labels = batch["labels"].to(device)
    total_samples += labels.size(0)

    # First-First Augmented Model
    logits_first_first = CLIP_aug_first_first(images)
    pred_first_first = logits_first_first.argmax(dim=1)
    correct_first_first += (pred_first_first == labels).sum().item()

    # First-Last Augmented Model
    logits_first_last = CLIP_aug_first_last(images)
    pred_first_last = logits_first_last.argmax(dim=1)
    correct_first_last += (pred_first_last == labels).sum().item()

    # First-Final Augmented Model
    logits_first_final = CLIP_aug_first_final(images)
    pred_first_final = logits_first_final.argmax(dim=1)
    correct_first_final += (pred_first_final == labels).sum().item()

    # Last-Last Augmented Model
    logits_last_last = CLIP_aug_last_last(images)
    pred_last_last = logits_last_last.argmax(dim=1)
    correct_last_last += (pred_last_last == labels).sum().item()

    # Last-Final Augmented Model
    logits_last_final = CLIP_aug_last_final(images)
    pred_last_final = logits_last_final.argmax(dim=1)
    correct_last_final += (pred_last_final == labels).sum().item()

    # Final-Final Augmented Model
    logits_final_final = CLIP_aug_final_final(images)
    pred_final_final = logits_final_final.argmax(dim=1)
    correct_final_final += (pred_final_final == labels).sum().item()

    # Base Model
    logits_base = CLIP_base(images)
    pred_base = logits_base.argmax(dim=1)
    correct_base += (pred_base == labels).sum().item()

    # Fine-Tuned Model
    logits_fine_tuned = CLIP_fine_tuned(images)
    pred_fine_tuned = logits_fine_tuned.argmax(dim=1)
    correct_fine_tuned += (pred_fine_tuned == labels).sum().item()

first_first_acc = correct_first_first / total_samples
first_last_acc = correct_first_last / total_samples
first_final_acc = correct_first_final / total_samples
last_last_acc = correct_last_last / total_samples
last_final_acc = correct_last_final / total_samples
final_final_acc = correct_final_final / total_samples
base_acc = correct_base / total_samples
fine_tuned_acc = correct_fine_tuned / total_samples

print(f"\nAugmented First-First Accuracy: {first_first_acc:.4f}")
print(f"Augmented First-Last Accuracy: {first_last_acc:.4f}")
print(f"Augmented First-Final Accuracy: {first_final_acc:.4f}")
print(f"Augmented Last-Last Accuracy: {last_last_acc:.4f}")
print(f"Augmented Last-Final Accuracy: {last_final_acc:.4f}")
print(f"Augmented Final-Final Accuracy: {final_final_acc:.4f}")
print(f"Base Accuracy: {base_acc:.4f}")
print(f"Fine-Tuned Accuracy: {fine_tuned_acc:.4f}")

Testing: 100%|██████████| 157/157 [03:56<00:00,  1.50s/it]


Augmented First-First Accuracy: 0.1010
Augmented First-Last Accuracy: 0.1135
Augmented First-Final Accuracy: 0.1009
Augmented Last-Last Accuracy: 0.0958
Augmented Last-Final Accuracy: 0.1009
Augmented Final-Final Accuracy: 0.0982
Base Accuracy: 0.0909
Fine-Tuned Accuracy: 0.9961





In [34]:
# For fine-tuned classifier head
class TransformedCLIP1(nn.Module):
  def __init__(self, clip_model, W=None, b=None, transform_stage="", classifier_head=None):
    super().__init__()
    self.clip = clip_model
    self.W = torch.from_numpy(W.astype(np.float32)).to(device) if W is not None else None
    self.b = torch.from_numpy(b.astype(np.float32)).to(device) if b is not None else None
    self.transform_stage = transform_stage
    # base_first -> fine-tuned_first
    # base_first -> fine-tuned_last
    # base_first -> fine-tuned_final
    # base_last -> fine-tuned_last
    # base_last -> fine-tuned_final
    # base_final -> fine-tuned_final
    if classifier_head is not None:
        self.classifier = classifier_head
    else:
        self.classifier = None

  def customPass(self, image):
    image = image.to(device) # [B, 3, 224, 224]
    x = self.clip.visual.conv1(image) # [B, 768, 7, 7]
    x = x.reshape(x.shape[0], x.shape[1], -1) # [B, width, 49]
    x = x.permute(0, 2, 1) # [B, 49, 768]

    cls_token = self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
    x = torch.cat([cls_token, x], dim=1)  # [B, 50, 768]
    x = x + self.clip.visual.positional_embedding.to(x.dtype)
    x = self.clip.visual.ln_pre(x)  # [B, 50, 768]

    x = x.permute(1,0,2) # [50, B, 768] # NLD -> LND
    for i, block in enumerate(self.clip.visual.transformer.resblocks):
      x = block(x)
      if i == 0 and (self.transform_stage in {"first-first", "first-last", "first-final"}): # After first transformer block
        if self.transform_stage == "first-final":  # Need to project first layer into final embedding space (512-d vector) before manipulation
          # x = x.permute(1,0,2) # [B, 50, 768]; LND -> NLD
          # cls_token = x[:,0,:]
          # cls_token = self.clip.visual.ln_post(cls_token)

          # if self.clip.visual.proj is not None:
          #   cls_token = cls_token @ self.clip.visual.proj.to(x.dtype)
          # else:
          #   raise ValueError("Projection layer is None. Expected [768, 512]")

          # cls_token = cls_token @ self.W + self.b

          # return cls_token
          cls = x[0, :, :] # [B, 768]; Raw CLS after first block (768 dimensions)
          cls = cls.to(torch.float32)
          out = cls @ self.W + self.b # [B, 512] # Apply learned 768-> 512 W and bias
          return out

        # For first-first
        cls_token = x[0,:,:] # [B, 768]
        cls_token = cls_token.to(torch.float32)
        cls_token = cls_token @ self.W + self.b
        x = torch.cat([cls_token.unsqueeze(0), x[1:]], dim=0) # Replace CLS Token

        if self.transform_stage == "first-last":
          break

      if i == len(self.clip.visual.transformer.resblocks) - 1 and (self.transform_stage == "last-last" or self.transform_stage == "last-final"):
        if self.transform_stage == "last-final": # Need to project last layer into final embedding space (512-d vector) before manipulation
          #  x = x.permute(1,0,2) # [B, 50, 768]; LND -> NLD
          #  cls_token = x[:,0,:]
          #  cls_token = self.clip.visual.ln_post(cls_token) 

          #  if self.clip.visual.proj is not None:
          #    cls_token = cls_token @ self.clip.visual.proj.to(x.dtype)
          #  else:
          #    raise ValueError("Projection layer is None. Expected [768, 512]")

          #  cls_token = cls_token @ self.W + self.b

          #  return cls_token
          cls = x[0, :, :] # [B, 768]; Raw CLS from last block (768 dimensions)
          cls = cls.to(torch.float32)
          out = cls @ self.W + self.b # [B, 512]
          return out

        # For last-last
        cls_token = x[0,:,:] # [B, 768]
        cls_token = cls_token.to(torch.float32)
        cls_token = cls_token @ self.W + self.b
        x = torch.cat([cls_token.unsqueeze(0), x[1:]], dim=0) # Replaces CLS Token and adds it to shape of this [49, B, 768] <- all tokens after CLS (excluding CLS) -> [50, B, 768]

    x = x.permute(1,0,2) # [B, 50, 768]; LND -> NLD
    x = self.clip.visual.ln_post(x[:, 0, :])

    if self.clip.visual.proj is not None:
      x = x @ self.clip.visual.proj

    return x

  def forward(self, image):
    image = image.to(device)
    if self.transform_stage == "final-final":
      image_embed = self.clip.encode_image(image)
      image_embed = image_embed @ self.W + self.b
    elif self.transform_stage in {"first-first", "first-last", "first-final", "last-last", "last-final"}:
      image_embed = self.customPass(image)
    else: # From CLIP GitHub -> model.py @ 359. https://github.com/openai/CLIP/blob/main/clip/model.py
      image_embed = self.clip.encode_image(image)

    # Don't need to do normalization from original code. Need to preserve scale information to accurately fit W, b. Keep the true structural nature
    # Can normalize if evaluating cosine similarities between embeddings. Normalize both base and fine-tuned if doign so before computing similarity/dot product.
    # Just doing classification so no need for text

    if self.classifier is not None:
        return self.classifier(image_embed)
    else:
        return image_embed

In [35]:
CLIP_base, _ = clip.load("ViT-B/32", device=device)
CLIP_base = CLIP_base.float()
CLIP_base = TransformedCLIP1(CLIP_base, classifier_head=nn.Linear(512, 10)).to(device)
CLIP_base.eval()

CLIP_fine_tuned, _ = clip.load("ViT-B/32", device=device)
CLIP_fine_tuned = CLIP_fine_tuned.float()
CLIP_fine_tuned = CLIPClassifier(clip_model=CLIP_fine_tuned).to(device)
CLIP_fine_tuned.load_state_dict(torch.load("best_clip_mnist.pt", map_location=device))
CLIP_fine_tuned.eval()

CLIP_aug_first_first, _ = clip.load("ViT-B/32", device=device)
CLIP_aug_first_first = CLIP_aug_first_first.float()
CLIP_aug_first_first = TransformedCLIP1(CLIP_aug_first_first, W_0_1, b_0_1, transform_stage="first-first", classifier_head = CLIP_fine_tuned.classifier).to(device)
CLIP_aug_first_first.eval()

CLIP_aug_first_last, _ = clip.load("ViT-B/32", device=device)
CLIP_aug_first_last = CLIP_aug_first_last.float()
CLIP_aug_first_last = TransformedCLIP1(CLIP_aug_first_last, W_1_1, b_1_1, transform_stage="first-last", classifier_head = CLIP_fine_tuned.classifier).to(device)
CLIP_aug_first_last.eval()

CLIP_aug_first_final, _ = clip.load("ViT-B/32", device=device)
CLIP_aug_first_final = CLIP_aug_first_final.float()
CLIP_aug_first_final = TransformedCLIP1(CLIP_aug_first_final, W_2_1, b_2_1, transform_stage="first-final", classifier_head = CLIP_fine_tuned.classifier).to(device)
CLIP_aug_first_final.eval()

CLIP_aug_last_last, _ = clip.load("ViT-B/32", device=device)
CLIP_aug_last_last = CLIP_aug_last_last.float()
CLIP_aug_last_last.eval()
CLIP_aug_last_last = TransformedCLIP1(CLIP_aug_last_last, W_3_1, b_3_1, transform_stage="last-last", classifier_head = CLIP_fine_tuned.classifier).to(device)

CLIP_aug_last_final, _ = clip.load("ViT-B/32", device=device)
CLIP_aug_last_final = CLIP_aug_last_final.float()
CLIP_aug_last_final = TransformedCLIP1(CLIP_aug_last_final, W_4_1, b_4_1, transform_stage="last-final", classifier_head = CLIP_fine_tuned.classifier).to(device)
CLIP_aug_last_final.eval()

CLIP_aug_final_final, _ = clip.load("ViT-B/32", device=device)
CLIP_aug_final_final = CLIP_aug_final_final.float()
CLIP_aug_final_final = TransformedCLIP1(CLIP_aug_final_final, W_5_1, b_5_1, transform_stage="final-final", classifier_head = CLIP_fine_tuned.classifier).to(device)
CLIP_aug_final_final.eval()

TransformedCLIP1(
  (clip): CLIP(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuanti

In [36]:
correct_first_first = 0
correct_first_last = 0
correct_first_final = 0
correct_last_last = 0
correct_last_final = 0
correct_final_final = 0
correct_base = 0
correct_fine_tuned = 0

total_samples = 0

with torch.no_grad():
  for batch in tqdm(test_loader, desc="Testing"):
    images = batch["pixel_values"].to(device)
    labels = batch["labels"].to(device)
    total_samples += labels.size(0)

    # First-First Augmented Model
    logits_first_first = CLIP_aug_first_first(images)
    pred_first_first = logits_first_first.argmax(dim=1)
    correct_first_first += (pred_first_first == labels).sum().item()

    # First-Last Augmented Model
    logits_first_last = CLIP_aug_first_last(images)
    pred_first_last = logits_first_last.argmax(dim=1)
    correct_first_last += (pred_first_last == labels).sum().item()

    # First-Final Augmented Model
    logits_first_final = CLIP_aug_first_final(images)
    pred_first_final = logits_first_final.argmax(dim=1)
    correct_first_final += (pred_first_final == labels).sum().item()

    # Last-Last Augmented Model
    logits_last_last = CLIP_aug_last_last(images)
    pred_last_last = logits_last_last.argmax(dim=1)
    correct_last_last += (pred_last_last == labels).sum().item()

    # Last-Final Augmented Model
    logits_last_final = CLIP_aug_last_final(images)
    pred_last_final = logits_last_final.argmax(dim=1)
    correct_last_final += (pred_last_final == labels).sum().item()

    # Final-Final Augmented Model
    logits_final_final = CLIP_aug_final_final(images)
    pred_final_final = logits_final_final.argmax(dim=1)
    correct_final_final += (pred_final_final == labels).sum().item()

    # Base Model
    logits_base = CLIP_base(images)
    pred_base = logits_base.argmax(dim=1)
    correct_base += (pred_base == labels).sum().item()

    # Fine-Tuned Model
    logits_fine_tuned = CLIP_fine_tuned(images)
    pred_fine_tuned = logits_fine_tuned.argmax(dim=1)
    correct_fine_tuned += (pred_fine_tuned == labels).sum().item()

first_first_acc = correct_first_first / total_samples
first_last_acc = correct_first_last / total_samples
first_final_acc = correct_first_final / total_samples
last_last_acc = correct_last_last / total_samples
last_final_acc = correct_last_final / total_samples
final_final_acc = correct_final_final / total_samples
base_acc = correct_base / total_samples
fine_tuned_acc = correct_fine_tuned / total_samples

print(f"\nAugmented First-First Accuracy: {first_first_acc:.4f}")
print(f"Augmented First-Last Accuracy: {first_last_acc:.4f}")
print(f"Augmented First-Final Accuracy: {first_final_acc:.4f}")
print(f"Augmented Last-Last Accuracy: {last_last_acc:.4f}")
print(f"Augmented Last-Final Accuracy: {last_final_acc:.4f}")
print(f"Augmented Final-Final Accuracy: {final_final_acc:.4f}")
print(f"Base Accuracy: {base_acc:.4f}")
print(f"Fine-Tuned Accuracy: {fine_tuned_acc:.4f}")

Testing: 100%|██████████| 157/157 [03:55<00:00,  1.50s/it]


Augmented First-First Accuracy: 0.1171
Augmented First-Last Accuracy: 0.1135
Augmented First-Final Accuracy: 0.1135
Augmented Last-Last Accuracy: 0.1135
Augmented Last-Final Accuracy: 0.1135
Augmented Final-Final Accuracy: 0.1135
Base Accuracy: 0.0806
Fine-Tuned Accuracy: 0.9961





In [None]:
mnist = load_dataset("ylecun/mnist")
train = mnist["train"]

