 # Activation Patching in TransformerLens Demo
 This is an accompaniment to [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo). That notebook explains some basic techniques for mech interp of networks, including an overview of activation patching ([summary here](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx)). This demonstrates how to use the Activation Patching utils in TransformerLens.


 <b style="color: red">To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.</b>

 **Tips for reading this Colab:**
 * You can run all this code for yourself!
 * The graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate
 * Collapse irrelevant sections with the dropdown arrows
 * Search the page using the search in the sidebar, not CTRL+F

 ## Setup (Ignore)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install transformer_lens
    # Install my janky personal plotting utils
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

 We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f6b307d0990>

 Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:

In [6]:
from neel_plotly import line, imshow, scatter

In [7]:
import transformer_lens.patching as patching

 ## Activation Patching Setup
 This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important.

In [8]:
model = HookedTransformer.from_pretrained("gpt2-small")

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [9]:
prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']
answers = [(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]

clean_tokens = model.to_tokens(prompts)
# Swap each adjacent pair, with a hacky list comprehension
corrupted_tokens = clean_tokens[
    [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]
    ]
print("Clean string 0", model.to_string(clean_tokens[0]))
print("Corrupted string 0", model.to_string(corrupted_tokens[0]))

answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)
print("Answer token indices", answer_token_indices)

Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to
Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to
Answer token indices tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]], device='cuda:0')


In [10]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    if len(logits.shape)==3:
        # Get final logits only
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 3.5519
Corrupted logit diff: -3.5519


In [11]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff
def ioi_metric(logits, answer_token_indices=answer_token_indices):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


 ## Patching
 In the following cells, we use the patching module to call activation patching utilities

In [12]:
# Whether to do the runs by head and by position, which are much slower
DO_SLOW_RUNS = False

 ### Patching Single Activation Types
 We start by patching single types of activation
 The general syntax is that the functions are called get_act_patch_... and take in (model, corrupted_tokens, clean_cache, patching_metric)

 We can patch the residual stream at the start of each block over each layer and position
 resid_pre -> attn_out, mlp_out, resid_mid all also work

In [13]:
resid_pre_act_patch_results = patching.get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="resid_pre Activation Patching")

  0%|          | 0/180 [00:00<?, ?it/s]

 We can patch head outputs over each head in each layer, patching across all positions at once
 out -> q, k, v, pattern all also work

In [14]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(attn_head_out_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="attn_head_out Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

 We can patch head outputs over each head in each layer, patching on each position in turn
 out -> q, k, v, pattern all also work, though note that pattern has output shape [layer, pos, head]
 We reshape it to plot nicely

In [15]:
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]
if DO_SLOW_RUNS:
    attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, ioi_metric)
    attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, "layer pos head -> (layer head) pos")
    imshow(attn_head_out_act_patch_results, 
        yaxis="Head Label", 
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=ALL_HEAD_LABELS,
        title="attn_head_out Activation Patching By Pos")

 ### Patching multiple activation types
 Some utilities are provided to patch multiple activations types *in turn*. Note that this is *not* a utility to patch multiple activations at once, it's just a useful scan to get a sense for what's going on in a model
 By block: We patch the residual stream at the start of each block, attention output and MLP output over each layer and position

In [16]:
every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(every_block_result, facet_col=0, facet_labels=["Residual Stream", "Attn Output", "MLP Output"], title="Activation Patching Per Block", xaxis="Position", yaxis="Layer", zmax=1, zmin=-1, x= [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])

  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/180 [00:00<?, ?it/s]

 By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow.

In [17]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (All Pos)", xaxis="Head", yaxis="Layer", zmax=1, zmin=-1)
# [markdown]
# We can also do by head *and* by position. This is a bit slow, but it can give useful + fine-grained detail

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

In [18]:
if DO_SLOW_RUNS:
    every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)
    every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, "act_type layer pos head -> act_type (layer head) pos")
    imshow(every_head_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (By Pos)", xaxis="Position", yaxis="Layer & Head", zmax=1, zmin=-1, x= [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=ALL_HEAD_LABELS)

 ## Induction Patching
 To show how easy it is, lets do that again with induction heads in a 2L Attention Only model
 The input will be repeated random tokens eg BOS 1 5 8 9 2 1 5 8 9 2, and we judge the model's ability to predict the second repetition with its induction heads
 Lets call A, B and C different (non-repeated) random sequences. We'll start with clean tokens AA and corrupted tokens AB, and see how well the model can predict the second A given the first A

 ### Setup

In [19]:
attn_only = HookedTransformer.from_pretrained("attn-only-2l")
batch = 4
seq_len = 20
rand_tokens_A = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)
rand_tokens_B = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)
rand_tokens_C = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)
bos = torch.tensor([attn_only.tokenizer.bos_token_id]*batch)[:, None].to(attn_only.cfg.device)
clean_tokens_induction = torch.cat([bos, rand_tokens_A, rand_tokens_A], dim=1).to(attn_only.cfg.device)
corrupted_tokens_induction = torch.cat([bos, rand_tokens_A, rand_tokens_B], dim=1).to(attn_only.cfg.device)

Loaded pretrained model attn-only-2l into HookedTransformer


In [20]:
clean_logits_induction, clean_cache_induction = attn_only.run_with_cache(clean_tokens_induction)
corrupted_logits_induction, corrupted_cache_induction = attn_only.run_with_cache(corrupted_tokens_induction)

 We define our metric as negative loss on the second half (negative loss so that higher is better)
 This time we won't normalise our metric

In [21]:
def induction_loss(logits, answer_token_indices=rand_tokens_A):
    seq_len = answer_token_indices.shape[1]

    # logits: batch x seq_len x vocab_size
    # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)
    final_logits = logits[:, -seq_len:-1]
    final_log_probs = final_logits.log_softmax(-1)
    return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()
CLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()
print("Clean baseline:", CLEAN_BASELINE_INDUCTION)
CORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()
print("Corrupted baseline:", CORRUPTED_BASELINE_INDUCTION)

Clean baseline: -2.2928695678710938
Corrupted baseline: -13.125859260559082


 ### Patching

In [22]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (All Pos)", xaxis="Head", yaxis="Layer", zmax=CLEAN_BASELINE_INDUCTION)

if DO_SLOW_RUNS:
    every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)
    every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, "act_type layer pos head -> act_type (layer head) pos")
    imshow(every_head_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (By Pos)", xaxis="Position", yaxis="Layer & Head", zmax=CLEAN_BASELINE_INDUCTION, x= [f"{tok}_{i}" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens[0]))], y=ALL_HEAD_LABELS)

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

 ### Changing the Corrupted Baseline
 We can also change the corrupted baseline easily to check what things look like! We'll keep clean as AA, but rather than corrupted as AB, we'll try out:
 * BA - This has a corrupted first half, so we expect both keys *and* values to matter. Head output patching should work, but value and key and pattern won't.
 * BB - This is still inductiony but with different tokens. So keys, queries and patterns don't matter, head output patching will work, and value will.
 * BC - This is just random tokens, so everything is corrupted! The induction head needs queries, keys *and* values, so only output will work.

In [23]:
corrupted_tokens_induction_BA = torch.cat([bos, rand_tokens_B, rand_tokens_A], dim=1).to(attn_only.cfg.device)
corrupted_tokens_induction_BB = torch.cat([bos, rand_tokens_B, rand_tokens_B], dim=1).to(attn_only.cfg.device)
corrupted_tokens_induction_BC = torch.cat([bos, rand_tokens_B, rand_tokens_C], dim=1).to(attn_only.cfg.device)

In [24]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BA, clean_cache_induction, induction_loss)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head on BA (All Pos)", xaxis="Head", yaxis="Layer", zmax=CLEAN_BASELINE_INDUCTION)
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BB, clean_cache_induction, induction_loss)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head on BB (All Pos)", xaxis="Head", yaxis="Layer", zmax=CLEAN_BASELINE_INDUCTION)
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BC, clean_cache_induction, induction_loss)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head on BC (All Pos)", xaxis="Head", yaxis="Layer", zmax=CLEAN_BASELINE_INDUCTION)

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]