## Install

In [None]:
!pip install -U torch torchvision 
!pip install transformers datasets tqdm ipywidgets hf_transfer

In [None]:
%pip install -U torch torchvision 
%pip install transformers datasets tqdm ipywidgets hf_transfer

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import CLIPModel, CLIPVisionModel
from datasets import load_from_disk
from tqdm.notebook import tqdm
import copy
import os
import pandas as pd
import numpy as np

## Instantiate

In [None]:
device = "cuda"
# device = torch.device("mps")

reference = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", dtype=torch.float32)
reference.to(device)

dataset_name = ["DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "Stanford_Cars", "SUN397", "SVHN"]
fine_tuned_name = ["tanganke/clip-vit-base-patch32_dtd", "tanganke/clip-vit-base-patch32_eurosat", "tanganke/clip-vit-base-patch32_gtsrb", "tanganke/clip-vit-base-patch32_mnist", "tanganke/clip-vit-base-patch32_resisc45", "tanganke/clip-vit-base-patch32_stanford-cars", "tanganke/clip-vit-base-patch32_sun397", "tanganke/clip-vit-base-patch32_svhn"]

custom_models = False
num_models = 5 if custom_models else 1

In [None]:
class Task_Matrix(torch.nn.Module):
    def __init__(self, model, W=None, transform_stage=-1):
        super().__init__()
        self.model = model
        self.W = torch.from_numpy(W.astype(np.float32)).to(device) if W is not None else None
        self.transform_stage = transform_stage
    
    def vision_manipulation(self, pixel_values):
        hidden_states = self.model.vision_model.embeddings(pixel_values)
        hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

        # Following CLIPEncoder
        for i, encoder_layer in enumerate(self.model.vision_model.encoder.layers):
            encoder_output = encoder_layer(hidden_states, attention_mask=None, causal_attention_mask=None, output_attentions=False)
            hidden_states = encoder_output[0]
            if i == self.transform_stage:
                if self.W is not None:
                    cls = hidden_states[:, 0, :]
                    # original multiplied by matrix
                    cls = cls @ self.W
                    hidden_states[:, 0, :] = cls
                break
        
        pooled_output = hidden_states[:, 0, :]
        pooled_output = self.model.vision_model.post_layernorm(pooled_output)

        return pooled_output

    def forward(self, pixel_values, all_label_embeds):
        image_embeds = self.model.visual_projection(self.vision_manipulation(pixel_values))
        image_embeds /= image_embeds.norm(dim=-1, keepdim=True)

        logits = image_embeds @ all_label_embeds.T
        logits = logits * self.model.logit_scale.exp()
        pred = logits.argmax(dim=1)

        return pred

## Evaluate

In [None]:
base = Task_Matrix(model=copy.deepcopy(reference))
base.eval().to(device)

os.makedirs("fine_tuned_acc", exist_ok=True)
for i, name in enumerate(dataset_name):
    fine_tuned_acc = {j: 0 for j in range(num_models)}
    base_acc = 0
    total = 0
    fine_tuned_models = {}
    for j in range(num_models):
        if custom_models:
            fine_tuned = copy.deepcopy(reference)
            fine_tuned.load_state_dict(torch.load(f"fine_tuned/{name}/CLIP_{name}_{j}.pt"))
            fine_tuned = Task_Matrix(model=fine_tuned)
            fine_tuned.eval().to(device)
            fine_tuned_models[j] = fine_tuned
        else:
            vision_model = CLIPVisionModel.from_pretrained(f"{fine_tuned_name[i]}")
            fine_tuned = copy.deepcopy(reference)
            fine_tuned.vision_model.load_state_dict(vision_model.vision_model.state_dict())
            fine_tuned = Task_Matrix(model=fine_tuned)
            fine_tuned.eval().to(device)
            fine_tuned_models[j] = fine_tuned

    test_dataset = load_from_disk(f"data/{name}/test")
    test_dataset.set_format(type="torch", columns=["pixel_values", "label"])
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=12, pin_memory=True, persistent_workers=True)

    all_label_embeds = torch.load(f"data/{name}/all_label_embeds.pt").to(device)

    test_progress = tqdm(test_loader, total=len(test_loader))
    with torch.no_grad():
        for batch in test_progress:
            pixel_values = batch["pixel_values"].to(device, non_blocking=True)
            labels = batch["label"].to(device, non_blocking=True)

            base_preds = base(pixel_values=pixel_values, all_label_embeds=all_label_embeds)
            base_acc += (base_preds == labels).sum().item()

            for j in range(num_models):
                fine_tuned_preds = fine_tuned_models[j](pixel_values=pixel_values, all_label_embeds=all_label_embeds)
                fine_tuned_acc[j] += (fine_tuned_preds == labels).sum().item()
            
            total += len(labels)
    
    data = {}

    base_acc /= total
    data["Base_Acc"] = [base_acc]
    print(f"{name}:")
    print(f"\tBase Accuracy: {base_acc}")
    for j in range(num_models):
        fine_tuned_acc[j] /= total
        print(f"\tFine-Tuned Accuracy {j}: {fine_tuned_acc[j]}")
        data[f"Fine-Tuned_Acc_{j}"] = [fine_tuned_acc[j]]
    print("\n")
    
    split_folder = "custom" if custom_models else "tanganke"
    full_save = f"fine_tuned_acc/{split_folder}"
    os.makedirs(full_save, exist_ok=True)
    df = pd.DataFrame(data)
    df.to_json(f"{full_save}/{name}.json", orient="records", index=False)