In [1]:
import vit_prisma
from vit_prisma.utils.data_utils.imagenet.imagenet_dict import IMAGENET_DICT
from vit_prisma.utils import prisma_utils

import numpy as np
import torch
from fancy_einsum import einsum
from collections import defaultdict

import plotly.graph_objs as go
import plotly.express as px

import matplotlib.colors as mcolors

from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from IPython.display import display, HTML

import pandas as pd

from tqdm import tqdm

In [2]:
# Helper function (ignore)
def plot_image(image):
  plt.figure()
  plt.axis('off')
  plt.imshow(image.permute(1,2,0))

class ConvertTo3Channels:
    def __call__(self, img):
        if img.mode != 'RGB':
            return img.convert('RGB')
        return img

transform = transforms.Compose([
    ConvertTo3Channels(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [3]:
from vit_prisma.models.base_vit import HookedViT

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from transformers import CLIPModel, CLIPProcessor

In [5]:
model_name = "openai/clip-vit-large-patch14"

In [6]:
clip_model = CLIPModel.from_pretrained(model_name)

In [14]:
clip_model.config.vision_config

CLIPVisionConfig {
  "attention_dropout": 0.0,
  "dropout": 0.0,
  "hidden_act": "quick_gelu",
  "hidden_size": 1024,
  "image_size": 224,
  "initializer_factor": 1.0,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "model_type": "clip_vision_model",
  "num_attention_heads": 16,
  "num_channels": 3,
  "num_hidden_layers": 24,
  "patch_size": 14,
  "projection_dim": 768,
  "torch_dtype": "float32",
  "transformers_version": "4.50.0"
}

In [36]:
clip_model.vision_model

CLIPVisionTransformer(
  (embeddings): CLIPVisionEmbeddings(
    (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
    (position_embedding): Embedding(257, 1024)
  )
  (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (encoder): CLIPEncoder(
    (layers): ModuleList(
      (0-23): 24 x CLIPEncoderLayer(
        (self_attn): CLIPSdpaAttention(
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): CLIPMLP(
          (activation_fn): QuickGELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, 

In [17]:
model_hooked = HookedViT.from_pretrained(model_name, is_clip=True, is_timm=False)

Official model name openai/clip-vit-large-patch14
LayerNorm folded.
Centered weights writing to residual stream
Loaded pretrained model openai/clip-vit-large-patch14 into HookedTransformer


In [8]:
clip_model.vision_model.encoder.layers[-1].mlp

CLIPMLP(
  (activation_fn): QuickGELUActivation()
  (fc1): Linear(in_features=1024, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=1024, bias=True)
)

In [9]:
df = pd.read_csv("generated_images/data.csv")
df

Unnamed: 0,num_shapes,shape_0,color_0,position_0,filename,shape_1,color_1,position_1,shape_2,color_2,position_2,shape_3,color_3,position_3,shape_4,color_4,position_4
0,1,pentagon,red,BL,generated_images/image_1_shapes_0.png,,,,,,,,,,,,
1,1,pentagon,red,BR,generated_images/image_1_shapes_1.png,,,,,,,,,,,,
2,1,square,red,TR,generated_images/image_1_shapes_2.png,,,,,,,,,,,,
3,1,hexagon,black,M,generated_images/image_1_shapes_3.png,,,,,,,,,,,,
4,1,square,blue,M,generated_images/image_1_shapes_4.png,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,5,pentagon,blue,BR,generated_images/image_5_shapes_95.png,triangle,black,BL,hexagon,yellow,M,square,green,TL,circle,pink,TR
496,5,pentagon,pink,BR,generated_images/image_5_shapes_96.png,hexagon,yellow,BL,circle,red,TL,square,blue,M,triangle,green,TR
497,5,triangle,blue,M,generated_images/image_5_shapes_97.png,hexagon,pink,BR,circle,green,TL,pentagon,red,TR,square,yellow,BL
498,5,square,yellow,BL,generated_images/image_5_shapes_98.png,hexagon,red,BR,triangle,pink,TL,circle,blue,M,pentagon,black,TR


In [10]:
image_neuron_activations = []
for i, row in tqdm(df.iterrows(), total=len(df)):
    image_path = row.filename

    image = Image.open(image_path)
    image = transform(image)
    image = image.unsqueeze(0)
    # image = torch.cat([image, image], dim=0)

    outputs, cache = model_hooked.run_with_cache(image)

    # print(outputs.shape)

    # print(cache['blocks.12.mlp.hook_post'].shape)
    # print(cache['blocks.12.hook_mlp_out'].shape)
    
    keys = [k for k in cache.keys() if 'hook_mlp_out' in k]
    # print(keys)
    # print(keys[0], cache[keys[0]].shape)
    values = [cache[k] for k in keys]
    # print(values[0].shape)
    values = torch.cat(values)
    # print(values.shape)
    mean_activations = torch.mean(values, axis=1)
    max_activations, _ = torch.max(values, axis=1)

    # print(mean_activations.shape)

    layer_activations = {
        "filename": image_path
    }
    num_layers = len(keys)
    num_neurons = mean_activations.shape[1]
    # print(num_layers, num_neurons)
    for layer_num in range(num_layers):
        for neuron_num in range(num_neurons):
            layer_activations[f"mean_{layer_num}_{neuron_num}"]= mean_activations[layer_num, neuron_num].item()
            layer_activations[f"max_{layer_num}_{neuron_num}"]= max_activations[layer_num, neuron_num].item()
    image_neuron_activations.append(layer_activations)
    # print(len(layer_activations.keys()))
    # break

100%|██████████| 500/500 [09:44<00:00,  1.17s/it]


In [11]:
df = pd.DataFrame(image_neuron_activations)
df

Unnamed: 0,filename,mean_0_0,max_0_0,mean_0_1,max_0_1,mean_0_2,max_0_2,mean_0_3,max_0_3,mean_0_4,...,mean_23_1019,max_23_1019,mean_23_1020,max_23_1020,mean_23_1021,max_23_1021,mean_23_1022,max_23_1022,mean_23_1023,max_23_1023
0,generated_images/image_1_shapes_0.png,0.099164,0.257024,0.298203,0.333471,0.059541,0.090624,0.108643,0.212754,0.022055,...,-0.171294,0.034744,0.227626,0.437658,-0.130334,0.179163,-0.228286,0.008854,-0.119478,0.226579
1,generated_images/image_1_shapes_1.png,0.098868,0.236394,0.298659,0.333518,0.059578,0.093828,0.108109,0.198275,0.022334,...,-0.159002,0.128170,0.228707,0.454122,-0.121280,0.225495,-0.221970,0.040223,-0.119355,0.224534
2,generated_images/image_1_shapes_2.png,0.098371,0.198219,0.294899,0.335975,0.057270,0.078525,0.102543,0.166103,0.021096,...,-0.183971,0.081780,0.185437,0.404024,-0.099349,0.141499,-0.227266,-0.016253,-0.099449,0.229580
3,generated_images/image_1_shapes_3.png,0.105667,0.350031,0.310025,0.336265,0.058159,0.078626,0.111158,0.305242,0.023447,...,-0.175808,0.136730,0.170698,0.396673,-0.065466,0.173172,-0.210550,0.142581,-0.052900,0.315388
4,generated_images/image_1_shapes_4.png,0.102984,0.147678,0.304164,0.331992,0.063006,0.108819,0.106660,0.195209,0.021967,...,-0.175748,0.019588,0.204655,0.394590,-0.124101,0.192931,-0.233677,-0.016928,-0.109357,0.350109
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,generated_images/image_5_shapes_95.png,0.102451,0.214196,0.264764,0.334364,0.047601,0.123497,0.108762,0.324152,0.021233,...,-0.142009,0.198176,0.214031,0.480671,-0.096051,0.227638,-0.217189,0.180153,-0.043028,0.357173
496,generated_images/image_5_shapes_96.png,0.099824,0.218822,0.254907,0.332797,0.048234,0.107941,0.103577,0.216401,0.023336,...,-0.147976,0.131034,0.200914,0.431632,-0.085056,0.267447,-0.210242,0.221128,-0.054774,0.284328
497,generated_images/image_5_shapes_97.png,0.099764,0.249264,0.256887,0.336022,0.047266,0.098838,0.109927,0.263641,0.019197,...,-0.151952,0.087928,0.201548,0.435799,-0.080965,0.278892,-0.212384,0.200994,-0.059827,0.318426
498,generated_images/image_5_shapes_98.png,0.101962,0.418774,0.262032,0.412612,0.047812,0.127314,0.110888,0.263006,0.021316,...,-0.147790,0.160468,0.201102,0.431684,-0.064188,0.318274,-0.208776,0.170633,-0.040669,0.427834


In [None]:
df.to_csv("clip_224/clip_mlp_activations.csv", index=False)