## Installs

In [None]:
!pip install -U torch torchvision
!pip install transformers datasets tqdm ipywidgets numpy pandas hf_transfer

In [None]:
%pip install -U torch torchvision
%pip install transformers datasets tqdm ipywidgets numpy pandas hf_transfer

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import CLIPModel, CLIPVisionModel
from datasets import load_from_disk, concatenate_datasets
from tqdm.notebook import tqdm
import copy
import os
import numpy as np
import pandas as pd
from collections import defaultdict

## Instantiate

In [None]:
device = "cuda"
# device = torch.device("mps")

reference = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", dtype=torch.float32)
reference.to(device)

dataset_name = ["DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "Stanford_Cars", "SUN397", "SVHN"]
fine_tuned_name = ["tanganke/clip-vit-base-patch32_dtd", "tanganke/clip-vit-base-patch32_eurosat", "tanganke/clip-vit-base-patch32_gtsrb", "tanganke/clip-vit-base-patch32_mnist", "tanganke/clip-vit-base-patch32_resisc45", "tanganke/clip-vit-base-patch32_stanford-cars", "tanganke/clip-vit-base-patch32_sun397", "tanganke/clip-vit-base-patch32_svhn"]

custom_models = False # False if using Tanganke's models
type = 1 # How many 
num_models = 5 if custom_models else 1
layers = [i for i in range(12)]

In [None]:
class Task_Matrix(torch.nn.Module):
    def __init__(self, model, W=None, transform_stage=-1):
        super().__init__()
        self.model = model
        self.W = torch.from_numpy(W.astype(np.float32)).to(device) if W is not None else None
        self.transform_stage = transform_stage
    
    def vision_manipulation(self, pixel_values):
        hidden_states = self.model.vision_model.embeddings(pixel_values)
        hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

        # Following CLIPEncoder
        for i, encoder_layer in enumerate(self.model.vision_model.encoder.layers):
            encoder_output = encoder_layer(hidden_states, attention_mask=None, causal_attention_mask=None, output_attentions=False)
            hidden_states = encoder_output[0]
            if i == self.transform_stage:
                if self.W is not None:
                    cls = hidden_states[:, 0, :]
                    # original multiplied by matrix
                    cls = cls @ self.W
                    hidden_states[:, 0, :] = cls
                break
        
        pooled_output = hidden_states[:, 0, :]
        pooled_output = self.model.vision_model.post_layernorm(pooled_output)

        return pooled_output

    def forward(self, pixel_values, all_label_embeds):
        image_embeds = self.model.visual_projection(self.vision_manipulation(pixel_values))
        image_embeds /= image_embeds.norm(dim=-1, keepdim=True)

        logits = image_embeds @ all_label_embeds.T
        logits = logits * self.model.logit_scale.exp()
        pred = logits.argmax(dim=1)

        return pred

In [None]:
class EmbeddingHooks(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.CLS = []

    def forward(self, pixel_values):
        self.CLS = []
        hidden_states = self.model.vision_model.embeddings(pixel_values)
        hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

        # Following CLIPEncoder
        for encoder_layer in self.model.vision_model.encoder.layers:
            encoder_output = encoder_layer(hidden_states, attention_mask=None, causal_attention_mask=None, output_attentions=False)
            hidden_states = encoder_output[0]
            self.CLS.append(hidden_states[:, 0, :])
        
        return self.CLS

## Helper Functions

In [None]:
def task_matrix_solver(base_hook, fine_tuned_hook, train_loader):
    base_embeddings = {l: [] for l in layers}
    fine_tuned_embeddings = []

    train_progress = tqdm(train_loader, total=len(train_loader))
    with torch.no_grad():
        for batch in train_progress:
            pixel_values = batch["pixel_values"].to(device, non_blocking=True)

            base_embed = base_hook(pixel_values)
            fine_tuned_embed = fine_tuned_hook(pixel_values)[-1] # Last Layer

            for l in layers:
                base_embeddings[l].append(base_embed[l].float().cpu())
            fine_tuned_embeddings.append(fine_tuned_embed.float().cpu())
    
    W = {}

    Z1_last = torch.cat(fine_tuned_embeddings)
    Z1 = Z1_last.cpu().numpy()

    for l, val in base_embeddings.items():
        val = torch.cat(val)
        val = val.cpu().numpy()
        W[l], resid, rank, s = np.linalg.lstsq(val, Z1, rcond=None)
        
    del fine_tuned_hook

    return W

In [None]:
def augment_model(W):
    task_matrix_augmented = {}
    for l in layers:
        aug = Task_Matrix(model=copy.deepcopy(reference), W=W[l], transform_stage=l)
        aug.eval().to(device)
        task_matrix_augmented[l] = aug
    return task_matrix_augmented

In [None]:
def evaluate(augmented_models, test_loader, all_label_embeds):
    task_matrix_acc = {i: 0 for i in layers}
    total = 0

    progress = tqdm(test_loader, total=len(test_loader))
    with torch.no_grad():
        for batch in progress:
            pixel_values = batch["pixel_values"].to(device, non_blocking=True)
            labels = batch["label"].to(device, non_blocking=True)

            for model_num in augmented_models.keys():
                pred = augmented_models[model_num](pixel_values=pixel_values, all_label_embeds=all_label_embeds)
                task_matrix_acc[model_num] += (pred == labels).sum().item()
            
            total += len(labels)
    
    for l in layers:
        task_matrix_acc[l] /= total
    
    return task_matrix_acc

## Augmentation Loop

In [None]:
for i, name in enumerate(dataset_name):
    base_hook = EmbeddingHooks(model=copy.deepcopy(reference))
    base_hook.eval().to(device)

    train_dataset = load_from_disk(f"data/{name}/train") # Can look into minimal data if necessary with the 'label' column
    train_dataset.set_format(type="torch", columns=["pixel_values", "input_ids", "attention_mask"])
    if not custom_models:
        val_dataset = load_from_disk(f"data/{name}/val")
        val_dataset.set_format(type="torch", columns=["pixel_values", "input_ids", "attention_mask"])
        train_dataset = concatenate_datasets([train_dataset, val_dataset])
    num_train_img = len(train_dataset)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False, num_workers=12, pin_memory=True, persistent_workers=True)
    
    test_dataset = load_from_disk(f"data/{name}/test")
    test_dataset.set_format(type="torch", columns=["pixel_values", "label"])
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=12, pin_memory=True, persistent_workers=True)

    all_label_embeds = torch.load(f"data/{name}/all_label_embeds.pt").to(device)

    print(name)
    os.makedirs(f"results/{name}", exist_ok=True)
    for models in range(num_models):
        if custom_models:
            fine_tuned = copy.deepcopy(reference)
            fine_tuned.load_state_dict(torch.load(f"fine_tuned/{name}/CLIP_{name}_{models}.pt"))
        else:
            vision_model = CLIPVisionModel.from_pretrained(f"{fine_tuned_name[i]}")
            fine_tuned = copy.deepcopy(reference)
            fine_tuned.vision_model.load_state_dict(vision_model.vision_model.state_dict())
        
        fine_tuned_hook = EmbeddingHooks(model=copy.deepcopy(fine_tuned))
        fine_tuned_hook.eval().to(device)

        W = task_matrix_solver(base_hook=base_hook, fine_tuned_hook=fine_tuned_hook, train_loader=train_loader)
        task_matrix_models = augment_model(W)
        task_matrix_acc = evaluate(augmented_models=task_matrix_models, test_loader=test_loader, all_label_embeds=all_label_embeds)

        express = f"Model {models}" if custom_models else f"Tanganke's Model"
        print(f"Task Matrix Accuracy | {express}")
        for layer, acc in task_matrix_acc.items():
            print(f"\t Layer {layer}: {acc}")
        print("\n")

        results_path = f"results/{name}/custom/{models}" if custom_models else f"results/{name}/tanganke"
        file_name = f"{num_train_img}_accuracy.json"
        os.makedirs(results_path, exist_ok=True)
        data = {
            "Task_Matrix_Accuracy": [task_matrix_acc[l] for l in layers]
        }
        df = pd.DataFrame(task_matrix_acc, index=[0])
        df.to_json(f"{results_path}/{file_name}", orient="records", indent=2)