## 1. Installations

In [None]:
!pip install torch torchvision
!pip install -U transformers datasets
!pip install fifty regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install matplotlib
!pip install -U pillow
%matplotlib inline

### 1. Possible Installs - Runpod only

In [None]:
!pip install --force-reinstall --no-cache-dir scipy datasets # Only needed within runpod environment

In [None]:
!pip install numpy==1.26.4 # only needed for runpod environment

## 2. Imports

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import clip
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import copy
import torch.nn.functional as F

## 3. Setting up Device + Test Set

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device) # https://github.com/openai/CLIP
clip_model = clip_model.float() # For fp-32 precision
mnist = load_dataset("ylecun/mnist") # https://huggingface.co/datasets/ylecun/mnist
split = mnist["train"].train_test_split(test_size=0.2, seed=66)

train_dataset = split["train"] # 48,000 examples (direct training data from training set)
val_dataset = split["test"] # 12,000 examples (validation set split from training set)
test_dataset = mnist["test"] # 10,000 examples (direct test set)

train_dataset.set_format(type="python", columns=["image", "label"])
val_dataset.set_format(type="python", columns=["image", "label"])
test_dataset.set_format(type="python", columns=["image", "label"])

## 4. Wrapping Models & Prepping Dataset

In [None]:
class CLIPClassifier(nn.Module):
  def __init__(self, clip_model, num_classes=10):
    super().__init__()
    self.clip = clip_model
    self.classifier = nn.Linear(self.clip.visual.output_dim, num_classes)

  def forward(self, images):
    image_features = self.clip.encode_image(images)
    logits = self.classifier(image_features)
    return logits

In [None]:
class CLIPWithHooks(nn.Module):
  def __init__(self, clip_model, classifier_head):
    super().__init__()
    self.clip = clip_model
    self.cls_tokens = []
    self.classifier = classifier_head

  def forward(self, images):
    self.cls_tokens = []

    # B is batch
    x = self.clip.visual.conv1(images)  # Convert image into patch embeddings. Divided into 32*32 patches. Shape is [B, 768, 7, 7]. Each 32*32 batch becomes a 768 dimensional vector. For 224*224 input, get 7*7=49 patches. Now have 49 such vectors per image.
    x = x.reshape(x.shape[0], x.shape[1], -1) # -> [B, 768, 49] -> [B, 49, 768]; Each image is a sequence of 49 token vectors each of size 768, ready for the transformer.
    x = x.permute(0,2,1)

    x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # [B, Patchs+CLS (sequence_length), Embedding Dimension] -> [64, 50 (49+1), 768]
    x = x + self.clip.visual.positional_embedding.to(x.dtype) # Adds positional information so transformer knows order and position. [B, 50, 768] + [1, 50, 768]
    x = self.clip.visual.ln_pre(x) # Normalize to stablize it

    x = x.permute(1,0,2) # [50, 64, 768]

    # Run resblocks manually, so hooks definitely trigger
    for i, resblock in enumerate(self.clip.visual.transformer.resblocks):
        x = resblock(x)
        self.cls_tokens.append(x[0, :, :].detach())

    x = x.permute(1,0,2) # [batch_size, sequence_length, embedding_dim] -> [64, 50, 768]

    x = self.clip.visual.ln_post(x[:, 0, :])

    if self.clip.visual.proj is not None: # Linear Projection from 768 CLS token to 512 dimension vector for compatability
      final_embed = x @ self.clip.visual.proj
      final_embed = final_embed.detach()
    else:
      final_embed = x
      final_embed = final_embed.detach()
    
    logits = self.classifier(final_embed)

    return {
      "first_cls": self.cls_tokens[0],
      "second_cls": self.cls_tokens[1],
      "third_cls": self.cls_tokens[2],
      "fourth_cls": self.cls_tokens[3],
      "fifth_cls": self.cls_tokens[4],
      "sixth_cls": self.cls_tokens[5],
      "seventh_cls": self.cls_tokens[6],
      "eighth_cls": self.cls_tokens[7],
      "ninth_cls": self.cls_tokens[8],
      "tenth_cls": self.cls_tokens[9],
      "eleventh_cls": self.cls_tokens[10],
      "twelfth_cls": self.cls_tokens[11],
      "final_embed": final_embed,
      "logits": logits,
      }

In [None]:
refer, preprocess = clip.load("ViT-B/32", device=device)
refer = refer.float()

base = CLIPWithHooks(copy.deepcopy(refer), nn.Linear(refer.visual.output_dim, 2)) # Random Classifier Head. 50% chance of being right.
base = base.eval()

f_t = CLIPClassifier(clip_model=copy.deepcopy(refer)).to(device) # Wrap in classifer to retrieve classifier head
f_t.load_state_dict(torch.load("best_clip_mnist_fp32.pt"))
fine_tuned = CLIPWithHooks(copy.deepcopy(refer), classifier_head=f_t.classifier) # Load in a raw CLIP model
fine_tuned = fine_tuned.eval()

In [None]:
def clip_collate_fn(batch):
  images = []
  labels = []

  for item in batch:
    img = item["image"].convert("RGB")  # Already a PIL Image
    img = preprocess(img)
    images.append(img)
    labels.append(item["label"])

  images = torch.stack(images)
  labels = torch.tensor(labels, dtype=torch.long)

  return {
      "pixel_values": images.to(device),
      "labels": labels.to(device)
  }

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

## 5. Collecting Embeddings

In [None]:
train_0 = split["train"].filter(lambda example: example["label"] == 0)
val_0 = split["test"].filter(lambda example: example["label"] == 0)
test_0 = mnist["test"].filter(lambda example: example["label"] == 0)

train_0.set_format(type="python", columns=["image", "label"])
val_0.set_format(type="python", columns=["image", "label"])
test_0.set_format(type="python", columns=["image", "label"])

train_0 = DataLoader(train_0, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
val_0 = DataLoader(val_0, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
test_0 = DataLoader(test_0, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

In [None]:
# Least Squares Regression
Z0_first_lsr = []
Z0_second_lsr = []
Z0_third_lsr = []
Z0_fourth_lsr = []
Z0_fifth_lsr = []

Z1_twelfth_lsr = []

In [None]:
with torch.no_grad():
  for i, batch in enumerate(tqdm(train_0, desc="Extracting Train Set Vectors")):
    images = batch["pixel_values"]

    out_base = base(images)
    out_fine_tuned = fine_tuned(images)

    Z0_first_lsr.append(out_base["first_cls"].float())
    Z0_second_lsr.append(out_base["second_cls"].float())
    Z0_third_lsr.append(out_base["third_cls"].float())
    Z0_fourth_lsr.append(out_base["fourth_cls"].float())
    Z0_fifth_lsr.append(out_base["fifth_cls"].float())

    Z1_twelfth_lsr.append(out_fine_tuned["twelfth_cls"].float())

Z0_first_lsr = torch.cat(Z0_first_lsr)
Z0_second_lsr = torch.cat(Z0_second_lsr)
Z0_third_lsr = torch.cat(Z0_third_lsr)
Z0_fourth_lsr = torch.cat(Z0_fourth_lsr)
Z0_fifth_lsr = torch.cat(Z0_fifth_lsr)

Z1_twelfth_lsr = torch.cat(Z1_twelfth_lsr)

## 6. Calculating Linear Transformation Matrix and Bias

In [None]:
# Universal / Entire Dataset
Z0_first_lsr_0 = Z0_first_lsr.cpu().numpy()
Z0_second_lsr_0 = Z0_second_lsr.cpu().numpy()
Z0_third_lsr_0 = Z0_third_lsr.cpu().numpy()
Z0_fourth_lsr_0 = Z0_fourth_lsr.cpu().numpy()
Z0_fifth_lsr_0 = Z0_fifth_lsr.cpu().numpy()

Z1_twelfth_lsr_0 = Z1_twelfth_lsr.cpu().numpy()

In [None]:
# Add constant 1 column for bias term
def addBiasColumn(Z0):
    ones = np.ones((Z0.shape[0], 1)) # (1,1)
    return np.hstack([Z0, ones])

def leastSquares(Z0, Z1):
    W_full, residuals, rank, s = np.linalg.lstsq(Z0, Z1, rcond=None)
    return W_full  

In [None]:
def retrieve_avg(Z0, Z1):
    Z0 = addBiasColumn(Z0)
    full = leastSquares(Z0, Z1)
    W = full[:-1]
    b = full[-1]
    return W, b

W_first_0, b_first_0 = retrieve_avg(Z0_first_lsr_0, Z1_twelfth_lsr_0)
W_second_0, b_second_0 = retrieve_avg(Z0_second_lsr_0, Z1_twelfth_lsr_0)
W_third_0, b_third_0 = retrieve_avg(Z0_third_lsr_0, Z1_twelfth_lsr_0)
W_fourth_0, b_fourth_0 = retrieve_avg(Z0_fourth_lsr_0, Z1_twelfth_lsr_0)
W_fifth_0, b_fifth_0 = retrieve_avg(Z0_fifth_lsr_0, Z1_twelfth_lsr_0)

### 1. Extra Split Sets

In [None]:
# train_1 = split["train"].filter(lambda example: example["label"] == 1)
# val_1 = split["test"].filter(lambda example: example["label"] == 1)
# test_1 = mnist["test"].filter(lambda example: example["label"] == 1)

# train_2 = split["train"].filter(lambda example: example["label"] == 2)
# val_2 = split["test"].filter(lambda example: example["label"] == 2)
# test_2 = mnist["test"].filter(lambda example: example["label"] == 2)

# train_3 = split["train"].filter(lambda example: example["label"] == 3)
# val_3 = split["test"].filter(lambda example: example["label"] == 3)
# test_3 = mnist["test"].filter(lambda example: example["label"] == 3)

# train_4 = split["train"].filter(lambda example: example["label"] == 4)
# val_4 = split["test"].filter(lambda example: example["label"] == 4)
# test_4 = mnist["test"].filter(lambda example: example["label"] == 4)

# train_5 = split["train"].filter(lambda example: example["label"] == 5)
# val_5 = split["test"].filter(lambda example: example["label"] == 5)
# test_5 = mnist["test"].filter(lambda example: example["label"] == 5)

# train_6 = split["train"].filter(lambda example: example["label"] == 6)
# val_6 = split["test"].filter(lambda example: example["label"] == 6)
# test_6 = mnist["test"].filter(lambda example: example["label"] == 6)

# train_7 = split["train"].filter(lambda example: example["label"] == 7)
# val_7 = split["test"].filter(lambda example: example["label"] == 7)
# test_7 = mnist["test"].filter(lambda example: example["label"] == 7)

# train_8 = split["train"].filter(lambda example: example["label"] == 8)
# val_8 = split["test"].filter(lambda example: example["label"] == 8)
# test_8 = mnist["test"].filter(lambda example: example["label"] == 8)

# train_9 = split["train"].filter(lambda example: example["label"] == 9)
# val_9 = split["test"].filter(lambda example: example["label"] == 9)
# test_9 = mnist["test"].filter(lambda example: example["label"] == 9)

# train_1.set_format(type="python", columns=["image", "label"])
# val_1.set_format(type="python", columns=["image", "label"])
# test_1.set_format(type="python", columns=["image", "label"])

# train_2.set_format(type="python", columns=["image", "label"])
# val_2.set_format(type="python", columns=["image", "label"])
# test_2.set_format(type="python", columns=["image", "label"])

# train_3.set_format(type="python", columns=["image", "label"])
# val_3.set_format(type="python", columns=["image", "label"])
# test_3.set_format(type="python", columns=["image", "label"])

# train_4.set_format(type="python", columns=["image", "label"])
# val_4.set_format(type="python", columns=["image", "label"])
# test_4.set_format(type="python", columns=["image", "label"])

# train_5.set_format(type="python", columns=["image", "label"])
# val_5.set_format(type="python", columns=["image", "label"])
# test_5.set_format(type="python", columns=["image", "label"])

# train_6.set_format(type="python", columns=["image", "label"])
# val_6.set_format(type="python", columns=["image", "label"])
# test_6.set_format(type="python", columns=["image", "label"])

# train_7.set_format(type="python", columns=["image", "label"])
# val_7.set_format(type="python", columns=["image", "label"])
# test_7.set_format(type="python", columns=["image", "label"])

# train_8.set_format(type="python", columns=["image", "label"])
# val_8.set_format(type="python", columns=["image", "label"])
# test_8.set_format(type="python", columns=["image", "label"])

# train_9.set_format(type="python", columns=["image", "label"])
# val_9.set_format(type="python", columns=["image", "label"])
# test_9.set_format(type="python", columns=["image", "label"])

# train_1 = DataLoader(train_1, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
# val_1 = DataLoader(val_1, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
# test_1 = DataLoader(test_1, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

# train_2 = DataLoader(train_2, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
# val_2 = DataLoader(val_2, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
# test_2 = DataLoader(test_2, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

# train_3 = DataLoader(train_3, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
# val_3 = DataLoader(val_3, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
# test_3 = DataLoader(test_3, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

# train_4 = DataLoader(train_4, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
# val_4 = DataLoader(val_4, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
# test_4 = DataLoader(test_4, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

# train_5 = DataLoader(train_5, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
# val_5 = DataLoader(val_5, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
# test_5 = DataLoader(test_5, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

# train_6 = DataLoader(train_6, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
# val_6 = DataLoader(val_6, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
# test_6 = DataLoader(test_6, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

# train_7 = DataLoader(train_7, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
# val_7 = DataLoader(val_7, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
# test_7 = DataLoader(test_7, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

# train_8 = DataLoader(train_8, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
# val_8 = DataLoader(val_8, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
# test_8 = DataLoader(test_8, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

# train_9 = DataLoader(train_9, batch_size=64, shuffle=True, collate_fn=clip_collate_fn)
# val_9 = DataLoader(val_9, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)
# test_9 = DataLoader(test_9, batch_size=64, shuffle=False, collate_fn=clip_collate_fn)

## 7. Manipulating Base CLIP

In [None]:
class CLIPClassifier(nn.Module):
    def __init__(self, clip_model, num_classes=10):
        super().__init__()
        self.clip = clip_model
        self.classifier = nn.Linear(self.clip.visual.output_dim, num_classes) # 512 -> 10
    
    def forward(self, images):
        image_features = self.clip.encode_image(images)
        logits = self.classifier(image_features)
        return logits 

In [None]:
class AugmentedCLIP(nn.Module):
    def __init__(self, clip, W=None, b=None, transform_stage=None, classifier=None):
        super().__init__()
        self.clip = clip
        self.W = torch.from_numpy(W.astype(np.float32)).to(device) if W is not None else None
        self.b = torch.from_numpy(b.astype(np.float32)).to(device) if W is not None else None
        self.transform_stage = transform_stage if transform_stage is not None else -1
        self.classifier = classifier if classifier is not None else nn.Linear(self.clip.visual.output_dim, 2)
    
    def forward(self, image): 
        image = image.to(device)
        x = self.clip.visual.conv1(image)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0,2,1)

        x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # [B, Patchs+CLS (sequence_length), Embedding Dimension] -> [64, 50 (49+1), 768]
        x = x + self.clip.visual.positional_embedding.to(x.dtype)
        x = self.clip.visual.ln_pre(x) # Normalize for Stability

        x = x.permute(1,0,2) # [sequence_length, batch_size, embedding_dim] -> [50, 64, 768]

        if self.transform_stage == -1:
            x = self.clip.visual.transformer(x)
        else:
            for i, block in enumerate(self.clip.visual.transformer.resblocks):
                x = block(x)
                if i+1 == self.transform_stage:
                    cls = x[0, :, :]
                    cls = cls.to(torch.float32)
                    manipulated = cls @ self.W + self.b
                    break
            manipulated = manipulated.unsqueeze(0) # Shape (1, B, D)
            x = torch.cat([manipulated, x[1:, :, :]], dim=0) # Adds manipulated cls token all together, not seperately
        twelfth_cls = x[0, :, :].squeeze()
        
        x = x.permute(1,0,2) # [batch_size, sequence_length, embedding_dim] -> [64, 50, 768]

        x = self.clip.visual.ln_post(x[:, 0, :])

        if self.clip.visual.proj is not None:
            final_embed = x @ self.clip.visual.proj
            final_embed = final_embed
        else:
            final_embed = x
            final_embed = final_embed
        
        logits = self.classifier(final_embed)

        return {
            "logits": logits,
            "final_embed": final_embed,
            "manipulated_cls": manipulated.squeeze(0) if self.transform_stage != -1 else twelfth_cls
        }


In [None]:
refer, _ = clip.load("ViT-B/32", device=device)
refer = refer.float()

f_t = CLIPClassifier(clip_model=copy.deepcopy(refer)).to(device) # Wrap in classifer to retrieve classifier head
f_t.load_state_dict(torch.load("best_clip_mnist_fp32.pt"))
fine_tuned = AugmentedCLIP(copy.deepcopy(refer), classifier=f_t.classifier) # Load in a raw CLIP model
fine_tuned = fine_tuned.eval()

base = AugmentedCLIP(copy.deepcopy(refer), classifier=f_t.classifier)
base = base.eval()

aug_0_1 = AugmentedCLIP(copy.deepcopy(refer), W=W_first_0, b=b_first_0, transform_stage=1, classifier=f_t.classifier)
aug_0_1.eval()

aug_0_2 = AugmentedCLIP(copy.deepcopy(refer), W=W_second_0, b=b_second_0, transform_stage=2, classifier=f_t.classifier)
aug_0_2.eval()

aug_0_3 = AugmentedCLIP(copy.deepcopy(refer), W=W_third_0, b=b_third_0, transform_stage=3, classifier=f_t.classifier)
aug_0_3.eval()

aug_0_4 = AugmentedCLIP(copy.deepcopy(refer), W=W_fourth_0, b=b_fourth_0, transform_stage=4, classifier=f_t.classifier)
aug_0_4.eval()

aug_0_5 = AugmentedCLIP(copy.deepcopy(refer), W=W_fifth_0, b=b_fifth_0, transform_stage=5, classifier=f_t.classifier)
aug_0_5.eval()

In [None]:
correct_1 = 0
correct_2 = 0
correct_3 = 0
correct_4 = 0
correct_5 = 0
correct_base = 0
correct_fine_tuned = 0

total_samples = 0

sim_final_1 = []
sim_cls_1 = []

sim_final_2 = []
sim_cls_2 = []

sim_final_3 = []
sim_cls_3 = []

sim_final_4 = []
sim_cls_4 = []

sim_final_5 = []
sim_cls_5 = []

def calcPred(model, images):
    logits = model(images)["logits"]
    pred = logits.argmax(dim=1)
    return pred

def cosineSimilarity(aug, fine, images):
    out_aug = aug(images)
    out_aug_embed = out_aug["final_embed"]
    out_aug_cls = out_aug["manipulated_cls"]

    out_fine = fine(images)
    out_fine_embed = out_fine["final_embed"]
    out_fine_cls = out_fine["manipulated_cls"]

    # Prevent NaNs
    eps = 1e-8
    out_aug_embed = F.normalize(out_aug_embed, dim=1, eps=eps)
    out_fine_embed = F.normalize(out_fine_embed, dim=1, eps=eps)
    out_aug_cls = F.normalize(out_aug_cls, dim=1, eps=eps)
    out_fine_cls = F.normalize(out_fine_cls, dim=1, eps=eps)

    cos_sim_final = (out_aug_embed * out_fine_embed).sum(dim=1).mean().item()
    cos_sim_cls = (out_aug_cls * out_fine_cls).sum(dim=1).mean().item()
    return cos_sim_final, cos_sim_cls

with torch.no_grad():
    for batch in tqdm(train_0, desc="Evaluating 0's"):
        images = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        total_samples += labels.size(0)

        correct_1 += (calcPred(aug_0_1, images) == labels).sum().item()
        correct_2 += (calcPred(aug_0_2, images) == labels).sum().item()
        correct_3 += (calcPred(aug_0_3, images) == labels).sum().item()
        correct_4 += (calcPred(aug_0_4, images) == labels).sum().item()
        correct_5 += (calcPred(aug_0_5, images) == labels).sum().item()

        correct_base += (calcPred(base, images) == labels).sum().item()
        correct_fine_tuned += (calcPred(fine_tuned, images) == labels).sum().item()

        cs_final, cs_cls = cosineSimilarity(aug_0_1, fine_tuned, images)
        sim_final_1.append(cs_final)
        sim_cls_1.append(cs_cls)

        cs_final, cs_cls = cosineSimilarity(aug_0_2, fine_tuned, images)
        sim_final_2.append(cs_final)
        sim_cls_2.append(cs_cls)

        cs_final, cs_cls = cosineSimilarity(aug_0_3, fine_tuned, images)
        sim_final_3.append(cs_final)
        sim_cls_3.append(cs_cls)

        cs_final, cs_cls = cosineSimilarity(aug_0_4, fine_tuned, images)
        sim_final_4.append(cs_final)
        sim_cls_4.append(cs_cls)

        cs_final, cs_cls = cosineSimilarity(aug_0_5, fine_tuned, images)
        sim_final_5.append(cs_final)
        sim_cls_5.append(cs_cls)

acc_1 = correct_1 / total_samples
acc_2 = correct_2 / total_samples
acc_3 = correct_3 / total_samples
acc_4 = correct_4 / total_samples
acc_5 = correct_5 / total_samples

acc_base = correct_base / total_samples
acc_fine_tuned = correct_fine_tuned / total_samples

print("\n")
print(f"\n")
for i, acc in enumerate([
    acc_1, acc_2, acc_3, acc_4, acc_5
]):
    print(f"Augmented {i+1} - Last Layer CLIP Accuracy: {acc:.4f}")
print(f"Base CLIP Accuracy: {acc_base:.4f}")
print(f"Fine-Tuned CLIP Accuracy: {acc_fine_tuned:.4f}")

for i, fin in enumerate([
    sim_final_1, sim_final_2, sim_final_3, sim_final_4, sim_final_5,
]):
    print(f"Avg cos sim final embed for aug_0_{i+1}: {np.mean(fin):.4f}")

for i, cls in enumerate([
    sim_cls_1, sim_cls_2, sim_cls_3, sim_cls_4, sim_cls_5,
]):
    print(f"Avg cos sim CLS token for aug_0_{i+1}: {np.mean(cls):.4f}")
