In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np

import tqdm
import gc

## Definitions

In [None]:
class CLIPClassificator():
  def __init__(self, model_name):
    self._device = "cuda" if torch.cuda.is_available() else "cpu"
    self.set_model_from_name(model_name)

  def set_model_from_name(self, name:str):
    self._model = CLIPModel.from_pretrained(name).to(self._device)
    self._processor = CLIPProcessor.from_pretrained(name)

  def _compute_text_embeddings(self, text_labels):
    text_inputs = self._processor(
      text=[f"a photo of a {label}" for label in text_labels],
      return_tensors="pt",
      padding=True
    ).to(self._device)

    with torch.no_grad():
      text_embeddings = self._model.get_text_features(**text_inputs)
      text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
    return text_embeddings

  def _compute_image_embeddings(self, image_batch):
    image_inputs = self._processor(
      images=image_batch,
      return_tensors="pt",
    ).to(self._device)

    with torch.no_grad():
      image_embeddings = self._model.get_image_features(**image_inputs)
      image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
    return image_embeddings

  def _compute_similarity(self, image_embeddings, text_embeddings):
    return (image_embeddings @ text_embeddings.T) * self._model.logit_scale.exp()

  def compute_probs(self, model_inputs:dict, batch_size:int=32):
    image_dataset = model_inputs["image_dataset"]
    text_labels = model_inputs["text_labels"]
    predictions = []

    text_embeddings = self._compute_text_embeddings(text_labels)

    for i in tqdm.tqdm(range(0, len(image_dataset), batch_size)):
      batch = image_dataset[i:i + batch_size]

      image_embeddings = self._compute_image_embeddings(batch)

      similarity = self._compute_similarity(image_embeddings, text_embeddings)
      predictions.extend(similarity.argmax(dim=1).tolist())
    return predictions


In [None]:
class SparseAutoEncoder(nn.Module):
    def __init__(self, input_dim: int, latent_upsample: int):
        super().__init__()
        self._device = "cuda" if torch.cuda.is_available() else "cpu"
        self._encoder = nn.Sequential(
            nn.Linear(input_dim, input_dim * latent_upsample, bias=False),
            nn.ReLU()
        )
        self._decoder = nn.Linear(input_dim * latent_upsample, input_dim)
        self.to(self._device)

    def get_metrics_names(self):
        return self._metrics_names

    def forward(self, input):
        latent = self._encoder(input)
        recon = self._decoder(latent)
        return recon, latent

In [None]:
class CLIPwithSAE(CLIPClassificator):
  def __init__(self, model_name, sae_st_dict_path, model_dim, latent_upsample):
    super().__init__(model_name)

    self._sae = SparseAutoEncoder(model_dim, latent_upsample)
    self._load_sae_weights(sae_st_dict_path)
    self._sae.eval()

  def _load_sae_weights(self, state_dict_path):
    state_dict = torch.load(state_dict_path, map_location="cpu")
    self._sae.load_state_dict(state_dict)
    print(f"Loaded SAE weights from {state_dict_path}")

  def _compute_image_embeddings(self, image_batch):
    image_inputs = self._processor(
      images=image_batch,
      return_tensors="pt",
    ).to(self._device)

    with torch.no_grad():
      image_embeddings = self._model.get_image_features(**image_inputs)
      image_embeddings, _ = self._sae(image_embeddings)
      image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)

    return image_embeddings

In [None]:
def create_test_sae(clip_name:str, sae_params:dict):
  path_to_sae_dict = sae_params["path_to_dict"]
  input_size = sae_params["input_size"]
  latent_upsample = sae_params["latent_upsample"]

  models_dict = {
      "Clip": CLIPClassificator(clip_name),
      "Clip_Sae": CLIPwithSAE(clip_name, path_to_sae_dict, input_size, latent_upsample)
  }

  def compute_metrics(labels, predictions):
    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average='macro', zero_division=0)
    recall = recall_score(labels, predictions, average='macro', zero_division=0)
    f1 = f1_score(labels, predictions, average='macro', zero_division=0)

    print(f"\naccuracy - {accuracy:.2%}\nprecision - {precision:.2%}\nrecall - {recall:.2%}\nf1 - {f1:.2%}\n")

  def clear_memory():
    gc.collect()
    if torch.cuda.is_available():
      torch.cuda.empty_cache()

  def test_sae(dataset:dict, batch_size:int):
    image_dataset = dataset["images"]
    true_values = dataset["true_values"]
    text_labels = dataset["text_labels"]

    print(f"\n=== Zero-shot classification results by model ===\n")
    for model_name, model_obj in models_dict.items():
      print(f"\n{model_name}:")

      predictions = model_obj.compute_probs(
        {
          "image_dataset": image_dataset,
          "text_labels": text_labels,
        },
        batch_size=batch_size
      )
      clear_memory()
      compute_metrics(true_values, predictions)

  return test_sae

## Actions

In [None]:
clip_name = "openai/clip-vit-large-patch14"
sae_params = {
    "path_to_dict": "../st_dicts/sae_for_clip.pth", # change if running in google colab
    "input_size": 768,
    "latent_upsample": 128,
}
test_sae = create_test_sae(clip_name, sae_params)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loaded SAE weights from sae_for_clip.pth


In [None]:
cifar10 = load_dataset("uoft-cs/cifar10", split="test[:8192]")
cifar10_labels = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]
test_sae(
    {
        "images": cifar10["img"],
        "true_values": cifar10["label"],
        "text_labels": cifar10_labels,
    },
    batch_size=128
)


=== Zero-shot classification results by model ===


Clip:


100%|██████████| 64/64 [06:29<00:00,  6.09s/it]



accuracy - 95.35%
precision - 95.47%
recall - 95.35%
f1 - 95.36%


Clip_Sae:


100%|██████████| 64/64 [06:27<00:00,  6.06s/it]



accuracy - 95.39%
precision - 95.50%
recall - 95.38%
f1 - 95.39%



In [None]:
cifar100 = load_dataset("uoft-cs/cifar100", split="test[:8192]")
cifar100_labels = [
    "apple", "aquarium_fish", "baby", "bear", "beaver", "bed", "bee", "beetle",
    "bicycle", "bottle", "bowl", "boy", "bridge", "bus", "butterfly", "camel",
    "can", "castle", "caterpillar", "cattle", "chair", "chimpanzee", "clock",
    "cloud", "cockroach", "couch", "cra", "crocodile", "cup", "dinosaur",
    "dolphin", "elephant", "flatfish", "forest", "fox", "girl", "hamster",
    "house", "kangaroo", "keyboard", "lamp", "lawn_mower", "leopard", "lion",
    "lizard", "lobster", "man", "maple_tree", "motorcycle", "mountain", "mouse",
    "mushroom", "oak_tree", "orange", "orchid", "otter", "palm_tree", "pear",
    "pickup_truck", "pine_tree", "plain", "plate", "poppy", "porcupine", "possum",
    "rabbit", "raccoon", "ray", "road", "rocket", "rose", "sea", "seal", "shark",
    "shrew", "skunk", "skyscraper", "snail", "snake", "spider", "squirrel",
    "streetcar", "sunflower", "sweet_pepper", "table", "tank", "telephone",
    "television", "tiger", "tractor", "train", "trout", "tulip", "turtle",
    "wardrobe", "whale", "willow_tree", "wolf", "woman", "worm"
]
test_sae(
    {
        "images": cifar100["img"],
        "true_values": cifar100["fine_label"],
        "text_labels": cifar100_labels,
    },
    batch_size=128
)

README.md: 0.00B [00:00, ?B/s]

cifar100/train-00000-of-00001.parquet:   0%|          | 0.00/119M [00:00<?, ?B/s]

cifar100/test-00000-of-00001.parquet:   0%|          | 0.00/23.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]


=== Zero-shot classification results by model ===


Clip:


100%|██████████| 64/64 [06:22<00:00,  5.97s/it]



accuracy - 72.46%
precision - 78.87%
recall - 72.45%
f1 - 73.22%


Clip_Sae:


100%|██████████| 64/64 [06:22<00:00,  5.97s/it]



accuracy - 72.49%
precision - 78.28%
recall - 72.48%
f1 - 73.18%

