# Setup

Imports, device, and utils


In [1]:
from IPython import get_ipython

ip = get_ipython()
if not ip.extension_manager.loaded:
    ip.extension_manager.load("autoreload")
    %autoreload 2

In [2]:
import plotly.io as pio

pio.renderers.default = "notebook_connected"

In [3]:
import circuitsvis as cv

# Testing that the library works
cv.examples.hello("Andrew")

In [4]:
# Import stuff
import torch
import torch.nn as nn
import einops
import numpy as np
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial
from transformers import AutoModelForCausalLM, AutoTokenizer

In [5]:
# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import ActivationCache, HookedTransformer, FactoredMatrix

In [6]:
# save GPU mem for inference
torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

# clear cache
torch.cuda.empty_cache()

Disabled automatic differentiation


In [7]:
# plotting helpers
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show(renderer)


def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x": xaxis, "y": yaxis}, **kwargs).show(
        renderer
    )


def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y, x=x, labels={"x": xaxis, "y": yaxis, "color": caxis}, **kwargs
    ).show(renderer)


In [8]:
device = utils.get_device()
device

device(type='cuda')

In [9]:
seed = 1234
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Load model

HookedTransformer is a somewhat adapted GPT-2 architecture, but is computationally identical. The most significant changes are to the internal structure of the attention heads:

- The weights (W_K, W_Q, W_V) mapping the residual stream to queries, keys and values are 3 separate matrices, rather than big concatenated one.
- The weight matrices (W_K, W_Q, W_V, W_O) and activations (keys, queries, values, z (values mixed by attention pattern)) have separate head_index and d_head axes, rather than flattening them into one big axis.
  - The activations all have shape `[batch, position, head_index, d_head]`
  - W_K, W_Q, W_V have shape `[head_index, d_model, d_head]` and W_O has shape `[head_index, d_head, d_model]`

The various flags are simplifications that preserve the model's output but simplify its internals.
We verify this by comparing the logits of the original model and the HookedTransformer model.


In [10]:
def assert_hf_and_tl_model_are_close(
    hf_model,
    tl_model,
    tokenizer,
    prompt="12 x 34 = ",
    atol=1e-5,
):
    prompt_toks = tokenizer(prompt, return_tensors="pt").input_ids

    hf_logits = hf_model(prompt_toks.to(hf_model.device)).logits
    tl_logits = tl_model(prompt_toks).to(hf_logits)

    assert torch.allclose(
        torch.softmax(hf_logits, dim=-1), torch.softmax(tl_logits, dim=-1), atol=atol
    )


In [11]:
# NBVAL_IGNORE_OUTPUT

model_path = "Qwen/Qwen2.5-0.5B"

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

hf_model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map=device,
    trust_remote_code=True,
).eval()

tl_model = HookedTransformer.from_pretrained_no_processing(
    model_path,
    device=device,
    dtype=torch.float32,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
).to(device)

assert_hf_and_tl_model_are_close(hf_model, tl_model, tokenizer)




Loaded pretrained model Qwen/Qwen2.5-0.5B into HookedTransformer
Moving model to device:  cuda


## Parameter Names

Here is a list of the parameters and shapes in the model. By convention, all weight matrices multiply on the right (ie `new_activation = old_activation @ weights + bias`).

Reminder of the key hyper-params:

- `n_layers`: 12. The number of transformer blocks in the model (a block contains an attention layer and an MLP layer)
- `n_heads`: 14. The number of attention heads per attention layer
- `d_model`: 768. The residual stream width.
- `d_head`: 64. The internal dimension of an attention head activation.
- `d_mlp`: 3072. The internal dimension of the MLP layers (ie the number of neurons).
- `d_vocab`: 50257. The number of tokens in the vocabulary.
- `n_ctx`: 1024. The maximum number of tokens in an input prompt.


In [12]:
for name, param in tl_model.named_parameters():
    if name.startswith("blocks.0."):
        print(name, param.shape)


blocks.0.attn.W_Q torch.Size([14, 896, 64])
blocks.0.attn.W_O torch.Size([14, 64, 896])
blocks.0.attn.b_Q torch.Size([14, 64])
blocks.0.attn.b_O torch.Size([896])
blocks.0.attn._W_K torch.Size([2, 896, 64])
blocks.0.attn._W_V torch.Size([2, 896, 64])
blocks.0.attn._b_K torch.Size([2, 64])
blocks.0.attn._b_V torch.Size([2, 64])
blocks.0.mlp.W_in torch.Size([896, 4864])
blocks.0.mlp.W_out torch.Size([4864, 896])
blocks.0.mlp.W_gate torch.Size([896, 4864])
blocks.0.mlp.b_in torch.Size([4864])
blocks.0.mlp.b_out torch.Size([896])


In [13]:
for name, param in tl_model.named_parameters():
    if not name.startswith("blocks"):
        print(name, param.shape)


embed.W_E torch.Size([151936, 896])
unembed.W_U torch.Size([896, 151936])
unembed.b_U torch.Size([151936])


## Activation + Hook Names

Let's get a list of all model activations/hook names by entering in a short prompt and add a hook function to each activations to print its name and shape. To avoid spam, let's just add this to activations in the first block or not in a block.

Note 1: Each LayerNorm has a hook for the scale factor (ie the standard deviation of the input activations for each token position & batch element) and for the normalized output (ie the input activation with mean 0 and standard deviation 1, but _before_ applying scaling or translating with learned weights). LayerNorm is applied every time a layer reads from the residual stream: `ln1` is the LayerNorm before the attention layer in a block, `ln2` the one before the MLP layer, and `ln_final` is the LayerNorm before the unembed.

Note 2: _Every_ activation apart from the attention pattern and attention scores has shape beginning with `[batch, position]`. The attention pattern and scores have shape `[batch, head_index, dest_position, source_position]` (the numbers are the same, unless we're using caching).


In [14]:
example_problem = "12345 x 54321 = "
print("Num tokens:", len(tl_model.to_tokens(example_problem)[0]))


def print_name_shape_hook_function(activation, hook):
    print(hook.name, activation.shape)


def not_in_late_block_filter(name):
    return name.startswith("blocks.0.") or not name.startswith("blocks")


tl_model.run_with_hooks(
    example_problem,
    return_type=None,
    fwd_hooks=[(not_in_late_block_filter, print_name_shape_hook_function)],
)

Num tokens: 14
hook_embed torch.Size([1, 14, 896])
blocks.0.hook_resid_pre torch.Size([1, 14, 896])
blocks.0.ln1.hook_scale torch.Size([1, 14, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 14, 896])
blocks.0.ln1.hook_scale torch.Size([1, 14, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 14, 896])
blocks.0.ln1.hook_scale torch.Size([1, 14, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 14, 896])
blocks.0.attn.hook_q torch.Size([1, 14, 14, 64])
blocks.0.attn.hook_k torch.Size([1, 14, 2, 64])
blocks.0.attn.hook_v torch.Size([1, 14, 2, 64])
blocks.0.attn.hook_rot_q torch.Size([1, 14, 14, 64])
blocks.0.attn.hook_rot_k torch.Size([1, 14, 2, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 14, 14, 14])
blocks.0.attn.hook_pattern torch.Size([1, 14, 14, 14])
blocks.0.attn.hook_z torch.Size([1, 14, 14, 64])
blocks.0.hook_attn_out torch.Size([1, 14, 896])
blocks.0.hook_resid_mid torch.Size([1, 14, 896])
blocks.0.ln2.hook_scale torch.Size([1, 14, 1])
blocks.0.ln2.hook_normalized torch.Size(

# Task: Multi-digit multiplication

The next step is to verify that the model can _actually_ do the task!


In [15]:
# Generate an example multiplication problem
problems = []
answers = []
num_digits = 5

# Create numbers with num_digits digits
num1 = np.random.randint(10 ** (num_digits - 1), 10**num_digits)
num2 = np.random.randint(10 ** (num_digits - 1), 10**num_digits)
problems.append(f"{num1} x {num2} =")
answers.append(f" {num1 * num2}")

# Display the problems
for problem, answer in zip(problems, answers):
    print(problem + answer)


68067 x 44086 = 3000801762


In [16]:
utils.test_prompt(problems[0], answers[0], tl_model, prepend_bos=True)


Tokenized prompt: ['<|endoftext|>', '6', '8', '0', '6', '7', ' x', ' ', '4', '4', '0', '8', '6', ' =']
Tokenized answer: [' ', '3', '0', '0', '0', '8', '0', '1', '7', '6', '2']


Top 0th token. Logit: 20.62 Prob: 84.42% Token: | |
Top 1th token. Logit: 17.73 Prob:  4.66% Token: | ?|
Top 2th token. Logit: 17.67 Prob:  4.40% Token: | ?
|
Top 3th token. Logit: 16.26 Prob:  1.07% Token: | ?

|
Top 4th token. Logit: 15.32 Prob:  0.42% Token: | (|
Top 5th token. Logit: 15.22 Prob:  0.38% Token: | x|
Top 6th token. Logit: 14.43 Prob:  0.17% Token: | __|
Top 7th token. Logit: 14.39 Prob:  0.17% Token: | $|
Top 8th token. Logit: 14.35 Prob:  0.16% Token: |1|
Top 9th token. Logit: 14.16 Prob:  0.13% Token: | \|


Top 0th token. Logit: 21.82 Prob: 73.58% Token: |3|
Top 1th token. Logit: 20.34 Prob: 16.88% Token: |2|
Top 2th token. Logit: 18.31 Prob:  2.21% Token: |1|
Top 3th token. Logit: 18.24 Prob:  2.05% Token: |4|
Top 4th token. Logit: 17.93 Prob:  1.50% Token: |6|
Top 5th token. Logit: 17.39 Prob:  0.88% Token: |8|
Top 6th token. Logit: 17.27 Prob:  0.78% Token: |5|
Top 7th token. Logit: 17.19 Prob:  0.72% Token: |0|
Top 8th token. Logit: 17.11 Prob:  0.67% Token: |9|
Top 9th token. Logit: 17.06 Prob:  0.63% Token: |7|


Top 0th token. Logit: 20.86 Prob: 54.88% Token: |1|
Top 1th token. Logit: 19.79 Prob: 18.67% Token: |2|
Top 2th token. Logit: 19.41 Prob: 12.84% Token: |0|
Top 3th token. Logit: 18.36 Prob:  4.49% Token: |3|
Top 4th token. Logit: 18.33 Prob:  4.37% Token: |.|
Top 5th token. Logit: 17.63 Prob:  2.17% Token: |4|
Top 6th token. Logit: 16.80 Prob:  0.94% Token: |5|
Top 7th token. Logit: 15.82 Prob:  0.36% Token: |6|
Top 8th token. Logit: 15.24 Prob:  0.20% Token: |
|
Top 9th token. Logit: 15.23 Prob:  0.20% Token: |,|


Top 0th token. Logit: 18.31 Prob: 15.14% Token: |7|
Top 1th token. Logit: 18.05 Prob: 11.63% Token: |0|
Top 2th token. Logit: 17.83 Prob:  9.39% Token: |1|
Top 3th token. Logit: 17.72 Prob:  8.38% Token: |2|
Top 4th token. Logit: 17.70 Prob:  8.21% Token: |3|
Top 5th token. Logit: 17.64 Prob:  7.75% Token: |4|
Top 6th token. Logit: 17.63 Prob:  7.64% Token: |5|
Top 7th token. Logit: 17.61 Prob:  7.49% Token: |6|
Top 8th token. Logit: 17.60 Prob:  7.43% Token: |9|
Top 9th token. Logit: 17.58 Prob:  7.28% Token: |8|


Top 0th token. Logit: 18.35 Prob: 13.43% Token: |0|
Top 1th token. Logit: 18.10 Prob: 10.43% Token: |1|
Top 2th token. Logit: 17.97 Prob:  9.17% Token: |7|
Top 3th token. Logit: 17.96 Prob:  9.09% Token: |8|
Top 4th token. Logit: 17.94 Prob:  8.94% Token: |6|
Top 5th token. Logit: 17.92 Prob:  8.74% Token: |9|
Top 6th token. Logit: 17.90 Prob:  8.62% Token: |2|
Top 7th token. Logit: 17.89 Prob:  8.47% Token: |4|
Top 8th token. Logit: 17.85 Prob:  8.16% Token: |3|
Top 9th token. Logit: 17.78 Prob:  7.65% Token: |5|


Top 0th token. Logit: 19.61 Prob: 69.52% Token: |0|
Top 1th token. Logit: 16.73 Prob:  3.90% Token: |1|
Top 2th token. Logit: 16.51 Prob:  3.14% Token: |3|
Top 3th token. Logit: 16.39 Prob:  2.78% Token: |2|
Top 4th token. Logit: 16.30 Prob:  2.55% Token: |8|
Top 5th token. Logit: 16.29 Prob:  2.51% Token: |6|
Top 6th token. Logit: 16.25 Prob:  2.43% Token: |4|
Top 7th token. Logit: 16.17 Prob:  2.23% Token: |7|
Top 8th token. Logit: 16.09 Prob:  2.05% Token: |9|
Top 9th token. Logit: 15.97 Prob:  1.82% Token: |5|


Top 0th token. Logit: 17.25 Prob: 18.50% Token: |0|
Top 1th token. Logit: 16.59 Prob:  9.62% Token: |8|
Top 2th token. Logit: 16.55 Prob:  9.18% Token: |2|
Top 3th token. Logit: 16.54 Prob:  9.11% Token: |4|
Top 4th token. Logit: 16.51 Prob:  8.81% Token: |6|
Top 5th token. Logit: 16.30 Prob:  7.18% Token: |1|
Top 6th token. Logit: 16.26 Prob:  6.90% Token: |5|
Top 7th token. Logit: 16.24 Prob:  6.76% Token: |9|
Top 8th token. Logit: 16.19 Prob:  6.41% Token: |7|
Top 9th token. Logit: 16.16 Prob:  6.23% Token: |3|


Top 0th token. Logit: 18.17 Prob: 16.08% Token: |0|
Top 1th token. Logit: 18.06 Prob: 14.45% Token: |8|
Top 2th token. Logit: 17.84 Prob: 11.60% Token: |4|
Top 3th token. Logit: 17.74 Prob: 10.51% Token: |2|
Top 4th token. Logit: 17.59 Prob:  9.03% Token: |3|
Top 5th token. Logit: 17.44 Prob:  7.80% Token: |1|
Top 6th token. Logit: 17.25 Prob:  6.40% Token: |6|
Top 7th token. Logit: 17.10 Prob:  5.54% Token: |9|
Top 8th token. Logit: 16.88 Prob:  4.44% Token: |7|
Top 9th token. Logit: 16.82 Prob:  4.17% Token: |5|


Top 0th token. Logit: 18.89 Prob: 33.17% Token: |2|
Top 1th token. Logit: 18.11 Prob: 15.21% Token: |6|
Top 2th token. Logit: 17.34 Prob:  7.06% Token: |8|
Top 3th token. Logit: 17.34 Prob:  7.03% Token: |0|
Top 4th token. Logit: 17.25 Prob:  6.43% Token: |4|
Top 5th token. Logit: 17.25 Prob:  6.40% Token: |9|
Top 6th token. Logit: 17.01 Prob:  5.03% Token: |3|
Top 7th token. Logit: 16.95 Prob:  4.75% Token: |7|
Top 8th token. Logit: 16.68 Prob:  3.65% Token: |1|
Top 9th token. Logit: 16.67 Prob:  3.59% Token: |5|


Top 0th token. Logit: 18.13 Prob: 24.32% Token: |6|
Top 1th token. Logit: 17.96 Prob: 20.61% Token: |2|
Top 2th token. Logit: 16.87 Prob:  6.88% Token: |
|
Top 3th token. Logit: 16.84 Prob:  6.72% Token: |0|
Top 4th token. Logit: 16.78 Prob:  6.29% Token: |8|
Top 5th token. Logit: 16.64 Prob:  5.49% Token: |4|
Top 6th token. Logit: 16.50 Prob:  4.76% Token: |3|
Top 7th token. Logit: 16.43 Prob:  4.44% Token: |7|
Top 8th token. Logit: 16.41 Prob:  4.37% Token: |9|
Top 9th token. Logit: 16.25 Prob:  3.72% Token: |1|


Top 0th token. Logit: 18.39 Prob: 53.30% Token: |
|
Top 1th token. Logit: 16.77 Prob: 10.56% Token: |0|
Top 2th token. Logit: 15.91 Prob:  4.46% Token: |

|
Top 3th token. Logit: 15.34 Prob:  2.53% Token: |8|
Top 4th token. Logit: 15.11 Prob:  2.00% Token: | |
Top 5th token. Logit: 14.98 Prob:  1.75% Token: |6|
Top 6th token. Logit: 14.98 Prob:  1.75% Token: |\n|
Top 7th token. Logit: 14.78 Prob:  1.44% Token: |4|
Top 8th token. Logit: 14.54 Prob:  1.14% Token: |3|
Top 9th token. Logit: 14.49 Prob:  1.08% Token: |2|


We now want to find a reference prompt to run the model on. Even though our ultimate goal is to reverse engineer how this behaviour is done in general, often the best way to start out in mechanistic interpretability is by zooming in on a concrete example and understanding it in detail, and only _then_ zooming out and verifying that our analysis generalises.

We'll run the model on 4 instances of this task, each prompt given twice - one with the correct answer, one with an incorrect answer (e.g., random perturbation such as a single digit change). To make our lives easier, we'll carefully choose prompts with single token answers and the corresponding answers in the same token positions.

<details> <summary>(*) <b>Aside on tokenization</b></summary>

We want models that can take in arbitrary text, but models need to have a fixed vocabulary. So the solution is to define a vocabulary of **tokens** and to deterministically break up arbitrary text into tokens. Tokens are, essentially, subwords, and are determined by finding the most frequent substrings - this means that tokens vary a lot in length and frequency!

Tokens are a _massive_ headache and are one of the most annoying things about reverse engineering language models... Different names will be different numbers of tokens, different prompts will have the relevant tokens at different positions, different prompts will have different total numbers of tokens, etc. Language models often devote significant amounts of parameters in early layers to convert inputs from tokens to a more sensible internal format (and do the reverse in later layers). You really, really want to avoid needing to think about tokenization wherever possible when doing exploratory analysis (though, of course, it's relevant later when trying to flesh out your analysis and make it rigorous!). HookedTransformer comes with several helper methods to deal with tokens: `to_tokens, to_string, to_str_tokens, to_single_token, get_token_position`


In [17]:
prompts = []
answers = []
answer_tokens = []

num_digits = 5
num_examples = 4

def generate_perturbations(correct_answer):
    result = int(correct_answer.strip())
    perturbations = [
        result + 1,                      # Add 1
        result - 1,                      # Subtract 1
        result + 10,                     # Add 10
        result - 10,                     # Subtract 10
        result * 2,                      # Double
        result // 2 if result > 1 else 1,  # Halve (with min of 1)
        int(str(result)[::-1]) if len(str(result)) > 1 else result + 2,  # Reverse digits
        int(''.join(sorted(str(result)))),  # Sort digits
        int(str(result)[1:] + str(result)[0]) if len(str(result)) > 1 else result + 3,  # Rotate digits left
        int(str(result)[-1] + str(result)[:-1]) if len(str(result)) > 1 else result + 4,  # Rotate digits right
        int(''.join([str((int(d) + 1) % 10) for d in str(result)])),  # Add 1 to each digit
        int(''.join([str((int(d) - 1) % 10) for d in str(result)])),  # Subtract 1 from each digit
    ]
    return perturbations

max_len = 0
for _ in range(num_examples):
    num1 = np.random.randint(10 ** (num_digits - 1), 10**num_digits)
    num2 = np.random.randint(10 ** (num_digits - 1), 10**num_digits)
    multiplicands = (num1, num2)
    for j in range(2):
        prompts.append(f"{multiplicands[j]} x {multiplicands[1 - j]} =")
        correct_answer = f" {str(multiplicands[j] * multiplicands[1 - j])}"
        
        perturbations = generate_perturbations(correct_answer)
        incorrect_answer = f" {str(np.random.choice(perturbations))}"
        
        answers.append((correct_answer, incorrect_answer))
        
        correct_tokens = tl_model.to_tokens(answers[-1][0], prepend_bos=True).squeeze().tolist()
        incorrect_tokens = tl_model.to_tokens(answers[-1][1], prepend_bos=True).squeeze().tolist()
        
        max_len = max(max_len, len(correct_tokens), len(incorrect_tokens))
        correct_tokens = correct_tokens + [0] * (max_len - len(correct_tokens))
        incorrect_tokens = incorrect_tokens + [0] * (max_len - len(incorrect_tokens))
        
        answer_tokens.append((correct_tokens, incorrect_tokens))

answer_tokens = torch.tensor(answer_tokens).to(device)
prompts, answers, answer_tokens


(['70620 x 99460 =',
  '99460 x 70620 =',
  '42399 x 65985 =',
  '65985 x 42399 =',
  '33706 x 18222 =',
  '18222 x 33706 =',
  '89222 x 89728 =',
  '89728 x 89222 ='],
 [(' 7023865200', ' 238652007'),
  (' 7023865200', ' 702386520'),
  (' 2797698015', ' 125677899'),
  (' 2797698015', ' 5279769801'),
  (' 614190732', ' 503089621'),
  (' 614190732', ' 725201843'),
  (' 8005711616', ' 4002855808'),
  (' 8005711616', ' 8005711617')],
 tensor([[[151643,    220,     22,     15,     17,     18,     23,     21,
               20,     17,     15,     15],
          [151643,    220,     17,     18,     23,     21,     20,     17,
               15,     15,     22,      0]],
 
         [[151643,    220,     22,     15,     17,     18,     23,     21,
               20,     17,     15,     15],
          [151643,    220,     22,     15,     17,     18,     23,     21,
               20,     17,     15,      0]],
 
         [[151643,    220,     17,     22,     24,     22,     21,     24,
        

**Gotcha**: It's important that all of your prompts have the same number of tokens. If they're different lengths, then the position of the "final" logit where you can check logit difference will differ between prompts, and this will break the below code. The easiest solution is just to choose your prompts carefully to have the same number of tokens (you can eg add filler words like The, or newlines to start).

There's a range of other ways of solving this, eg you can index more intelligently to get the final logit. A better way is to just use left padding by setting `model.tokenizer.padding_side = 'left'` before tokenizing the inputs and running the model; this way, you can use something like `logits[:, -1, :]` to easily access the final token outputs without complicated indexing. TransformerLens checks the value of `padding_side` of the tokenizer internally, and if the flag is set to be `'left'`, it adjusts the calculation of absolute position embedding and causal masking accordingly.

In this demo, though, we stick to using the prompts of the same number of tokens because we want to show some visualisations aggregated along the batch dimension later in the demo.


In [18]:
for prompt in prompts:
    str_tokens = tl_model.to_str_tokens(prompt)
    print("Prompt length:", len(str_tokens))
    print("Prompt as tokens:", str_tokens)


Prompt length: 13
Prompt as tokens: ['7', '0', '6', '2', '0', ' x', ' ', '9', '9', '4', '6', '0', ' =']
Prompt length: 13
Prompt as tokens: ['9', '9', '4', '6', '0', ' x', ' ', '7', '0', '6', '2', '0', ' =']
Prompt length: 13
Prompt as tokens: ['4', '2', '3', '9', '9', ' x', ' ', '6', '5', '9', '8', '5', ' =']
Prompt length: 13
Prompt as tokens: ['6', '5', '9', '8', '5', ' x', ' ', '4', '2', '3', '9', '9', ' =']
Prompt length: 13
Prompt as tokens: ['3', '3', '7', '0', '6', ' x', ' ', '1', '8', '2', '2', '2', ' =']
Prompt length: 13
Prompt as tokens: ['1', '8', '2', '2', '2', ' x', ' ', '3', '3', '7', '0', '6', ' =']
Prompt length: 13
Prompt as tokens: ['8', '9', '2', '2', '2', ' x', ' ', '8', '9', '7', '2', '8', ' =']
Prompt length: 13
Prompt as tokens: ['8', '9', '7', '2', '8', ' x', ' ', '8', '9', '2', '2', '2', ' =']


## Cache layer activations

The first basic operation when doing mechanistic interpretability is to break open the black box of the model and look at all of the internal activations of a model. This can be done with `logits, cache = model.run_with_cache(tokens)`.


In [19]:
tokens = tl_model.to_tokens(prompts, prepend_bos=True)

# Run the model and cache all activations
original_logits, cache = tl_model.run_with_cache(tokens)

We'll later be evaluating how model performance differs upon performing various interventions, so it's useful to have a metric to measure model performance. Our metric here will be the **logit difference**, the difference in logit between the predicted and ground truth product.


In [20]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    # answer_tokens shape is [batch_size, 2, token_length]
    # Calculate logit diff across all tokens in the answer
    _, _, token_length = answer_tokens.shape
    
    all_token_logit_diffs = []
    for i in range(token_length):
        answer_tokens_at_pos = answer_tokens[:, :, i]
        answer_logits = final_logits.gather(dim=-1, index=answer_tokens_at_pos)
        token_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
        all_token_logit_diffs.append(token_logit_diff)
    stacked_diffs = torch.stack(all_token_logit_diffs, dim=1)
    answer_logit_diff = stacked_diffs.mean(dim=1)
    
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()


print(
    "Per prompt logit difference:",
    logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
    .detach()
    .cpu()
    .round(decimals=3),
)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print(
    "Average logit difference:",
    round(logits_to_ave_logit_diff(original_logits, answer_tokens).item(), 3),
)


Per prompt logit difference: tensor([ 0.4090,  0.4020,  0.3710,  0.0000,  0.3600,  0.0820,  0.2270, -0.0210])
Average logit difference: 0.229


We see that the average logit difference is 0.229- for context, this represents putting an $e^{0.229}\approx 1.257 \times$ higher probability on the correct answer, which isn't a lot! Clearly, even though the model is able to do the task, it's is only able to do it with a small margin.


# Direct Logit Attribution

_Look up unfamiliar terms in the [mech interp explainer](https://neelnanda.io/glossary)_

Further, the easiest part of the model to understand is the output - this is what the model is trained to optimize, and so it can always be directly interpreted! Often the right approach to reverse engineering a circuit is to start at the end, understand how the model produces the right answer, and to then work backwards. The main technique used to do this is called **direct logit attribution**

**Background:** The central object of a transformer is the **residual stream**. This is the sum of the outputs of each layer and of the original token and positional embedding. Importantly, this means that any linear function of the residual stream can be perfectly decomposed into the contribution of each layer of the transformer. Further, each attention layer's output can be broken down into the sum of the output of each head (See [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html) for details), and each MLP layer's output can be broken down into the sum of the output of each neuron (and a bias term for each layer).

The logits of a model are `logits=Unembed(LayerNorm(final_residual_stream))`. The Unembed is a linear map, and LayerNorm is approximately a linear map, so we can decompose the logits into the sum of the contributions of each component, and look at which components contribute the most to the logit of the correct token! This is called **direct logit attribution**. Here we look at the direct attribution to the logit difference!

<details> <summary>(*) <b>Background and motivation of the logit difference</b></summary>

Logit difference is actually a _really_ nice and elegant metric. In general, there are two natural ways to interpret the model's outputs: the output logits, or the output log probabilities (or probabilities).

The logits are much nicer and easier to understand, as noted above. However, the model is trained to optimize the cross-entropy loss (the average of log probability of the correct token). This means it does not directly optimize the logits, and indeed if the model adds an arbitrary constant to every logit, the log probabilities are unchanged.

But `log_probs == logits.log_softmax(dim=-1) == logits - logsumexp(logits)`, and so `log_probs(" 7") - log_probs(" 8") = logits(" 7") - logits(" 8")` - the ability to add an arbitrary constant cancels out!

Further, the metric helps us isolate the precise capability we care about - figuring out _which_ number is the product. There are many other components of the task - deciding what the task at hand is, etc. By taking the logit difference we control for all of that.

Our metric is further refined, because each prompt is repeated twice, for each possible product. This controls for irrelevant behaviour such as the model learning that 7 is a more frequent token than 8.

</details>

<details> <summary>Ignoring LayerNorm</summary>

LayerNorm is an analogous normalization technique to BatchNorm (that's friendlier to massive parallelization) that transformers use. Every time a transformer layer reads information from the residual stream, it applies a LayerNorm to normalize the vector at each position (translating to set the mean to 0 and scaling to set the variance to 1) and then applying a learned vector of weights and biases to scale and translate the normalized vector. This is _almost_ a linear map, apart from the scaling step, because that divides by the norm of the vector and the norm is not a linear function. (The `fold_ln` flag when loading a model factors out all the linear parts).

But if we fixed the scale factor, the LayerNorm would be fully linear. And the scale of the residual stream is a global property that's a function of _all_ components of the stream, while in practice there is normally just a few directions relevant to any particular component, so in practice this is an acceptable approximation. So when doing direct logit attribution we use the `apply_ln` flag on the `cache` to apply the global layernorm scaling factor to each constant. See [this clean GPT-2 implementation](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=Clean_Transformer_Implementation) for more on LayerNorm.

</details>

Getting an output logit is equivalent to projecting onto a direction in the residual stream. We use `model.tokens_to_residual_directions` to map the answer tokens to that direction, and then convert this to a logit difference direction for each batch


In [21]:
answer_residual_directions = tl_model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", logit_diff_directions.shape)


Answer residual directions shape: torch.Size([8, 2, 12, 896])
Logit difference directions shape: torch.Size([8, 12, 896])


In [22]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

# Fix the einsum to account for the extra dimension in logit_diff_directions
average_logit_diff = einsum(
    "batch d_model, batch head d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

Final residual stream shape: torch.Size([8, 14, 896])
Calculated average logit diff: 2.746
Original logit difference: 0.229


### Logit Lens

We can now decompose the residual stream! First we apply a technique called the [**logit lens**](https://www.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) - this looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequence layers.


In [23]:
def residual_stack_to_logit_diff(
    residual_stack: Float[torch.Tensor, "components batch d_model"],
    cache: ActivationCache,
) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch head d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(prompts)


Interestingly, we see that the model essentially improves with each layer nearly exactly linearly.

**Note:** Hover over each data point to see what residual stream position it's from!

<details> <summary>Details on `accumulated_resid`</summary>
**Key:** `n_pre` means the residual stream at the start of layer n, `n_mid` means the residual stream after the attention part of layer n (`n_post` is the same as `n+1_pre` so is not included)

- `layer` is the layer for which we input the residual stream (this is used to identify _which_ layer norm scaling factor we want)
- `incl_mid` is whether to include the residual stream in the middle of a layer, ie after attention & before MLP
- `pos_slice` is the subset of the positions used. See `utils.Slice` for details on the syntax.
- return_labels is whether to return the labels for each component returned (useful for plotting)
</details>


In [24]:
accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(
    logit_lens_logit_diffs,
    x=np.arange(tl_model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulate Residual Stream",
)


### Layer Attribution

We can repeat the above analysis but for each layer (this is equivalent to the differences between adjacent residual streams)

Note: Annoying terminology overload - layer k of a transformer means the kth **transformer block**, but each block consists of an **attention layer** (to move information around) _and_ an **MLP layer** (to process information).

Interestingly, we see that while attn layer 23 matters, mostly it's the later MLP layers that matter! This gives some evidence that the model is passing around intermediate memorized values.


In [25]:
per_layer_residual, labels = cache.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True
)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")


## Head Attribution

We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 14 heads, which each act independently and additively.

<details> <summary>Decomposing attention output into sums of heads</summary> 
The standard way to compute the output of an attention layer is by concatenating the mixed values of each head, and multiplying by a big output weight matrix. But as described in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) this is equivalent to splitting the output weight matrix into a per-head output (here `model.blocks[k].attn.W_O`) and adding them up (including an overall bias term for the entire layer)
</details>

We see that only a few heads really matter - head L23H1 contributes a lot positively, while head L22H7 contributes a lot negatively. These correspond to (some of) the name movers and negative name movers discussed in the paper. There are also several heads that matter positively or negatively but less strongly.

There are a few meta observations worth making here - our model has 24 layers \* 14 heads/layer = 336 heads, yet we could localise this behaviour to a handful of specific heads, using straightforward, general techniques. This supports the claim in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) that attention heads are the right level of abstraction to understand attention. It also really surprising that there are _negative_ heads - eg L22H7 makes the incorrect logit e^-0.137 = 0.872x _more_ likely. I'm not sure what's going on there, though the paper discusses some possibilities.


In [26]:
per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=tl_model.cfg.n_layers,
    head_index=tl_model.cfg.n_heads,
)
imshow(
    per_head_logit_diffs,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
)


Tried to stack head results when they weren't cached. Computing head results now


## Attention Analysis

Attention heads are particularly easy to study because we can look directly at their attention patterns and study from what positions they move information from and two. This is particularly easy here as we're looking at the direct effect on the logits so we need only look at the attention patterns from the final token.

We use Alan Cooney's circuitsvis library to visualize the attention patterns! We visualize the top 3 positive and negative heads by direct logit attribution, and show these for the first prompt (as an illustration).

<details> <summary>Interpreting Attention Patterns</summary> 
An easy mistake to make when looking at attention patterns is thinking that they must convey information about the <i>token</i> looked at (maybe accounting for the context of the token). But actually, all we can confidently say is that it moves information from the *residual stream position* corresponding to that input token. Especially later on in the model, there may be components in the residual stream that are nothing to do with the input token! Eg the period at the end of a sentence may contain summary information for that sentence, and the head may solely move that, rather than caring about whether it ends in ".", "!" or "?"
</details>


In [27]:
def visualize_attention_patterns(
    heads: list[int] | int | Float[torch.Tensor, "heads"],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: str | None = "",
    max_width: int | None = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: list[str] = []
    patterns: list[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // tl_model.cfg.n_heads
        head_index = head % tl_model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = tl_model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"


Inspecting the patterns, we can see information from the residual stream position corresponding to the all tokens attending nearly exclusively to themselves. This gives some evidence that the model is primarily using self-attention to process information at each position. The top positive logit attribution heads show strong diagonal attention patterns, suggesting they're extracting or enhancing position-specific features. Meanwhile, the negative attribution heads appear to be attending to the first token, possibly to suppress or counteract certain information. 

In [28]:
top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

positive_html = visualize_attention_patterns(
    top_positive_logit_attr_heads,
    cache,
    tokens[0],
    f"Top {top_k} Positive Logit Attribution Heads",
)

top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

negative_html = visualize_attention_patterns(
    top_negative_logit_attr_heads,
    cache,
    tokens[0],
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)


# Activation Patching

The obvious limitation to the techniques used above is that they only look at the very end of the circuit - the parts that directly affect the logits. Clearly this is not sufficient to understand the circuit! We want to understand how things compose together to produce this final output, and ideally to produce an end-to-end circuit fully explaining this behaviour.

The technique we'll use to investigate this is called **activation patching**. This was first introduced in [David Bau and Kevin Meng's excellent ROME paper](https://rome.baulab.info/), there called causal tracing.

The setup of activation patching is to take two runs of the model on two different inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the corrupted run does not. The key idea is that we give the model the corrupted input, but then **intervene** on a specific activation and **patch** in the corresponding activation from the clean run (ie replace the corrupted activation with the clean activation), and then continue the run. And we then measure how much the output has updated towards the correct answer.

We can then iterate over many possible activations and look at how much they affect the corrupted run. If patching in an activation significantly increases the probability of the correct answer, this allows us to _localise_ which activations matter.


## Residual stream

One natural thing to patch in is the residual stream at a specific layer and specific position. For example, the model is likely initially doing some processing on the multiplicands' tokens to realise the task at hand, but then uses attention to move that information to the " =" token. So patching in the residual stream at the " =" token will likely matter a lot in later layers but not at all in early layers.

We can zoom in much further and patch in specific activations from specific layers. For example, we think that the output of head L9H9 on the final token is significant for directly connecting to the logits

We can patch in specific activations, and can zoom in as far as seems reasonable. For example, if we patch in the output of head L9H9 on the final token, we would predict that it will significantly affect performance.

Note that this technique does _not_ tell us how the components of the circuit connect up, just what they are.

<details> <summary>Technical details</summary> 
The choice of clean and corrupted prompt has both pros and cons. By carefully setting up the counterfactual, that <i>only</i> differs in the second subject, we avoid detecting the parts of the model doing irrelevant computation like detecting that the indirect object task is relevant at all or that it should be outputting a name rather than an article or pronoun. Or even context like that John and Mary are names at all.

However, it _also_ bakes in some details that _are_ relevant to the task. Such as finding the location of the second subject, and of the names in the first clause. Or that the name mover heads have learned to copy whatever they look at.

Some of these could be patched by also changing up the order of the names in the original sentence - patching in "After <b>John and Mary</b> went to the store, John gave a bottle of milk to" vs "After <b>Mary and John</b> went to the store, John gave a bottle of milk to".

In the ROME paper they take a different tack. Rather than carefully setting up counterfactuals between two different but related inputs, they **corrupt** the clean input by adding Gaussian noise to the token embedding for the subject. This is in some ways much lower effort (you don't need to set up a similar but different prompt) but can also introduce some issues, such as ways this noise might break things. In practice, you should take care about how you choose your counterfactuals and try out several. Try to reason beforehand about what they will and will not tell you, and compare the results between different counterfactuals.

</details>

We first create a set of corrupted tokens - where we reverse the order of the prompts to have a guaranteed wrong answer.


In [29]:
corrupted_prompts = []
for i in range(0, len(prompts)):
    corrupted_prompts.append(prompts[len(prompts) - i - 1])
corrupted_tokens = tl_model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = tl_model.run_with_cache(
    corrupted_tokens, return_type="logits"
)
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", round(corrupted_average_logit_diff.item(), 2))
print("Clean Average Logit Diff", round(original_average_logit_diff.item(), 2))


Corrupted Average Logit Diff 0.27
Clean Average Logit Diff 0.23


In [30]:
tl_model.to_string(corrupted_tokens)


['<|endoftext|>89728 x 89222 =',
 '<|endoftext|>89222 x 89728 =',
 '<|endoftext|>18222 x 33706 =',
 '<|endoftext|>33706 x 18222 =',
 '<|endoftext|>65985 x 42399 =',
 '<|endoftext|>42399 x 65985 =',
 '<|endoftext|>99460 x 70620 =',
 '<|endoftext|>70620 x 99460 =']

We now intervene on the corrupted run and patch in the clean residual stream at a specific layer and position.

We do the intervention using TransformerLens's `HookPoint` feature. We can design a hook function that takes in a specific activation and returns an edited copy, and temporarily add it in with `model.run_with_hooks`.


In [31]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache,
):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component


def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff) / (
        original_average_logit_diff - corrupted_average_logit_diff
    )


patched_residual_stream_diff = torch.zeros(
    tl_model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(tl_model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = tl_model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("resid_pre", layer), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(
            patched_logit_diff
        )


We can immediately see that all relevant computation happens on the final token. Moving the residual stream at the correct position near _exactly_ recovers performance!

For reference, tokens and their index from the first prompt are on the x-axis. In an abuse of notation, note that the difference here is averaged over _all_ 8 prompts, while the labels only come from the _first_ prompt.

To be easier to interpret, we normalise the logit difference, by subtracting the corrupted logit difference, and dividing by the total improvement from clean to corrupted to normalise
0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively _improved_ on clean performance


In [32]:
prompt_position_labels = [
    f"{tok}_{i}" for i, tok in enumerate(tl_model.to_str_tokens(tokens[0]))
]
imshow(
    patched_residual_stream_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Residual Stream",
    labels={"x": "Position", "y": "Layer"},
)


## Layers

We can apply exactly the same idea, but this time patching in attention or MLP layers. These are also residual components with identical shapes to the residual stream terms, so we can reuse the same hooks.


In [33]:
patched_attn_diff = torch.zeros(
    tl_model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
patched_mlp_diff = torch.zeros(
    tl_model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(tl_model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_attn_logits = tl_model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("attn_out", layer), hook_fn)],
            return_type="logits",
        )
        patched_attn_logit_diff = logits_to_ave_logit_diff(
            patched_attn_logits, answer_tokens
        )
        patched_mlp_logits = tl_model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("mlp_out", layer), hook_fn)],
            return_type="logits",
        )
        patched_mlp_logit_diff = logits_to_ave_logit_diff(
            patched_mlp_logits, answer_tokens
        )

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(
            patched_attn_logit_diff
        )
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(
            patched_mlp_logit_diff
        )


We see that several attention layers are significant but that, matching the residual stream results, early layers slightly matter on the multiplicands, but mostly everything matters on the final token. Extremely localised! As with direct logit attribution, layer 22 is positive and layer 3 is not, suggesting that the early layers only matter for direct logit effects.


In [34]:
imshow(
    patched_attn_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Attention Layer",
    labels={"x": "Position", "y": "Layer"},
)


And as before, the MLP layer 7 does matter, which aligns with the idea that the model is simply recalling parts of the answer.

An interesting case is MLP 0, which also matters, but this is usually misleading and just a generally true statement about MLP 0 rather than being about the circuit on this task.

<details> <summary>Takes on MLP0</summary> 
It's often observed on GPT-2 XL that MLP0 matters a lot, and that ablating it utterly destroys performance. A current best guess is that the first MLP layer is essentially acting as an extension of the embedding (for whatever reason) and that when later layers want to access the input tokens they mostly read in the output of the first MLP layer, rather than the token embeddings. Within this frame, the first attention layer doesn't do much.

In this framing, it makes sense that MLP0 matters on the second subject token, because that's the one position with a different input token!

It's not entirely known why this happens, but a guess is that the embedding and unembedding matrices in GPT-2 XL are the same. This is pretty unprincipled, as the tasks of embedding and unembedding tokens are <i>not</i> inverses, but this is common practice, and plausibly models want to dedicate some parameters to overcoming this.

There's only suggestive evidence of this though.

</details>


In [35]:
imshow(
    patched_mlp_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched MLP Layer",
    labels={"x": "Position", "y": "Layer"},
)


## Heads

We can refine the above analysis by patching in individual heads! This is somewhat more annoying, because there are now three dimensions (head_index, position and layer), so for now lets patch in a head's output across all positions.

The easiest way to do this is to patch in the activation `z`, the "mixed value" of the attention head. That is, the average of all previous values weighted by the attention pattern, ie the activation that is then multiplied by `W_O`, the output weights.


In [36]:
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][
        :, :, head_index, :
    ]
    return corrupted_head_vector


patched_head_z_diff = torch.zeros(
    tl_model.cfg.n_layers, tl_model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(tl_model.cfg.n_layers):
    for head_index in range(tl_model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = tl_model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )


We can now see that, in addition to the heads identified before, heads L0H11 and L22H7 matter and are presumably responsible for attending to the final token.


In [37]:
imshow(
    patched_head_z_diff,
    title="Logit Difference From Patched Head Output",
    labels={"x": "Head", "y": "Layer"},
)


## Decomposing Heads

Decomposing attention layers into patching in individual heads has already helped us localise the behaviour a lot. But we can understand it further by decomposing heads. An attention head consists of two semi-independent operations - calculating _where_ to move information from and to (represented by the attention pattern and implemented via the QK-circuit) and calculating _what_ information to move (represented by the value vectors and implemented by the OV circuit). We can disentangle which of these is important by patching in just the attention pattern _or_ the value vectors. (See [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) or [a walkthrough video](https://www.youtube.com/watch?v=KV5gbOmHbjU) for more on this decomposition. If you're not familiar with the details of how attention is implemented, I recommend checking out [my clean transformer implementation](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=3Pb0NYbZ900e) to see how the code works)

First let's patch in the value vectors, to measure when figuring out what to move is important.

In [None]:
patched_head_v_diff = torch.zeros(
    tl_model.cfg.n_layers, tl_model.cfg.n_heads, device=device, dtype=torch.float32
)

def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch head_index pos d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, head_index, : , :] = clean_cache[hook.name][
        :, head_index, :, :
    ]
    return corrupted_head_vector


for layer in range(tl_model.cfg.n_layers):
    for head_index in range(tl_model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = tl_model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("v", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_v_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )


We can plot this as a heatmap and it's initially hard to interpret.


In [63]:
imshow(
    patched_head_v_diff,
    title="Logit Difference From Patched Head Value",
    labels={"x": "Head", "y": "Layer"},
)


But it's very easy to interpret if we plot a scatter plot against patching head outputs. Here we see that the late heads matter most.

Meta lesson: Plot things early, often and in diverse ways as you explore a model's internals!


In [64]:
head_labels = [
    f"L{layer}H{head}"
    for layer in range(tl_model.cfg.n_layers)
    for head in range(tl_model.cfg.n_heads)
]
scatter(
    x=utils.to_numpy(patched_head_v_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    xaxis="Value Patch",
    yaxis="Output Patch",
    caxis="Layer",
    hover_name=head_labels,
    color=einops.repeat(
        np.arange(tl_model.cfg.n_layers),
        "layer -> (layer head)",
        head=tl_model.cfg.n_heads,
    ),
    range_x=(-0.5, 0.5),
    range_y=(-0.5, 0.5),
    title="Scatter plot of output patching vs value patching",
)


When we patch in attention patterns, we see a different effect - early and late heads matter a lot, middle heads don't. (In fact, the sum of value patching and pattern patching is approx the same as output patching)


In [66]:
def patch_head_pattern(
    corrupted_head_pattern: Float[torch.Tensor, "batch head_index query_pos d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][
        :, head_index, :, :
    ]
    return corrupted_head_pattern


patched_head_attn_diff = torch.zeros(
    tl_model.cfg.n_layers, tl_model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(tl_model.cfg.n_layers):
    for head_index in range(tl_model.cfg.n_heads):
        hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=cache)
        patched_logits = tl_model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("attn", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_attn_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )


In [67]:
imshow(
    patched_head_attn_diff,
    title="Logit Difference From Patched Head Pattern",
    labels={"x": "Head", "y": "Layer"},
)
head_labels = [
    f"L{layer}H{head}"
    for layer in range(tl_model.cfg.n_layers)
    for head in range(tl_model.cfg.n_heads)
]
scatter(
    x=utils.to_numpy(patched_head_attn_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    hover_name=head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching",
)


# Consolidating Understanding

OK, let's zoom out and reconsolidate. At a high-level, we find that all the action is pretty much on the final token. And that attention layers matter much less than MLP layers.

We've further localised important behaviour to late heads (L23H1, L22H7, L0H11) whose output matters on the final token and whose behaviour is determined by their attention patterns.

A natural speculation is that the early heads are simply pointing to each subproblem of the multiplication, and the late heads are just copying the result.


## Visualizing Attention Patterns

We can validate this by looking at the attention patterns of these heads! Let's take the top 10 heads by output patching (in absolute value).

We see that early heads attend to every token equally, while later heads attend to just the bos token, which is completely consistent with the above speculation!

In [85]:
top_k = 10
top_heads_by_output_patch = torch.topk(
    patched_head_z_diff.abs().flatten(), k=top_k
).indices
first_layer = 0
late_heads = top_heads_by_output_patch[
    tl_model.cfg.n_heads * first_layer <= top_heads_by_output_patch
]

late = visualize_attention_patterns(
    late_heads, cache, tokens[0], title="Top Late Heads"
)

HTML(late)
