In [None]:
!git clone https://github.com/soniajoseph/ViT-Prisma

In [11]:
!pip install -e ViT-Prisma

Defaulting to user installation because normal site-packages is not writeable
Obtaining file:///home/user/cv-proj2/ViT-Prisma
  Preparing metadata (setup.py) ... [?25ldone
Collecting line_profiler (from vit-prisma==2.0.0)
  Downloading line_profiler-4.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (34 kB)
Collecting wandb (from vit-prisma==2.0.0)
  Downloading wandb-0.19.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting kaleido (from vit-prisma==2.0.0)
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl.metadata (15 kB)
Collecting open-clip-torch (from vit-prisma==2.0.0)
  Downloading open_clip_torch-2.32.0-py3-none-any.whl.metadata (31 kB)
Collecting ftfy (from open-clip-torch->vit-prisma==2.0.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting docker-pycreds>=0.4.0 (from wandb->vit-prisma==2.0.0)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting sentry-sd

In [1]:
# -*- coding: utf-8 -*-
"""ViT Prisma Main Demo

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1TL_BY1huQ4-OTORKbiIg7XfTyUbmyToQ

by Sonia Joseph

Twitter: [@soniajoseph_](https://twitter.com/soniajoseph_)

Original introduction is [here](https://www.lesswrong.com/posts/kobJymvvcvhbjWFKe/laying-the-foundations-for-vision-and-multimodal-mechanistic).

# Introduction

The purpose of this notebook is to introduce readers to vision transformer (ViT) mechanistic interpretability.

To make ViT mechanistic interpretability easier, I built an [open source library Prisma](https://github.com/soniajoseph/ViT-Prisma). The library is based on Neel Nanda's fantastic [TransformerLens](https://github.com/neelnanda-io/TransformerLens) but adapted for vision transformers and text-image models like CLIP. This notebook serves as a demo of the library. I highly encourage readers to check out the library and request features that they'd like to see!

I hope this notebook builds the vision mech interp ecosystem and encourages researchers to pursue their own directions with Prisma!

## Audience
This notebook is geared toward two audiences. The first is familiar with language mech interp, but not vision mech interp. The second is somewhat new to mechanistic interpretability, but is familiar with basic deep learning and has some mild exposure to mech interp concepts. If you are *completely* new to all mech interp, I recommend getting the basics down with the [ARENA curriculum](https://github.com/callummcdougall/ARENA_2.0) first.


The structure of this notebook is based on the excellent notebook [Exploratory Analysis](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=lZgu7cH72kdd) from TransformerLens, with some detours. While this notebook acts as a stand-alone, I encourage readers to consult the original notebook when they would like a deeper explanation.

For unfamiliar terms, also check out the [mech interp explorer](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#inline-images&theme=default).

## Structure

*See the sidebar for navigation.*

We'll run through the basic mech interp techniques on a vision transformer, including:

* Logit attribution
* Patch-level emoji logit lens
* Attention visualization
* Activation patching

**We will illustrate the last few technique by changing the ViT's prediction from tabby cat to Border Collie with a minimum viable ablation.**

## Acknowledgements and contributors

Thank you to Noah MacCallum, Rob Graham, and Karolis Ramanauskas for giving feedback on an early draft of this notebook.

Further thank you to Neel Nanda for your feedback, and to the Prisma team and core contributers, the MATS community, and South Park Commons. Full acknowledgements are on the Prisma repo documentation.

### Differences between ViT and Language Interpretability

*This section is geared toward readers already familiar with language transformer mech interp. If you are new to mech interp in general, you don't have to dwell on this section too much.*


Vision mech interp is like language mech interp, but in a fun-house mirror. Both architectures are transformers, so many LLM techniques carry over. However, there are a few twists:

* **The typical ViT is not doing unidirectional sequence modeling.** ViTs use bidirectional attention and predict a global CLS token, rather than predicting the next token in an autoregressive manner. (Note: There are autoregressive vision transformers with basically the same architecture as language, such as [Image GPT](https://openai.com/research/image-gpt) and [Parti](https://sites.research.google/parti/), which do next-token image generation. However, as of February 2024, autoregressive vision transformers are not frequently used in the wild.)
* **Bidirectional attention vs causal attention.** Language transformers have causal (unidirectional) attention. This means that there is an upper triangular mask on the attention, so that earlier tokens cannot attend to tokens in the future. The classical ViT, with its bidirectional attention, does not have the same concept of "time." Thus, some of the original LLM mech interp techniques break. It can be unclear which direction information is flowing. Induction heads, if they are present in vision, would look different from those in language to account for bidirectional attention.
* **CLS token instead of next token prediction/ autoregressive loss.** For ViTs, a learnable CLS token, which is prepended to the input, gets fed into the classification head instead of the final token as in language. The CLS token accrues global information from the other patches through self-attention as all the patches pass through the net.
* **No canonical dictionary matrix. Vision is more ambiguous.** Vision lacks a standard dictionary matrix like the 50k one for language, partially due to inherent ambiguity. For instance, a yellow patch on a goldfinch might represent "yellow," "wing," "goldfinch," "bird," or "animal," depending on the granularity, demonstrating hierarchical ambiguity. An animal might be identified specifically as a "Border collie" or more generally as a "dog." Beyond hierarchy, ambiguity in vision also stems from cultural interpretations and the imprecision of language. Practically, ImageNet's 1000 classes serve as a makeshift "dictionary," but it falls short of fully encompassing visual concepts.
* **Additional hyperparameters.** Patch size is a vision-specific hyperparameter, determining the size of the patches into which an image is divided. Using smaller patches increases accuracy but also computational load, because attention scales quadratically with patch number.
* **There is a zoo of vision transformers.** Similar to language, vision transformers come in many forms. The most relevant are the vanilla ViT, which we'll be analyzing in this notebook; CLIP, which is co-trained with text using contrastive loss; and DINO uses unlabeled data. For a review, check out [this survey](https://arxiv.org/pdf/2101.01169.pdf).

## Import libraries, data, and helper functions (ignore)
"""

# Install the Prisma repo library (update version number or clone from source for latest functionality)

!pip install vit_prisma

import vit_prisma
from vit_prisma.utils.data_utils.imagenet_dict import IMAGENET_DICT
from vit_prisma.utils import prisma_utils

import numpy as np
import torch
from fancy_einsum import einsum
from collections import defaultdict

import plotly.graph_objs as go
import plotly.express as px

import matplotlib.colors as mcolors

from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from IPython.core.display import display, HTML

# Get images we'll feed into the model
!wget https://github.com/soniajoseph/ViT-Prisma/blob/main/src/vit_prisma/sample_images/cat_dog.jpeg?raw=true -O cat_dog.jpeg --quiet
!wget https://github.com/soniajoseph/ViT-Prisma/blob/main/src/vit_prisma/sample_images/cat_crop.jpeg?raw=true -O crop_cat.png --quiet

"""**Helper Functions** (ignore)"""

# Helper function (ignore)
def plot_image(image):
  plt.figure()
  plt.axis('off')
  plt.imshow(image.permute(1,2,0))

class ConvertTo3Channels:
    def __call__(self, img):
        if img.mode != 'RGB':
            return img.convert('RGB')
        return img

transform = transforms.Compose([
    ConvertTo3Channels(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def plot_logit_boxplot(average_logits, labels):
  hovertexts = np.array([[IMAGENET_DICT[i] for _ in range(25)] for i in range(1000)])

  fig = go.Figure()
  data = []

  # if tensor, turn to numpy
  if isinstance(average_logits, torch.Tensor):
      average_logits = average_logits.detach().cpu().numpy()

  for i in range(average_logits.shape[1]):  # For each layer
      layer_logits = average_logits[:, i]
      hovertext = hovertexts[:, i]
      box = fig.add_trace(go.Box(
          y=layer_logits,
          name=f'{layer_labels[i]}',
          text=hovertext,
          hoverinfo='y+text',
          boxpoints='suspectedoutliers'
      ))
      data.append(box)


  means = np.mean(average_logits, axis=0)
  fig.add_trace(go.Scatter(
      x = layer_labels,
      y=means,
      mode='markers',
      name='Mean',
      # line=dict(color='gray'),
      marker=dict(size=4, color='red'),
  ))


  fig.update_layout(
      title='Raw Logit Values Per Layer (each dot is 1 ImageNet Class)',
      xaxis=dict(title='Layer'),
      yaxis=dict(title='Logit Values'),
      showlegend=False
  )

  fig.show()

def plot_patched_component(patched_head, title=''):
  """
  Use for plotting Activation Patching.
  """

  fig = go.Figure(data=go.Heatmap(
      z=patched_head.detach().numpy(),
      colorscale='RdBu',  # You can choose any colorscale
      colorbar=dict(title='Value'),  # Customize the color bar
      hoverongaps=False
  ))
  fig.update_layout(
      title=title,
      xaxis_title='Attention Head',
      yaxis_title='Patch Number',
  )

  return fig

def imshow(tensor, **kwargs):
    """
    Use for Activation Patching.
    """
    px.imshow(
          prisma_utils.to_numpy(tensor),
          color_continuous_midpoint=0.0,
          color_continuous_scale="RdBu",
          **kwargs,
      ).show()

"""# Load model and data

## ViT Architecture

![image](https://production-media.paperswithcode.com/methods/Screen_Shot_2021-01-26_at_9.43.31_PM_uI4jjMq.png)


A [vision transformer](https://arxiv.org/pdf/2010.11929.pdf) (ViT) is an architecture designed for image classification tasks, similar to the classic transformer architecture used in language models. A ViT consists of transformer blocks; each block consists of an Attention layer and an MLP layer.


Unlike language models, vision transformers do not have a dictionary embedding and unembedding matrix. Instead, images are divided into non-overlapping patches, similar to tokens in language models. These patches are flattened and linearly projected to embeddings via a Conv2D layer, similar to word embeddings in language models. A learnable class token (CLS token) is appended to the beginning of the sequence, which accrues global information throughout the network. A linear position embedding is added to the patches.

The patch embeddings then pass through the transformer blocks (each block consists of a layer norm, an attention layer, another layernorm, and an mlp layer). The output of each block is added back to the previous input. The sum of the block and the previous input is called the residual stream.

The final layer of this vision transformer is a classification head with 1000 logit values for ImageNet's 1000 classes. The CLS token is fed into the final layer for 1000-way classification.

Like TransformerLens, we use HookedViT to easily capture intermediate activations with custom hook functions, instead of dealing with PyTorch's normal hook functionality.
"""

Defaulting to user installation because normal site-packages is not writeable
Collecting vit_prisma
  Downloading vit_prisma-0.1.4-py3-none-any.whl.metadata (10 kB)
Collecting jaxtyping (from vit_prisma)
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting einops (from vit_prisma)
  Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Collecting fancy-einsum (from vit_prisma)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting plotly==5.19.0 (from vit_prisma)
  Downloading plotly-5.19.0-py3-none-any.whl.metadata (7.0 kB)
Collecting transformers (from vit_prisma)
  Downloading transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting wadler-lindig>=0.1.3 (from jaxtyping->vit_prisma)
  Downloading wadler_lindig-0.1.6-py3-none-any.whl.metadata (17 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers->vit_prisma)
  Downloading tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Downl

  from IPython.core.display import display, HTML


"# Load model and data\n\n## ViT Architecture\n\n![image](https://production-media.paperswithcode.com/methods/Screen_Shot_2021-01-26_at_9.43.31_PM_uI4jjMq.png)\n\n\nA [vision transformer](https://arxiv.org/pdf/2010.11929.pdf) (ViT) is an architecture designed for image classification tasks, similar to the classic transformer architecture used in language models. A ViT consists of transformer blocks; each block consists of an Attention layer and an MLP layer.\n\n\nUnlike language models, vision transformers do not have a dictionary embedding and unembedding matrix. Instead, images are divided into non-overlapping patches, similar to tokens in language models. These patches are flattened and linearly projected to embeddings via a Conv2D layer, similar to word embeddings in language models. A learnable class token (CLS token) is appended to the beginning of the sequence, which accrues global information throughout the network. A linear position embedding is added to the patches.\n\nThe pa

In [4]:
from vit_prisma.models.base_vit import HookedViT
from vit_prisma.configs.HookedViTConfig import HookedViTConfig

In [9]:
from vit_prisma.models.base_vit import HookedViT
from vit_prisma.configs.HookedViTConfig import HookedViTConfig

model = HookedViT.from_pretrained("vit_base_patch16_384",
                                        center_writing_weights=True,
                                        center_unembed=True,
                                        fold_ln=True,
                                        refactor_factored_attn_matrices=True,
                                          
                                    )

config.json:   0%|          | 0.00/586 [00:00<?, ?B/s]

AttributeError: 'TimmWrapperConfig' object has no attribute 'num_hidden_layers'