 # Attribution Patching Demo
 **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**
 This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)

 The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass

 I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down.

 <b style="color: red">To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.</b>

 **Tips for reading this Colab:**
 * You can run all this code for yourself!
 * The graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate
 * Collapse irrelevant sections with the dropdown arrows
 * Search the page using the search in the sidebar, not CTRL+F

 ## Setup (Ignore)

In [1]:
import os

IN_COLAB = 'google.colab' in str(get_ipython())
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
DEBUG_MODE = False
DO_SLOW_RUNS = not IN_GITHUB
TORCH_DEVICE = "cuda" if not IN_GITHUB else "cpu"
EPOCHS_SIZE = 4000 if not IN_GITHUB else 25

if IN_COLAB or IN_GITHUB:
    
    %pip install transformer_lens
    %pip install torchtyping
    # Install my janky personal plotting utils
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    # Needed for PySvelte to work, v3 came out and broke things...
    %pip install typeguard==2.13.3
    print("Running as a Colab or github notebook")
else:
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional, Callable
from functools import partial
import copy
import itertools
import json

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML, Markdown

In [4]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

 Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:

In [5]:
from neel_plotly import line, imshow, scatter

In [6]:
import transformer_lens.patching as patching

 ## IOI Patching Setup
 This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important.

In [7]:
model = HookedTransformer.from_pretrained("gpt2-small")
model.set_use_attn_result(True)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [8]:
prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']
answers = [(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]

clean_tokens = model.to_tokens(prompts)
# Swap each adjacent pair, with a hacky list comprehension
corrupted_tokens = clean_tokens[
    [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]
    ]
print("Clean string 0", model.to_string(clean_tokens[0]))
print("Corrupted string 0", model.to_string(corrupted_tokens[0]))

answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)
print("Answer token indices", answer_token_indices)

Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to
Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to
Answer token indices tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]])


In [9]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    if len(logits.shape)==3:
        # Get final logits only
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 3.5519
Corrupted logit diff: -3.5519


In [10]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff
def ioi_metric(logits, answer_token_indices=answer_token_indices):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


 ## Patching
 In the following cells, we define attribution patching and use it in various ways on the model.

In [11]:
Metric = Callable[[TT["batch_and_pos_dims", "d_model"]], float]

In [12]:
filter_not_qkv_input = lambda name: "_input" not in name
def get_cache_fwd_and_bwd(model, tokens, metric):
    model.reset_hooks()
    cache = {}
    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()
    model.add_hook(filter_not_qkv_input, forward_cache_hook, "fwd")

    grad_cache = {}
    def backward_cache_hook(act, hook):
        grad_cache[hook.name] = act.detach()
    model.add_hook(filter_not_qkv_input, backward_cache_hook, "bwd")

    value = metric(model(tokens))
    value.backward()
    model.reset_hooks()
    return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)

clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)
print("Clean Value:", clean_value)
print("Clean Activations Cached:", len(clean_cache))
print("Clean Gradients Cached:", len(clean_grad_cache))
corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(model, corrupted_tokens, ioi_metric)
print("Corrupted Value:", corrupted_value)
print("Corrupted Activations Cached:", len(corrupted_cache))
print("Corrupted Gradients Cached:", len(corrupted_grad_cache))

Clean Value: 1.0
Clean Activations Cached: 220
Clean Gradients Cached: 220
Corrupted Value: 0.0
Corrupted Activations Cached: 220
Corrupted Gradients Cached: 220


 ### Attention Attribution
 The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!
 Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says "I should activate the IOI circuit", etc. Though using logit diff as our metric *does*
 Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).
 We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!

In [13]:
def create_attention_attr(clean_cache, clean_grad_cache) -> TT["batch", "layer", "head_index", "dest", "src"]:
    attention_stack = torch.stack([clean_cache["pattern", l] for l in range(model.cfg.n_layers)], dim=0)
    attention_grad_stack = torch.stack([clean_grad_cache["pattern", l] for l in range(model.cfg.n_layers)], dim=0)
    attention_attr = attention_grad_stack * attention_stack
    attention_attr = einops.rearrange(attention_attr, "layer batch head_index dest src -> batch layer head_index dest src")
    return attention_attr

attention_attr = create_attention_attr(clean_cache, clean_grad_cache)

In [14]:
HEAD_NAMES = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
HEAD_NAMES_SIGNED = [f"{name}{sign}" for name in HEAD_NAMES for sign in ["+", "-"]]
HEAD_NAMES_QKV = [f"{name}{act_name}" for name in HEAD_NAMES for act_name in ["Q", "K", "V"]]
print(HEAD_NAMES[:5])
print(HEAD_NAMES_SIGNED[:5])
print(HEAD_NAMES_QKV[:5])

['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']
['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']
['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']


 An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern.

In [15]:
def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=""):
    if len(tokens.shape)==2:
        tokens = tokens[index]
    if len(attention_attr.shape)==5:
        attention_attr = attention_attr[index]
    attention_attr_pos = attention_attr.clamp(min=-1e-5)
    attention_attr_neg =  - attention_attr.clamp(max=1e-5)
    attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)
    attention_attr_signed = einops.rearrange(attention_attr_signed, "sign layer head_index dest src -> (layer head_index sign) dest src")
    attention_attr_signed = attention_attr_signed / attention_attr_signed.max()
    attention_attr_indices = attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)
    # print(attention_attr_indices.shape)
    # print(attention_attr_indices)
    attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]
    head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]

    if title: display(Markdown("### "+title))
    if DO_SLOW_RUNS:
        display(pysvelte.AttentionMulti(tokens=model.to_str_tokens(tokens), attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k], head_labels=head_labels[:top_k]))


plot_attention_attr(attention_attr, clean_tokens, index=0, title="Attention Attribution for first sequence")

plot_attention_attr(attention_attr.sum(0), clean_tokens[0], title="Summed Attention Attribution for all sequences")
print("Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.")

### Attention Attribution for first sequence

### Summed Attention Attribution for all sequences

Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.


 ## Attribution Patching
 In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))
 ### Residual Stream Patching
 <details><summary>Note: We add up across both d_model and batch (Explanation).</summary>
 We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.
 We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.</details>

In [16]:
def attr_patch_residual(
        clean_cache: ActivationCache, 
        corrupted_cache: ActivationCache, 
        corrupted_grad_cache: ActivationCache,
    ) -> TT["component", "pos"]:
    clean_residual, residual_labels = clean_cache.accumulated_resid(-1, incl_mid=True, return_labels=True)
    corrupted_residual = corrupted_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)
    corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)
    residual_attr = einops.reduce(
        corrupted_grad_residual * (clean_residual - corrupted_residual),
        "component batch pos d_model -> component pos",
        "sum"
    )
    return residual_attr, residual_labels

residual_attr, residual_labels = attr_patch_residual(clean_cache, corrupted_cache, corrupted_grad_cache)
imshow(residual_attr, y=residual_labels, yaxis="Component", xaxis="Position", title="Residual Attribution Patching")

# ### Layer Output Patching

In [17]:
def attr_patch_layer_out(
        clean_cache: ActivationCache, 
        corrupted_cache: ActivationCache, 
        corrupted_grad_cache: ActivationCache,
    ) -> TT["component", "pos"]:
    clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)
    corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)
    corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(-1, return_labels=False)
    layer_out_attr = einops.reduce(
        corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),
        "component batch pos d_model -> component pos",
        "sum"
    )
    return layer_out_attr, labels

layer_out_attr, layer_out_labels = attr_patch_layer_out(clean_cache, corrupted_cache, corrupted_grad_cache)
imshow(layer_out_attr, y=layer_out_labels, yaxis="Component", xaxis="Position", title="Layer Output Attribution Patching")

In [18]:
def attr_patch_head_out(
        clean_cache: ActivationCache, 
        corrupted_cache: ActivationCache, 
        corrupted_grad_cache: ActivationCache,
    ) -> TT["component", "pos"]:
    labels = HEAD_NAMES

    clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)
    corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)
    corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(-1, return_labels=False)
    head_out_attr = einops.reduce(
        corrupted_grad_head_out * (clean_head_out - corrupted_head_out),
        "component batch pos d_model -> component pos",
        "sum"
    )
    return head_out_attr, labels

head_out_attr, head_out_labels = attr_patch_head_out(clean_cache, corrupted_cache, corrupted_grad_cache)
imshow(head_out_attr, y=head_out_labels, yaxis="Component", xaxis="Position", title="Head Output Attribution Patching")
sum_head_out_attr = einops.reduce(head_out_attr, "(layer head) pos -> layer head", "sum", layer=model.cfg.n_layers, head=model.cfg.n_heads)
imshow(sum_head_out_attr, yaxis="Layer", xaxis="Head Index", title="Head Output Attribution Patching Sum Over Pos")

 ### Head Activation Patching
 Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!
 As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.
 We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)

In [19]:
from typing_extensions import Literal
def stack_head_vector_from_cache(
        cache, 
        activation_name: Literal["q", "k", "v", "z"]
    ) -> TT["layer_and_head_index", "batch", "pos", "d_head"]:
    """Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor."""
    stacked_head_vectors = torch.stack([cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0)
    stacked_head_vectors = einops.rearrange(
        stacked_head_vectors,
        "layer batch pos head_index d_head -> (layer head_index) batch pos d_head"
    )
    return stacked_head_vectors

def attr_patch_head_vector(
        clean_cache: ActivationCache, 
        corrupted_cache: ActivationCache, 
        corrupted_grad_cache: ActivationCache,
        activation_name: Literal["q", "k", "v", "z"],
    ) -> TT["component", "pos"]:
    labels = HEAD_NAMES

    clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)
    corrupted_head_vector = stack_head_vector_from_cache(corrupted_cache, activation_name)
    corrupted_grad_head_vector = stack_head_vector_from_cache(corrupted_grad_cache, activation_name)
    head_vector_attr = einops.reduce(
        corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),
        "component batch pos d_head -> component pos",
        "sum"
    )
    return head_vector_attr, labels

head_vector_attr_dict = {}
for activation_name, activation_name_full in [("k", "Key"), ("q", "Query"), ("v", "Value"), ("z", "Mixed Value")]:
    display(Markdown(f"#### {activation_name_full} Head Vector Attribution Patching"))
    head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(clean_cache, corrupted_cache, corrupted_grad_cache, activation_name)
    imshow(head_vector_attr_dict[activation_name], y=head_vector_labels, yaxis="Component", xaxis="Position", title=f"{activation_name_full} Attribution Patching")
    sum_head_vector_attr = einops.reduce(head_vector_attr_dict[activation_name], "(layer head) pos -> layer head", "sum", layer=model.cfg.n_layers, head=model.cfg.n_heads)
    imshow(sum_head_vector_attr, yaxis="Layer", xaxis="Head Index", title=f"{activation_name_full} Attribution Patching Sum Over Pos")

#### Key Head Vector Attribution Patching

#### Query Head Vector Attribution Patching

#### Value Head Vector Attribution Patching

#### Mixed Value Head Vector Attribution Patching

In [20]:
from typing_extensions import Literal
def stack_head_pattern_from_cache(
        cache, 
    ) -> TT["layer_and_head_index", "batch", "dest_pos", "src_pos"]:
    """Stacks the head patterns from the cache into a single tensor."""
    stacked_head_pattern = torch.stack([cache["pattern", l] for l in range(model.cfg.n_layers)], dim=0)
    stacked_head_pattern = einops.rearrange(
        stacked_head_pattern,
        "layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos"
    )
    return stacked_head_pattern

def attr_patch_head_pattern(
        clean_cache: ActivationCache, 
        corrupted_cache: ActivationCache, 
        corrupted_grad_cache: ActivationCache,
    ) -> TT["component", "dest_pos", "src_pos"]:
    labels = HEAD_NAMES

    clean_head_pattern = stack_head_pattern_from_cache(clean_cache)
    corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)
    corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)
    head_pattern_attr = einops.reduce(
        corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),
        "component batch dest_pos src_pos -> component dest_pos src_pos",
        "sum"
    )
    return head_pattern_attr, labels

head_pattern_attr, labels = attr_patch_head_pattern(clean_cache, corrupted_cache, corrupted_grad_cache)

plot_attention_attr(einops.rearrange(head_pattern_attr, "(layer head) dest src -> layer head dest src", layer=model.cfg.n_layers, head=model.cfg.n_heads), clean_tokens, index=0, title="Head Pattern Attribution Patching")

### Head Pattern Attribution Patching

In [21]:
def get_head_vector_grad_input_from_grad_cache(
        grad_cache: ActivationCache, 
        activation_name: Literal["q", "k", "v"],
        layer: int
    ) -> TT["batch", "pos", "head_index", "d_model"]:
    vector_grad = grad_cache[activation_name, layer]
    ln_scales = grad_cache["scale", layer, "ln1"]
    attn_layer_object = model.blocks[layer].attn
    if activation_name == "q":
        W = attn_layer_object.W_Q
    elif activation_name == "k":
        W = attn_layer_object.W_K
    elif activation_name == "v":
        W = attn_layer_object.W_V
    else:
        raise ValueError("Invalid activation name")

    return einsum("batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model", vector_grad, ln_scales.squeeze(-1), W)

def get_stacked_head_vector_grad_input(grad_cache, activation_name: Literal["q", "k", "v"]) -> TT["layer", "batch", "pos", "head_index", "d_model"]:
    return torch.stack([get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l) for l in range(model.cfg.n_layers)], dim=0)

def get_full_vector_grad_input(grad_cache) -> TT["qkv", "layer", "batch", "pos", "head_index", "d_model"]:
    return torch.stack([get_stacked_head_vector_grad_input(grad_cache, activation_name) for activation_name in ['q', 'k', 'v']], dim=0)

def attr_patch_head_path(
        clean_cache: ActivationCache, 
        corrupted_cache: ActivationCache, 
        corrupted_grad_cache: ActivationCache
    ) -> TT["qkv", "dest_component", "src_component", "pos"]:
    """
    Computes the attribution patch along the path between each pair of heads.

    Sets this to zero for the path from any late head to any early head

    """
    start_labels = HEAD_NAMES
    end_labels = HEAD_NAMES_QKV
    full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)
    clean_head_result_stack = clean_cache.stack_head_results(-1)
    corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)
    diff_head_result = einops.rearrange(
        clean_head_result_stack - corrupted_head_result_stack,
        "(layer head_index) batch pos d_model -> layer batch pos head_index d_model",
        layer = model.cfg.n_layers,
        head_index = model.cfg.n_heads,
    )
    path_attr = einsum(
        "qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos", 
        full_vector_grad_input, 
        diff_head_result)
    correct_layer_order_mask = (
        torch.arange(model.cfg.n_layers)[None, :, None, None, None, None] > 
        torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]).to(path_attr.device)
    zero = torch.zeros(1, device=path_attr.device)
    path_attr = torch.where(correct_layer_order_mask, path_attr, zero)

    path_attr = einops.rearrange(
        path_attr,
        "qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos",
    )
    return path_attr, end_labels, start_labels

head_path_attr, end_labels, start_labels = attr_patch_head_path(clean_cache, corrupted_cache, corrupted_grad_cache)
imshow(head_path_attr.sum(-1), y=end_labels, yaxis="Path End (Head Input)", x=start_labels, xaxis="Path Start (Head Output)", title="Head Path Attribution Patching")

 This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths.

In [22]:
head_out_values, head_out_indices  = head_out_attr.sum(-1).abs().sort(descending=True)
line(head_out_values)
top_head_indices = head_out_indices[:22].sort().values
top_end_indices = []
top_end_labels = []
top_start_indices = []
top_start_labels = []
for i in top_head_indices:
    i = i.item()
    top_start_indices.append(i)
    top_start_labels.append(start_labels[i])
    for j in range(3):
        top_end_indices.append(3*i+j)
        top_end_labels.append(end_labels[3*i+j])

imshow(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), y=top_end_labels, yaxis="Path End (Head Input)", x=top_start_labels, xaxis="Path Start (Head Output)", title="Head Path Attribution Patching (Filtered for Top Heads)")

In [23]:
for j, composition_type in enumerate(["Query", "Key", "Value"]):
    imshow(head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1), y=top_end_labels[j::3], yaxis="Path End (Head Input)", x=top_start_labels, xaxis="Path Start (Head Output)", title=f"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)")

In [24]:
top_head_path_attr = einops.rearrange(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), "(head_end qkv) head_start -> qkv head_end head_start", qkv=3)
imshow(top_head_path_attr, y=[i[:-1] for i in top_end_labels[::3]], yaxis="Path End (Head Input)", x=top_start_labels, xaxis="Path Start (Head Output)", title=f"Head Path Attribution Patching (Filtered for Top Heads)", facet_col=0, facet_labels=["Query", "Key", "Value"])

 Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )

In [25]:
interesting_heads = [5 * model.cfg.n_heads + 5, 8 * model.cfg.n_heads + 6, 9 * model.cfg.n_heads + 9]
interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]
for head_index, label in zip(interesting_heads, interesting_head_labels):
    in_paths = head_path_attr[3*head_index:3*head_index+3].sum(-1)
    out_paths = head_path_attr[:, head_index].sum(-1)
    out_paths = einops.rearrange(out_paths, "(layer_head qkv) -> qkv layer_head", qkv=3)
    all_paths = torch.cat([in_paths, out_paths], dim=0)
    all_paths = einops.rearrange(all_paths, "path_type (layer head) -> path_type layer head", layer=model.cfg.n_layers, head=model.cfg.n_heads)
    imshow(all_paths, facet_col=0, facet_labels=["Query (In)", "Key (In)", "Value (In)", "Query (Out)", "Key (Out)", "Value (Out)"], title=f"Input and Output Paths for head {label}", yaxis="Layer", xaxis="Head")

 ## Validating Attribution vs Activation Patching
 Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream
 My fuzzy intuition is that attribution patching works badly for "big" things which are poorly modelled as linear approximations, and works well for "small" things which are more like incremental changes. Anything involving replacing the embedding is a "big" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an "extended embedding" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.
 See more discussion in the accompanying blog post!


 First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!

In [26]:
attribution_cache_dict = {}
for key in corrupted_grad_cache.cache_dict.keys():
    attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key])
attr_cache = ActivationCache(attribution_cache_dict, model)

 By block: For each head we patch the starting residual stream, attention output + MLP output

In [27]:
str_tokens = model.to_str_tokens(clean_tokens[0])
context_length = len(str_tokens)

In [28]:
every_block_act_patch_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(every_block_act_patch_result, facet_col=0, facet_labels=["Residual Stream", "Attn Output", "MLP Output"], title="Activation Patching Per Block", xaxis="Position", yaxis="Layer", zmax=1, zmin=-1, x= [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])

  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/180 [00:00<?, ?it/s]

In [29]:
def get_attr_patch_block_every(attr_cache):
    resid_pre_attr = einops.reduce(
        attr_cache.stack_activation("resid_pre"),
        "layer batch pos d_model -> layer pos",
        "sum",
    )
    attn_out_attr = einops.reduce(
        attr_cache.stack_activation("attn_out"),
        "layer batch pos d_model -> layer pos",
        "sum",
    )
    mlp_out_attr = einops.reduce(
        attr_cache.stack_activation("mlp_out"),
        "layer batch pos d_model -> layer pos",
        "sum",
    )

    every_block_attr_patch_result = torch.stack([resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0)
    return every_block_attr_patch_result
every_block_attr_patch_result =  get_attr_patch_block_every(attr_cache)
imshow(every_block_attr_patch_result, facet_col=0, facet_labels=["Residual Stream", "Attn Output", "MLP Output"], title="Attribution Patching Per Block", xaxis="Position", yaxis="Layer", zmax=1, zmin=-1, x= [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])

In [30]:
if DO_SLOW_RUNS:
    scatter(y=every_block_attr_patch_result.reshape(3, -1), x=every_block_act_patch_result.reshape(3, -1), facet_col=0, facet_labels=["Residual Stream", "Attn Output", "MLP Output"], title="Attribution vs Activation Patching Per Block", xaxis="Activation Patch", yaxis="Attribution Patch", hover=[f"Layer {l}, Position {p}, |{str_tokens[p]}|" for l in range(model.cfg.n_layers) for p in range(context_length)], color=einops.repeat(torch.arange(model.cfg.n_layers), "layer -> (layer pos)", pos=context_length), color_continuous_scale="Portland")

 By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow.

In [31]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (All Pos)", xaxis="Head", yaxis="Layer", zmax=1, zmin=-1)

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

In [32]:
def get_attr_patch_attn_head_all_pos_every(attr_cache):
    head_out_all_pos_attr = einops.reduce(
        attr_cache.stack_activation("z"),
        "layer batch pos head_index d_head -> layer head_index",
        "sum",
    )
    head_q_all_pos_attr = einops.reduce(
        attr_cache.stack_activation("q"),
        "layer batch pos head_index d_head -> layer head_index",
        "sum",
    )
    head_k_all_pos_attr = einops.reduce(
        attr_cache.stack_activation("k"),
        "layer batch pos head_index d_head -> layer head_index",
        "sum",
    )
    head_v_all_pos_attr = einops.reduce(
        attr_cache.stack_activation("v"),
        "layer batch pos head_index d_head -> layer head_index",
        "sum",
    )
    head_pattern_all_pos_attr = einops.reduce(
        attr_cache.stack_activation("pattern"),
        "layer batch head_index dest_pos src_pos -> layer head_index",
        "sum",
    )

    return torch.stack([head_out_all_pos_attr, head_q_all_pos_attr, head_k_all_pos_attr, head_v_all_pos_attr, head_pattern_all_pos_attr])
    
every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(attr_cache)
imshow(every_head_all_pos_attr_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Attribution Patching Per Head (All Pos)", xaxis="Head", yaxis="Layer", zmax=1, zmin=-1)

In [33]:
if not DO_SLOW_RUNS:
    scatter(y=every_head_all_pos_attr_patch_result.reshape(5, -1), x=every_head_all_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Attribution vs Activation Patching Per Head (All Pos)", xaxis="Activation Patch", yaxis="Attribution Patch", include_diag=True, hover=head_out_labels, color=einops.repeat(torch.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads), color_continuous_scale="Portland")

 We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!
 Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts

In [34]:
graph_tok_labels = [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]
imshow(clean_cache["pattern", 5][:, 5], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title="Attention for Head L5H5", facet_name="Prompt")
imshow(clean_cache["pattern", 10][:, 7], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title="Attention for Head L10H7", facet_name="Prompt")
imshow(clean_cache["pattern", 11][:, 10], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title="Attention for Head L11H10", facet_name="Prompt")


# [markdown]

In [2]:
every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric) if not IN_GITHUB else torch.unsqueeze(torch.unsqueeze(torch.LongTensor(corrupted_tokens), 2), 0)
every_head_by_pos_act_patch_result = einops.rearrange(every_head_by_pos_act_patch_result, "act_type layer pos head -> act_type (layer head) pos")
if not IN_GITHUB:
    imshow(every_head_by_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (By Pos)", xaxis="Position", yaxis="Layer & Head", zmax=1, zmin=-1, x= [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)

NameError: name 'torch' is not defined

In [36]:
def get_attr_patch_attn_head_by_pos_every(attr_cache):
    head_out_by_pos_attr = einops.reduce(
        attr_cache.stack_activation("z"),
        "layer batch pos head_index d_head -> layer pos head_index",
        "sum",
    )
    head_q_by_pos_attr = einops.reduce(
        attr_cache.stack_activation("q"),
        "layer batch pos head_index d_head -> layer pos head_index",
        "sum",
    )
    head_k_by_pos_attr = einops.reduce(
        attr_cache.stack_activation("k"),
        "layer batch pos head_index d_head -> layer pos head_index",
        "sum",
    )
    head_v_by_pos_attr = einops.reduce(
        attr_cache.stack_activation("v"),
        "layer batch pos head_index d_head -> layer pos head_index",
        "sum",
    )
    head_pattern_by_pos_attr = einops.reduce(
        attr_cache.stack_activation("pattern"),
        "layer batch head_index dest_pos src_pos -> layer dest_pos head_index",
        "sum",
    )

    return torch.stack([head_out_by_pos_attr, head_q_by_pos_attr, head_k_by_pos_attr, head_v_by_pos_attr, head_pattern_by_pos_attr])
every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)
every_head_by_pos_attr_patch_result = einops.rearrange(every_head_by_pos_attr_patch_result, "act_type layer pos head -> act_type (layer head) pos")
if DO_SLOW_RUNS:
    imshow(every_head_by_pos_attr_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Attribution Patching Per Head (By Pos)", xaxis="Position", yaxis="Layer & Head", zmax=1, zmin=-1, x= [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)

In [37]:
scatter(y=every_head_by_pos_attr_patch_result.reshape(5, -1), x=every_head_by_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Attribution vs Activation Patching Per Head (by Pos)", xaxis="Activation Patch", yaxis="Attribution Patch", include_diag=True, hover=[f"{label} {tok}" for label in head_out_labels for tok in graph_tok_labels], color=einops.repeat(torch.arange(model.cfg.n_layers), "layer -> (layer head pos)", head=model.cfg.n_heads, pos = 15), color_continuous_scale="Portland")

 ## Factual Knowledge Patching Example
 Incomplete, but maybe of interest!
 Note that I have better results with the corrupted prompt as having random words rather than Colosseum.

In [None]:
if not IN_GITHUB:
    gpt2_xl = HookedTransformer.from_pretrained("gpt2-xl", device=TORCH_DEVICE)
clean_prompt = "The Eiffel Tower is located in the city of"
clean_answer = " Paris"
# corrupted_prompt = "The red brown fox jumps is located in the city of"
corrupted_prompt = "The Colosseum is located in the city of"
corrupted_answer = " Rome"
if not IN_GITHUB:
    utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)
    utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)

: 

: 

In [None]:
if not IN_GITHUB:
    clean_answer_index = gpt2_xl.to_single_token(clean_answer)
    corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)
def factual_logit_diff(logits: TT["batch", "position", "d_vocab"]):
    return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]

In [None]:

def factual_metric(logits: TT["batch", "position", "d_vocab"]):
    return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL)
if not IN_GITHUB:
    clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)
    CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()
    corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)
    CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()

    print("Clean logit diff:", CLEAN_LOGIT_DIFF_FACTUAL)
    print("Corrupted logit diff:", CORRUPTED_LOGIT_DIFF_FACTUAL)
    print("Clean Metric:", factual_metric(clean_logits))
    print("Corrupted Metric:", factual_metric(corrupted_logits))

In [None]:
# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)

In [None]:
if not IN_GITHUB:
    clean_tokens = gpt2_xl.to_tokens(clean_prompt)
    clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)
    corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)
    corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)
    print("Clean:", clean_str_tokens)
    print("Corrupted:", corrupted_str_tokens)

In [None]:
def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):
    if len(corrupted_tokens.shape)==2:
        corrupted_tokens = corrupted_tokens[0]
    residual_patches = torch.zeros((model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device)
    def residual_hook(resid_pre, hook, layer, pos):
        resid_pre[:, pos, :] = clean_cache["resid_pre", layer][:, pos, :]
        return resid_pre
    for layer in tqdm.tqdm(range(model.cfg.n_layers)):
        for pos in range(len(corrupted_tokens)):
            patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[(f"blocks.{layer}.hook_resid_pre", partial(residual_hook, layer=layer, pos=pos))])
            residual_patches[layer, pos] = metric(patched_logits).item()
    return residual_patches


if DO_SLOW_RUNS:
    residual_act_patch = act_patch_residual(clean_cache, corrupted_tokens, gpt2_xl, factual_metric)
    imshow(residual_act_patch, title="Factual Recall Patching (Residual)", xaxis="Position", yaxis="Layer", x=clean_str_tokens)