# Imports

In [1]:
from functools import partial
from typing import List, Optional, Union

import random
from tqdm import tqdm
import einops
import math
import numpy as np
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from neel_plotly import line, imshow, scatter # pip install git+https://github.com/neelnanda-io/neel-plotly.git
import torch
from torch import Tensor
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float
from typing import List, Optional, Tuple, Dict, Literal, Set
from rich.table import Table, Column
from rich import print as rprint
import re
import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint

torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

# def imshow(tensor, **kwargs):
#     px.imshow(
#         utils.to_numpy(tensor),
#         color_continuous_midpoint=0.0,
#         color_continuous_scale="RdBu",
#         **kwargs,
#     ).show()


# all these are taken from https://github.com/callummcdougall/ARENA_2.0/blob/main/chapter0_fundamentals/exercises/plotly_utils.py
update_layout_set = {"xaxis_range", "yaxis_range", "yaxis2_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor", "showlegend", "xaxis_tickmode", "yaxis_tickmode", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap", "xaxis_tickangle"}
from transformer_lens.utils import to_numpy

def bar(tensor, renderer=None, **kwargs):
    '''
    '''
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    px.bar(y=to_numpy(tensor), **kwargs_pre).update_layout(**kwargs_post).show(renderer)

def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

%load_ext autoreload
%autoreload 2

Disabled automatic differentiation


In [2]:
import sys
sys.path.append("/home/iustin/Mech-Interp/Easy-Transformer")
from easy_transformer.ioi_utils import (
    path_patching,
    show_attention_patterns,
    show_pp,
    logit_diff
    )

# Model Loading

The various flags are simplifications that preserve the model's output but simplify its internals.

Args from `HookedTransformer.from_pretrained` and `load_and_process_state_dict`: 

- `fold_ln` (bool, optional): Whether to fold in the LayerNorm weights to the subsequent linear layer. This does not change the computation. Defaults to True.
    - This applies LayerNorm folding in the weights of the next layer, efficient for post-hoc analysis
    - More about this in [TransformerLens/further_comments.md](https://github.com/TransformerLensOrg/TransformerLens/blob/main/further_comments.md#what-is-layernorm-folding-fold_ln)
- `center_writing_weights` (bool, optional): Whether to center weights writing to the residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the computation. Defaults to True.
    - every component reading an input from the residual stream is preceded by a LayerNorm, which means that the mean of a residual stream vector (ie the component in the direction of all ones) never matters. This means we can remove the all ones component of weights and biases whose output writes to the residual stream. Mathematically, `W_writing -= W_writing.mean(dim=1, keepdim=True)`

- `center_unembed` (bool, optional): Whether to center W_U (ie set mean to be zero). Softmax is translation invariant so this doesn't affect log probs or loss, but does change logits. Defaults to True.
    - The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to every logit doesn't change the output), so we can simplify things by setting the mean of the logits to be zero. This is equivalent to setting the mean of every output vector of W_U to zero. In code, `W_U -= W_U.mean(dim=-1, keepdim=True)`
- `refactor_factored_attn_matrices` (bool, optional): Whether to convert the factored matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False
    - check `load_and_process_state_dict.refactor_factored_attn_matrices`

    
! Note: `prepend_bos` is a flag to add a BOS (beginning of sequence) to the start of the prompt. GPT-2 was not trained with this, but it often makes model behaviour more stable, as the first token is treated weirdly.
</details>

In [3]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",   
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
)                              

# when using acdc_new env
device: torch.device = utils.get_device()



Loaded pretrained model gpt2-small into HookedTransformer


# Importing the dataset

In [22]:
from data.ioi_dataset import IOIDataset, NAMES
from data.homonymy_dataset import zip_and_tokenize_all_answers

N = 100
ioi_dataset = IOIDataset(
    prompt_type="mixed",
    N=N,
    tokenizer=model.tokenizer,
    prepend_bos=False,
    seed=1,
    device=str(device)
)

abc_dataset = ioi_dataset.gen_flipped_prompts("ABB->XYZ, BAB->XYZ")

clean_prompts = []
for prompt in ioi_dataset.sentences:
    clean_prompts.append(" ".join(prompt.split(" ")[:-1]))

abc_prompts = []
for prompt in abc_dataset.sentences:
    abc_prompts.append(" ".join(prompt.split(" ")[:-1]))

clean_answers = []
for i in range(len(ioi_dataset)):
    clean_answer = " " + ioi_dataset.ioi_prompts[i]['IO']
    wrong_answer = " " + ioi_dataset.ioi_prompts[i]['S']
    answers = (clean_answer, wrong_answer)
    clean_answers.append(answers)

abc_answers = []
for i in range(len(abc_dataset)):
    clean_answer = " " + abc_dataset.ioi_prompts[i]['IO']
    wrong_answer = " " + abc_dataset.ioi_prompts[i]['S']
    answers = (clean_answer, wrong_answer)
    abc_answers.append(answers)

clean_answer_tokens = []
for answer in clean_answers:
    token_pair = []
    for ans in answer:
        # Convert the answer string to token and then to its corresponding token ID
        token = model.to_single_token(ans)
        token_pair.append(token)
    
    clean_answer_tokens.append(token_pair)
# Convert the list of token pairs to a PyTorch tensor
clean_answers_tokens_ids = torch.tensor(clean_answer_tokens).to(device)
# Both clean and correct prompts
# all_answer_tok_ids, all_answer_strings = zip_and_tokenize_all_answers(model, clean_answers, abc_answers, device)

all_prompts_abc = list(zip(clean_prompts, abc_prompts))
all_prompts_strings_abc = [prompt for prompt_pair in all_prompts_abc for prompt in prompt_pair]

In [23]:
def swap_names_in_prompt(prompt: str) -> str:
    match_when = re.search(r"When (\w+) and (\w+)", prompt)
    match_then = re.search(r"Then, (\w+) and (\w+)", prompt)
    
    if match_when:
        name1, name2 = match_when.groups()
    elif match_then:
        name1, name2 = match_then.groups()
    else:
        # If the expected format is not found, return the prompt unchanged.
        return prompt

    # Use a temporary marker to avoid double replacement.
    temp = "__TEMP_SWAP__"
    prompt_swapped = prompt.replace(name1, temp)
    prompt_swapped = prompt_swapped.replace(name2, name1)
    prompt_swapped = prompt_swapped.replace(temp, name2)
    return prompt_swapped

corr_prompts = [swap_names_in_prompt(prompt) for prompt in clean_prompts]

correct_answers = []
wrong_answers = []
corr_answers = []
for i in range(len(corr_prompts)):
    correct_answer = clean_answers[i][1]
    wrong_answer = clean_answers[i][0]
    corr_answers.append((correct_answer, wrong_answer))

corr_answer_tokens = []
for answer in corr_answers:
    token_pair = []
    for ans in answer:
        # Convert the answer string to token and then to its corresponding token ID
        token = model.to_single_token(ans)
        token_pair.append(token)

    corr_answer_tokens.append(token_pair)

corr_answers_tok_ids = torch.tensor(corr_answer_tokens, device=device)

for i, (orig, corrupt, answers) in enumerate(zip(clean_prompts, corr_prompts, corr_answers)):
    print("Original:", orig)
    print("Corrupt: ", corrupt)
    print("Corr Answers:", answers)
    print("-" * 40)
    if i > 3:
        break


all_answer_tok_ids, all_answer_strings = zip_and_tokenize_all_answers(model, clean_answers, corr_answers, device)

all_prompts_corr = list(zip(clean_prompts, corr_prompts))
all_prompts_strings_corr = [prompt for prompt_pair in all_prompts_corr for prompt in prompt_pair]


Original: When Victoria and Jane got a snack at the store, Jane decided to give it to
Corrupt:  When Jane and Victoria got a snack at the store, Victoria decided to give it to
Corr Answers: (' Jane', ' Victoria')
----------------------------------------
Original: When Sullivan and Rose got a necklace at the garden, Sullivan decided to give it to
Corrupt:  When Rose and Sullivan got a necklace at the garden, Rose decided to give it to
Corr Answers: (' Sullivan', ' Rose')
----------------------------------------
Original: When Alan and Alex got a drink at the store, Alex decided to give it to
Corrupt:  When Alex and Alan got a drink at the store, Alan decided to give it to
Corr Answers: (' Alex', ' Alan')
----------------------------------------
Original: Then, Jessica and Crystal had a long argument, and afterwards Jessica said to
Corrupt:  Then, Crystal and Jessica had a long argument, and afterwards Crystal said to
Corr Answers: (' Jessica', ' Crystal')
-------------------------------

In [24]:
def format_prompt(sentence: str) -> str:
    '''Format a prompt by underlining names (for rich print)'''
    return re.sub("(" + "|".join(NAMES) + ")", lambda x: f"[u bold dark_orange]{x.group(0)}[/]", sentence) + "\n"


def make_table(cols, colnames, title="", n_rows=5, decimals=4):
    '''Makes and displays a table, from cols rather than rows (using rich print)'''
    table = Table(*colnames, title=title)
    rows = list(zip(*cols))
    f = lambda x: x if isinstance(x, str) else f"{x:.{decimals}f}"
    for row in rows[:n_rows]:
        table.add_row(*list(map(f, row)))
    rprint(table)

make_table(
    colnames = ["IOI prompt", "IOI subj", "IOI indirect obj", "ABC prompt"],
    cols = [
        map(format_prompt, ioi_dataset.sentences),
        model.to_string(ioi_dataset.s_tokenIDs).split(),
        model.to_string(ioi_dataset.io_tokenIDs).split(),
        map(format_prompt, abc_dataset.sentences),
    ],
    title = "Sentences from IOI vs ABC distribution",
)

### Defining the metric

We use the **Logit Difference** between the wanted and unwanted answer token. The reason logit difference is suitable for in-context learning tasks is that we can measure which token the model promotes as opposed to a something we don't want, as an unwanted behaviour, and because we ultimately care about the final predictions, which happens very localised on the final token in the sequence (appropriate for next-token prediction). 

Average logit difference vs Per prompt logit difference (zoomed out perspective):

- with the first we can get an approximation on how the model performs overall
- with the second we can see it's behaviour somehow clearly, inspecting what the model is doing in extreme cases, or how the model learns to do the task better (most of the time by increasing the number of prompts)

In [25]:
def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], ioi_dataset: IOIDataset = ioi_dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

model.reset_hooks(including_permanent=True)

ioi_logits_original, ioi_cache = model.run_with_cache(ioi_dataset.toks)

corr_tokens = model.to_tokens(corr_prompts)
corr_logits, corr_cache = model.run_with_cache(corr_tokens)

# abc_logits_original, abc_cache = model.run_with_cache(abc_dataset.toks)

ioi_per_prompt_diff = logits_to_ave_logit_diff_2(ioi_logits_original, per_prompt=True)
corr_per_prompt_diff = logits_to_ave_logit_diff_2(corr_logits, per_prompt=True)

ioi_average_logit_diff = logits_to_ave_logit_diff_2(ioi_logits_original).item()
corr_average_logit_diff = logits_to_ave_logit_diff_2(corr_logits).item()

print(f"Average logit diff (IOI dataset): {ioi_average_logit_diff:.4f}")
print(f"Average logit diff (Corr dataset): {corr_average_logit_diff:.4f}")

make_table(
    colnames = ["IOI prompt", "IOI logit diff", "Corr prompt", "Corr logit diff"],
    cols = [
        map(format_prompt, ioi_dataset.sentences),
        ioi_per_prompt_diff,
        map(format_prompt, corr_prompts),
        corr_per_prompt_diff,
    ],
    title = "Sentences from IOI vs Corr distribution",
)

Average logit diff (IOI dataset): 3.1490
Average logit diff (Corr dataset): 0.0544


In [26]:
exponential = math.exp(-0.0003)
exponential

0.9997000449955004

Average logit difference for 500 prompts from IOI is 3.1490, this represents putting an $e^{3.1490}\approx 23\times$ higher probability on the correct answer.

Average logit difference for 500 prompts from ABC is -0.0291, this represents putting an $e^{-0.0291}\approx 1\times$ higher probability on the correct answer.

## Direct Logit Attribution

The logits of a model are computed by `logits=Unembed(LayerNorm(final_residual_stream))`. The Unembed is a linear map, and LayerNorm is approximately a linear map, so we can decompose the logits into the sum of the contributions of each component, and look at which components contribute the most to the logit of the correct token! This is called **direct logit attribution**. Here we look at the direct attribution to the logit difference!

The residual stream (of dimension `d_model`) acts as a linear trajectory taking in the input and applying it some transformations which result in the output. We say that each layer reads from this stream, processes the data through Attention and MLP, and writes the processed information back to the stream. Before applying transformations the transformer layers are applying Layer Normalization to each vector at each position in the sequence, ie translating to set the mean to 0 and scaling to set the variance to 1 and then applying a learned vector of weights and biases to scale and translate the normalized vector (considering LayerNorm folding is applied).

`!` This is almost a linear map, apart from the scaling step, because that divides by the norm (norm = the magnitude of the vector) of the vector (division by the standard deviation, which computes the square root of the variance) and the norm part is not a linear function. And this is why `fold_ln` comes in handy, by factoring out all the linear parts of the LayerNorm in the weight matrices. 
But if we fixed the scale factor, the LayerNorm would be fully linear. And the scale of the residual stream is a global property that's a function of all components of the stream, while in practice there is normally just a few directions relevant to any particular component, so in practice this is an acceptable approximation. So when doing direct logit attribution we use the `apply_ln` flag on the cache to apply the global LayerNorm scaling factor to each constant.



Getting an **output logit** is equivalent to projecting onto a direction in the residual stream. We use `model.tokens_to_residual_directions` to map the answer tokens to that direction, and then convert this to a logit difference direction for each batch.

From docstring:
- "Maps tokens to a tensor with the unembedding vector for those tokens, ie the vector in the residual stream that we dot with to the get the logit for that token."

To verify that this works, we can apply this to the final residual stream for our cached prompts (after applying LayerNorm scaling) and verify that we get the same answer.

<details> <summary>Technical details</summary>

`logits = Unembed(LayerNorm(final_residual_stream))`, so we technically need to account for the centering, and then learned translation and scaling of the layernorm, not just the variance 1 scaling.

The centering is accounted for with the preprocessing flag `center_writing_weights` which ensures that every weight matrix writing to the residual stream has mean zero.

The learned scaling is folded into the unembedding weights `model.unembed.W_U` via `W_U_fold = layer_norm.weights[:, None] * unembed.W_U`

The learned translation is folded to `model.unembed.b_U`, a bias added to the logits (note that GPT-2 is not trained with an existing `b_U`). This roughly represents unigram statistics. But we can ignore this because each prompt occurs twice with names in the opposite order, so this perfectly cancels out.

Note that rather than using layernorm scaling we could just study cache["ln_final.hook_normalised"]

</details>

In [27]:
clean_tokens = model.to_tokens(clean_prompts, prepend_bos=True)

# Run the model and cache all activations
clean_logits, clean_cache = model.run_with_cache(clean_tokens)

In [None]:
from src.observational.direct_logit_attribution import logits_to_ave_logit_diff

# clean_per_prompt_diff = logits_to_ave_logit_diff(clean_logits, clean_answers_tokens_ids, per_prompt=True)
# print("Per prompt logit difference:", clean_per_prompt_diff)
# clean_average_logit_diff = logits_to_ave_logit_diff(clean_logits, clean_answers_tokens_ids)
# print("Average logit difference:", clean_average_logit_diff)

# cols = [
#     "Prompt",
#     Column("Correct", style="rgb(0,200,0) bold"),
#     Column("Incorrect", style="rgb(255,0,0) bold"),
#     Column("Logit Difference", style="bold")
# ]
# table = Table(*cols, title="Logit differences")

# for prompt, answer, logit_diff in zip(clean_prompts, clean_answers, clean_per_prompt_diff):
#     table.add_row(prompt, repr(answer[0]), repr(answer[1]), f"{logit_diff.item():.3f}")

# # rprint(table)

Per prompt logit difference: tensor([-0.2589, -2.1226, -1.2909, -2.2932, -0.2543,  0.0560,  1.6297, -1.8960,
        -0.9090,  0.0815, -0.6578, -1.4054,  2.4049,  1.1419,  0.4785, -2.6344,
         2.9288,  1.6651,  0.2637,  2.7172,  2.3388,  1.9674,  1.0920, -0.2760,
         0.5367,  0.1426,  0.5303,  0.1809, -1.9326,  0.3946, -0.7012,  0.3405,
        -0.5173, -1.9425,  1.7553,  0.3393,  0.4020, -1.4718,  0.9667,  0.9380,
        -1.0242, -1.5441,  0.9880,  6.1403,  1.8447,  2.0769,  0.6899,  0.1454,
        -0.8425,  1.1229, -1.0020, -0.5689, -0.1760, -1.0492,  0.7381,  2.4418,
        -0.6310,  0.8369,  0.6708,  2.0926,  1.7965,  2.5346,  1.1738, -0.0726,
         0.6780,  0.3580,  0.1420,  0.5761,  0.9868, -0.8368, -0.2341,  0.4986,
         0.1352, -1.2296,  6.0055, -1.0857, -0.3579,  3.6290, -2.0707, -0.6516,
         3.1673,  1.5446, -0.9390,  2.6314, -2.7418, -1.5044, -2.5641,  0.5598,
        -1.1270, -1.2051, -0.0211, -0.0134, -0.2154,  2.2469,  0.5946,  0.5420,
        -1.

### Logit Lens


We can now decompose the residual stream! First we apply a technique called the [**logit lens**](https://www.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) - this looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequent layers.


In [29]:
from src.observational.direct_logit_attribution import logit_lens

logit_lens_logit_diffs, labels = logit_lens(model, clean_prompts, clean_cache, clean_answers_tokens_ids, decomposition="residual")

line(
    logit_lens_logit_diffs,
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    labels={"x": "Layer", "y": "Logit Diff"},
    title="Logit Difference From Accumulate Residual Stream",
)

Answer residual directions shape: torch.Size([100, 2, 768])


### Layer Attribution


In [30]:
per_layer_logit_diffs, labels = logit_lens(model, clean_prompts, clean_cache, clean_answers_tokens_ids, decomposition="layer_blocks")

per_layer_logit_diffs = per_layer_logit_diffs.cpu().numpy()  # Move tensor to CPU and convert to NumPy array
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer", labels={"x": "Layer Activation (Attention/MLP)", "y": "Logit Diff"})

Answer residual directions shape: torch.Size([100, 2, 768])


### Head Attribution
We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 12 heads, which each act independently and additively.

<details> <summary>Decomposing attention output into sums of heads</summary>

The standard way to compute the output of an attention layer is by concatenating the mixed values `z` of each head, and multiplying by a big output weight matrix. But as described in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) this is equivalent to splitting the output weight matrix into a per-head output (here `model.blocks[k].attn.W_O`) and adding them up (including an overall bias term for the entire layer)

</details>

In [31]:
per_head_logit_diffs, labels = logit_lens(model, clean_prompts, clean_cache, clean_answers_tokens_ids, decomposition="attention_heads")

per_head_logit_diffs_np = per_head_logit_diffs.cpu().numpy()
imshow(
    per_head_logit_diffs_np,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
    width=800,
)

Answer residual directions shape: torch.Size([100, 2, 768])
Tried to stack head results when they weren't cached. Computing head results now
torch.Size([144, 100, 768])
torch.Size([144])


### Visual Logit Lens

In [14]:
from src.observational.direct_logit_attribution import visual_logit_lens

str_tokens = model.to_str_tokens(clean_tokens[4])
titles = [f"{token1} -> {token2}" for token1, token2 in zip(str_tokens[:-1], str_tokens[1:])]

correct_ids = []
for i in range(1, len(str_tokens) + 1):

    correct_ids.extend(model.tokenizer.encode(f" {str_tokens[i % len(str_tokens)]}")) # (len(days))

# visual_logit_lens(model, all_prompts_strings[4], titles)

## Attention Analysis

Attention heads are particularly easy to study because we can look directly at their attention patterns and study from what positions they move information from and two. This is particularly easy here as we're looking at the direct effect on the logits so we need only look at the attention patterns from the final token.

We use Alan Cooney's **circuitsvis** library to visualize the attention patterns! We visualize the top 3 positive and negative heads by direct logit attribution, and show these for the first prompt (as an illustration).

<details> <summary>Interpreting Attention Patterns</summary>

An easy mistake to make when looking at attention patterns is thinking that they must convey information about the <i>token</i> looked at (maybe accounting for the context of the token). But actually, all we can confidently say is that it moves information from the _residual stream position_ corresponding to that input token, implemented by the QK circuit. Especially later on in the model, there may be components in the residual stream that are nothing to do with the input token! Eg the period at the end of a sentence may contain summary information for that sentence, and the head may solely move that, rather than caring about whether it ends in ".", "!" or "?"

Terminology:

- Source (Query) tokens: The token in the input sequence that is being attended to
- Destination (Key) Tokens: The token in the input sequence that is attending to other tokens. The destination token is essentially querying the other tokens (the source tokens) to calculate attention weights

The attention is calculated on the source token: Specifically, it reflects how much attention a given destination token (row) is placing on each source token (column) in an attention pattern.

</details>


In [52]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"

In [55]:
from src.utils import strip_eot

top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

positive_html = visualize_attention_patterns(
    top_positive_logit_attr_heads,
    clean_cache,
    strip_eot(clean_tokens[0]),
    f"Top {top_k} Positive Logit Attribution Heads",
)

top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

negative_html = visualize_attention_patterns(
    top_negative_logit_attr_heads,
    clean_cache,
    strip_eot(clean_tokens[0]),
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)

## Understanding Vector Norms in Attention

<details>
<summary>1. The Definition of the Norm</summary>

The **Euclidean (L2) norm** of a vector $(\mathbf{v} \in \mathbb{R}^d)$ is defined as:

$
\|\mathbf{v}\| = \sqrt{v_1^2 + v_2^2 + \cdots + v_d^2}.
$

In a Transformer, many operations (e.g., linear transformations, layer normalization, residual connections) can change the norms of the intermediate vectors. Thus, two different tokens $\mathbf{x}_1$ and $\mathbf{x}_2$ may, after these transformations, end up with very different $\|\mathbf{V}_1\|$ and $|\mathbf{V}_2\|$.

</details>

<details>
<summary>2. Why Does the Norm Matter in Attention?</summary>

$|\mathbf{v}\| = \sqrt{v_1^2 + v_2^2 + \cdots + v_d^2}$

In a Transformer, many operations (e.g., linear transformations, layer normalization, residual connections) can change the norms of the intermediate vectors. Thus, two different tokens $\mathbf{x}_1$ and $\mathbf{x}_2$ may, after these transformations, end up with very different $\|\mathbf{V}_1\|$ and $|\mathbf{V}_2\|$.

---

Consider a single attention head where the output vector $\mathbf{y}_i$ for token $(i)$ is:

$
\mathbf{y}_i = \sum_{j=1}^{n} \alpha_{i,j} \, f(\mathbf{x}_j),
$

where
- $\alpha_{i,j}$ is the **attention weight** for how much the head attends to token $(j)$ when producing the output for token $(i)$.
- $f(\mathbf{x}_j)$ (or $\mathbf{V}_j$) is the **value vector** for token $(j)$.

#### Key Point

- **Attention weights alone $(\alpha_{i,j})$ can be misleading** if you ignore the magnitude of $f(\mathbf{x}_j)$.  $(f(\mathbf{x}_j))$
- Even if $\alpha_{i,j}$ is large for some $(j)$, if $\|f(\mathbf{x}_j)\|$ is near zero, then that term contributes very little to $\mathbf{y}_i$.  
- Conversely, if $\alpha_{i,k}$ is moderate but $\|f(\mathbf{x}_k)\|$ is huge, token $(k)$ can dominate the output sum.

Hence, the product $\alpha_{i,j} \cdot \|\mathbf{V}_j\|$ is a more direct measure of how strongly token $(j)$ influences the final output.

---

<details>
<summary>3. Kobayashi et al. (2020): "Attention is Not Only a Weight"</summary>

In the paper **[Attention is Not Only a Weight: Analyzing Transformers with Vector Norms](https://aclanthology.org/2020.emnlp-main.574.pdf)** (Kobayashi et al., EMNLP 2020)
, the authors highlight exactly this phenomenon. They note that **solely analyzing attention weights $\alpha_{i,j}$** can lead to incorrect conclusions about how much each token contributes.

#### 3.1 Rewriting the Attention Output

They start with the usual attention output (Equation 1 in many Transformer descriptions):

$
\mathbf{y}_i = \sum_{j=1}^{n} \alpha_{i,j} \, \mathbf{V}_j.
$

By **linearity**, you can group terms or factor out the attention weights. Kobayashi et al. rewrite this to emphasize that the **norm** of the weighted sum is crucial:

$
\left\lVert \sum_{j=1}^{n} \alpha_{i,j} \, \mathbf{V}_j \right\rVert.
$

This is (roughly) their **Equation 3**, which they use to analyze “how large” the resulting output vector is, due to each token.

#### 3.2 Why Look at the Norm?

If an input token $\mathbf{x}_j$ has a large transformation vector $\mathbf{V}_j$, it can dominate $\mathbf{y}_i$ *even if* $\alpha_{i,j}$ is not the largest attention weight. Conversely, a token that receives a high attention weight might have very little impact if its $\mathbf{V}_j$ is extremely small. 

Thus, **a high attention weight does not necessarily mean high impact** on $(\mathbf{y}_i)$. Kobayashi et al. argue that we should consider $(\|\mathbf{V}_j\|)$ (the **value norm**) or even the full product $(\alpha_{i,j} \cdot \mathbf{V}_j)$ to assess how each token truly contributes to the head’s output.

---

### 4. Practical Takeaways

1. **When analyzing attention heads**, do not rely solely on the softmax attention distribution $(\{\alpha_{i,j}\})$.  
2. **Look at the magnitude of the value vectors** to see whether certain tokens with moderate attention weights might be exerting a strong influence.  
3. This can be done by computing:
   $
   \alpha_{i,j} \cdot \|\mathbf{V}_j\|
   \quad\text{or}\quad
   \left\lVert \sum_{j=1}^{n} \alpha_{i,j}\,\mathbf{V}_j \right\rVert,
   $
   which reveals how “strong” each token’s **actual** contribution is.

By combining attention weights **and** vector norms, you get a more accurate picture of which tokens affect the model's final representation.

In [108]:
from src.observational.utils import show_attention_patterns
model.reset_hooks()
prompt = "When Mary and John went to the bar, Mary gave a drink to John"
prompt = ioi_dataset.sentences[2]

# Value-Weighted Attention Patterns by Norm of the Value vector
show_attention_patterns(model, [(9, 9), (9, 6), (10, 0)], prompts=prompt, mode="val")

In [101]:
from easy_transformer.ioi_utils import show_attention_patterns
ys = []
average_attention = {}
heads_raw = [(9, 9), (9, 6), (10, 0)]
for idx, dataset in enumerate([ioi_dataset, abc_dataset]):
    fig = go.Figure()
    for head_raw in heads_raw:
        heads = [head_raw]
        average_attention[head_raw] = {}
        cur_ys = []
        cur_stds = []
        att = torch.zeros(size=(dataset.N, dataset.max_len, dataset.max_len))
        for head in tqdm(heads):
            att += show_attention_patterns(
                model, [head], dataset, return_mtx=True, mode="scores"
            )
        att /= len(heads)

        vals = att[torch.arange(dataset.N), ioi_dataset.word_idx["end"][: dataset.N], :]
        evals = torch.exp(vals)
        val_sum = torch.sum(evals, dim=1)
        assert val_sum.shape == (dataset.N,), val_sum.shape

        for key in ioi_dataset.word_idx.keys():
            end_to_s2 = att[
                torch.arange(dataset.N),
                ioi_dataset.word_idx["end"][: dataset.N],
                ioi_dataset.word_idx[key][: dataset.N],
            ]
            cur_ys.append(end_to_s2.mean().item())
            cur_stds.append(end_to_s2.std().item())
            average_attention[head_raw][key] = end_to_s2.mean().item()
        fig.add_trace(
            go.Bar(
                x=list(ioi_dataset.word_idx.keys()),
                y=cur_ys,
                error_y=dict(type="data", array=cur_stds),
                name=str(head_raw),
            )
        )
        fig.update_layout(
            title_text=f'Attention of NMs from END to various positions on {["ioi_dataset", "abc_dataset"][idx]}'
        )
    fig.show()



100%|██████████| 1/1 [00:02<00:00,  2.69s/it]
100%|██████████| 1/1 [00:02<00:00,  2.80s/it]
100%|██████████| 1/1 [00:02<00:00,  2.69s/it]


100%|██████████| 1/1 [00:02<00:00,  2.82s/it]
100%|██████████| 1/1 [00:02<00:00,  2.70s/it]
100%|██████████| 1/1 [00:02<00:00,  2.66s/it]


In [19]:
import gc
gc.collect()

try: 
    torch.cuda.empty_cache()
except:
    pass

## Activation Patching


Activation Patching is an efficient method for determining the importance of a component in the Transformer for the final prediction.

It works by running the model on a clean prompt and taking a specific activation and then running the model on a corrupted prompt. We then patch the clean activation and measure the Logit Difference (patching means replacing the corrupted activation with the clean activation in the corrupted prompt). If patching in that specific activations promotes significantly the desired token on a corrupted prompt where that token is unlikely to be considered, then the component whose activation we patched is significant for the right outcome and we can consider it for our circuit. 

In [35]:
corrupted_tokens = model.to_tokens(corr_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(
    corrupted_tokens, return_type="logits"
)

corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, corr_answers_tok_ids)
print("Corrupted Average Logit Diff", round(corrupted_average_logit_diff.item(), 2))
print("Clean Average Logit Diff", round(ioi_average_logit_diff, 2))

Corrupted Average Logit Diff 0.37
Clean Average Logit Diff 3.15


### Residual Stream Patching

We first define a metric for our task, which is **normalized logit difference**, by subtracting the corrupted logit difference, and dividing by the total improvement from clean to corrupted to normalise. 

- 0 means zero change 

- negative means actively made worse

- 1 means totally recovered clean performance

- `>` 1 means actively improved on clean performance

We can patch in the residual stream in 2 ways: 

1. by using the `transformer_lens.patching` helper module. Here we use `patching.get_act_patch_resid_pre`, a function built on top of the more general `generic_activation_patch`. 

2. we do the intervention using TransformerLens's `HookPoint` feature. We can design a hook function that takes in a specific activation and returns an edited copy, and temporarily add it in with `model.run_with_hooks`.


In [32]:
# 1
from transformer_lens import patching

# define the metric
def ioi_metric(logits=corrupted_logits, 
                 answer_tokens=clean_answers_tokens_ids, 
                 corr_logit_diff=corrupted_average_logit_diff, 
                 clean_logit_diff=ioi_average_logit_diff):
    # Normalized logit diff
    patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False)
    return (patched_logit_diff - corr_logit_diff) / (clean_logit_diff - corr_logit_diff)

act_patch_resid_pre = patching.get_act_patch_resid_pre(
    model=model,
    corrupted_tokens=corrupted_tokens,
    clean_cache=ioi_cache,
    patching_metric=ioi_metric,
)

labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]

# plot
imshow(
    act_patch_resid_pre,
    labels={"x": "Position", "y": "Layer"},
    x=labels,
    title="resid_pre Activation Patching",
    width=1200
)

  0%|          | 0/252 [00:00<?, ?it/s]

The model promotes significantly the correct answer, ie the `S` token "Alice" from layer 0 up to layer 9, and then moves the information to the `END` token " - " in the final two layers.

Interestingly, the model also looks at incorrect token `NON_S` token "Alice" but it promotes it negatively, ie writing in the opposite direction in the residual stream.

- uncomment bellow to see difference between th clean and corrupted IOI prompts for a clearer picture of how we can set counterfactual prompts.

clean and corrupted ioi tokens: ...

<!--
['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']

['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' Mary', ' gave', ' the', ' bag', ' to']

['<|endoftext|>', 'When', ' Tom', ' and', ' James', ' went', ' to', ' the', ' park', ',', ' James', ' gave', ' the', ' ball', ' to']

['<|endoftext|>', 'When', ' Tom', ' and', ' James', ' went', ' to', ' the', ' park', ',', ' Tom', ' gave', ' the', ' ball', ' to']

['<|endoftext|>', 'When', ' Dan', ' and', ' Sid', ' went', ' to', ' the', ' shops', ',', ' Sid', ' gave', ' an', ' apple', ' to']

['<|endoftext|>', 'When', ' Dan', ' and', ' Sid', ' went', ' to', ' the', ' shops', ',', ' Dan', ' gave', ' an', ' apple', ' to']

['<|endoftext|>', 'After', ' Martin', ' and', ' Amy', ' went', ' to', ' the', ' park', ',', ' Amy', ' gave', ' a', ' drink', ' to']

['<|endoftext|>', 'After', ' Martin', ' and', ' Amy', ' went', ' to', ' the', ' park', ',', ' Martin', ' gave', ' a', ' drink', ' to']


corrupted ioi tokens:

['<|endoftext|>When John and Mary went to the shops, Mary gave the bag to',

 '<|endoftext|>When John and Mary went to the shops, John gave the bag to',

 '<|endoftext|>When Tom and James went to the park, Tom gave the ball to',

 '<|endoftext|>When Tom and James went to the park, James gave the ball to',

 '<|endoftext|>When Dan and Sid went to the shops, Dan gave an apple to',

 '<|endoftext|>When Dan and Sid went to the shops, Sid gave an apple to',

 '<|endoftext|>After Martin and Amy went to the park, Martin gave a drink to',

 '<|endoftext|>After Martin and Amy went to the park, Amy gave a drink to'] -->

We now intervene on the corrupted run and patch in the clean residual stream at a specific layer and position.

In [36]:
# 2
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache,
):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component


def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff) / (
        ioi_average_logit_diff - corrupted_average_logit_diff
    )

patched_residual_stream_diff = torch.zeros(
    model.cfg.n_layers, clean_tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for position in range(clean_tokens.shape[1]): 
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=ioi_cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("resid_pre", layer), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, clean_answers_tokens_ids)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(
            patched_logit_diff
        )

For reference, tokens and their index from the first prompt are on the x-axis. In an abuse of notation, note that the difference here is averaged over _all_ prompts, while the labels only come from the _first_ prompt.

In [37]:
prompt_position_labels = [
    f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))
]
imshow(
    patched_residual_stream_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Residual Stream",
    labels={"x": "Position", "y": "Layer"},
)

This plot is the same as before and it has the same interpretation., but with more code to it. 

### Layers


1. We can apply exactly the same idea, but this time patching in attention or MLP layers. These are also residual components with identical shapes to the residual stream terms, so we can reuse the same hooks.

2. We can reproduce this with `patching.get_act_patch_block_every`, that patches to `resid_pre`, `attn_out` and `mlp_out`.  


In [38]:
# 1
patched_attn_diff = torch.zeros(
    model.cfg.n_layers, clean_tokens.shape[1], device=device, dtype=torch.float32
)
patched_mlp_diff = torch.zeros(
    model.cfg.n_layers, clean_tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for position in range(clean_tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=ioi_cache)
        patched_attn_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("attn_out", layer), hook_fn)],
            return_type="logits",
        )
        patched_attn_logit_diff = logits_to_ave_logit_diff(
            patched_attn_logits, clean_answers_tokens_ids
        )
        patched_mlp_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("mlp_out", layer), hook_fn)],
            return_type="logits",
        )
        patched_mlp_logit_diff = logits_to_ave_logit_diff(
            patched_mlp_logits, clean_answers_tokens_ids
        )

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(
            patched_attn_logit_diff
        )
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(
            patched_mlp_logit_diff
        )

In [39]:
imshow(
    patched_attn_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Attention Layer",
    labels={"x": "Position", "y": "Layer"},
)

We can see that `MLP0` matters a lot, acting as an extention of the Embedding layer and that when later layers want to access the input tokens they mostly read in the output of the first MLP layer, rather than the token embeddings.


In [40]:
imshow(
    patched_mlp_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched MLP Layer",
    labels={"x": "Position", "y": "Layer"},
)

<details><summary>Note on `patching.get_act_patch_block_every`</summary>

One important thing to note - we're cycling through the `resid_pre`, `attn_out` and `mlp_out` and only patching one of them at a time, rather than patching all three at once.

</details>

In [41]:
# 2
act_patch_block_every = patching.get_act_patch_block_every(model, corrupted_tokens, ioi_cache, ioi_metric)

imshow(
    act_patch_block_every,
    x=labels,
    facet_col=0, # This argument tells plotly which dimension to split into separate plots
    facet_labels=["Residual Stream", "Attn Output", "MLP Output"], # Subtitles of separate plots
    title="Logit Difference From Patched Attn Head Output",
    labels={"x": "Sequence Position", "y": "Layer"},
    width=1400,
)

  0%|          | 0/252 [00:00<?, ?it/s]

  0%|          | 0/252 [00:00<?, ?it/s]

  0%|          | 0/252 [00:00<?, ?it/s]

### Heads


We can refine the above analysis by patching in individual heads! This is somewhat more annoying, because there are now three dimensions (head_index, position and layer), so for now lets patch in a head's output across all positions.

1. Coding it up, the easiest way to do this is to patch in the activation `z`, the "mixed value" of the attention head. That is, the average of all previous values weighted by the attention pattern, ie the activation that is then multiplied by `W_O`, the output weights.

2. We can also do it just by `patching.act_patch_attn_head_out_all_pos`.


In [42]:
# 1
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][
        :, :, head_index, :
    ]
    return corrupted_head_vector
   
patched_head_z_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=ioi_cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, clean_answers_tokens_ids)

        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [43]:
imshow(
    patched_head_z_diff,
    title="Logit Difference From Patched Head Output",
    labels={"x": "Head", "y": "Layer"},
)

In [44]:
# 2
act_patch_attn_head_out_all_pos = patching.get_act_patch_attn_head_out_all_pos(
    model,
    corrupted_tokens,
    ioi_cache,
    ioi_metric
)

imshow(
    act_patch_attn_head_out_all_pos,
    labels={"y": "Layer", "x": "Head"},
    title="attn_head_out Activation Patching (All Pos)",
    width=1200
)

  0%|          | 0/144 [00:00<?, ?it/s]

### Head Decomposition


Decomposing attention layers into patching in individual heads has already helped us localise the behaviour a lot. But we can understand it further by decomposing heads. An attention head consists of two semi-independent operations - calculating _where_ to move information _from and to_ (represented by the attention pattern and implemented via the `QK-circuit`) and calculating _what_ information to move (represented by the value vectors and implemented by the `OV-circuit`). We can disentangle which of these is important by patching in just the attention pattern or the value vectors.

1. First let's patch in the value vectors, to measure when figuring out what to move is important. This has the same shape as z ([batch, pos, head_index, d_head]) so we can reuse the same hook.

2. Or with `patching.act_patch_attn_head_all_pos_every`

In [45]:
# 2
act_patch_attn_head_all_pos_every = patching.get_act_patch_attn_head_all_pos_every(
    model,
    corrupted_tokens,
    ioi_cache,
    ioi_metric
)

imshow(
    act_patch_attn_head_all_pos_every,
    facet_col=0,
    facet_labels=["Output", "Query", "Key", "Value", "Attn Pattern"],
    title="Activation Patching Per Head (All Pos)",
    labels={"x": "Head", "y": "Layer"},
)

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

In [46]:
# 1
patched_head_v_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=ioi_cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("v", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, clean_answers_tokens_ids)

        patched_head_v_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [47]:
imshow(
    patched_head_v_diff,
    title="Logit Difference From Patched Head Value",
    labels={"x": "Head", "y": "Layer"},
)

In [48]:
head_labels = [
    f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)
]
scatter(
    x=utils.to_numpy(patched_head_v_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    xaxis="Value Patch",
    yaxis="Output Patch",
    caxis="Layer",
    hover_name=head_labels,
    color=einops.repeat(
        np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads
    ),
    range_x=(-0.5, 0.5),
    range_y=(-0.5, 0.5),
    title="Scatter plot of output patching vs value patching",
)

In [49]:
def patch_head_pattern(
    corrupted_head_pattern: Float[torch.Tensor, "batch head_index query_pos d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][:, head_index, :, :]
    return corrupted_head_pattern


patched_head_attn_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=ioi_cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("attn", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, clean_answers_tokens_ids)

        patched_head_attn_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [50]:
imshow(
    patched_head_attn_diff,
    title="Logit Difference From Patched Head Attn Pattern",
    labels={"x": "Head", "y": "Layer"},
)
head_labels = [
    f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)
]
scatter(
    x=utils.to_numpy(patched_head_attn_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    hover_name=head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching",
)

## Zooming out

We can now visualize attention patterns of different important heads; in terms of their patched attention pattern (QK circuit), their value vectors or their output vectors. We take the top 10 heads by output patching (in absolute value) and split it into early, middle and late.  

In [56]:
top_k = 10
top_heads_by_output_patch = torch.topk(
    patched_head_z_diff.abs().flatten(), k=top_k
).indices
first_mid_layer = 7
first_late_layer = 9
early_heads = top_heads_by_output_patch[
    top_heads_by_output_patch < model.cfg.n_heads * first_mid_layer
]
mid_heads = top_heads_by_output_patch[
    torch.logical_and(
        model.cfg.n_heads * first_mid_layer <= top_heads_by_output_patch,
        top_heads_by_output_patch < model.cfg.n_heads * first_late_layer,
    )
]
late_heads = top_heads_by_output_patch[
    model.cfg.n_heads * first_late_layer <= top_heads_by_output_patch
]

early = visualize_attention_patterns(
    early_heads, ioi_cache, strip_eot(clean_tokens[0]), title=f"Top Early Heads"
)
mid = visualize_attention_patterns(
    mid_heads, ioi_cache, strip_eot(clean_tokens[0]), title=f"Top Middle Heads"
)
late = visualize_attention_patterns(
    late_heads, ioi_cache, strip_eot(clean_tokens[0]), title=f"Top Late Heads"
)

# HTML(early + mid + late)
HTML(early + late)

### Investigating Early Heads

From the IOI paper, the circuit that solves the prompt "When Mary and John went to the store, John gave a drink to Mary" includes Attention Heads from three categories, early, middle and late heads. From the early heads, which are Duplicate Token Heads, Previous Token Heads and Induction Heads we now focus on the Induction Heads.

Induction Heads are a special class of Attention Heads, discovered in early Mech-Interp research that allegedly solve most of the In-Context Learning tasks, including our own and IOI. They essentially continue repeated sequences of tokens, done in two steps, as described in the Anthropic paper [In-Context Learning and Induction Heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html):

1. Prefix-Matching: this is done via the QK circuit, which looks for the previous occurence of the current token and attends to it. That is, it attends to the token which induction would suggest comes next [AB] ... [A] -> [B]

2. Copying: The head's output increases the logit corresponding to the attended-to token via the OV circuit, which is characterized by it's positive eigenvalues (similar direction in the embedding space).

![Move image demo](https://pbs.twimg.com/media/FNWAzXjVEAEOGRe.jpg)

We can inspect induction heads by passing in repeated text and plotting the attention pattern that implements the prefix-matching step.

In [63]:
example_text = ioi_dataset.sentences[0]
example_repeated_text = example_text + example_text
example_repeated_tokens = model.to_tokens(example_repeated_text, prepend_bos=True)
example_repeated_logits, example_repeated_cache = model.run_with_cache(
    example_repeated_tokens
)
induction_head_labels = [81, 65]

In [64]:
code = visualize_attention_patterns(
    induction_head_labels,
    example_repeated_cache,
    example_repeated_tokens,
    title="Induction Heads",
    max_width=800,
)
HTML(code)

The OV Circuit for a head (the factorised matrix $W_OV = W_V W_O$) is a linear map that determines what information is moved from the source position to the destination position. Because this is low rank, it can be thought of as *reading in* some low rank subspace of the source residual stream and *writing to* some low rank subspace of the destination residual stream (with maybe some processing happening in the middle).

A common operation for this will just be to *copy*, ie to have the same reading and writing subspace, and to do minimal processing in the middle. Empirically, this tends to coincide with the OV Circuit having (approximately) positive real eigenvalues. I mostly assert this as an empirical fact, but intuitively, operations that involve mapping eigenvectors to different directions (eg rotations) tend to have complex eigenvalues. And operations that preserve eigenvector direction but negate it tend to have negative real eigenvalues. And "what happens to the eigenvectors" is a decent proxy for what happens to an arbitrary vector.

We can get a score for "how positive real the OV circuit eigenvalues are" with $\frac{\sum \lambda_i}{\sum |\lambda_i|}$, where $\lambda_i$ are the eigenvalues of the OV circuit. This is a bit of a hack, but it seems to work well in practice.

Let's use FactoredMatrix to compute this for every head in the model! We use the helper `model.OV` to get the concatenated OV circuits for all heads across all layers in the model. This has the shape `[n_layers, n_heads, d_model, d_model]`, where `n_layers` and `n_heads` are batch dimensions and the final two dimensions are factorised as `[n_layers, n_heads, d_model, d_head]` and `[n_layers, n_heads, d_head, d_model]` matrices.

We can then get the eigenvalues for this, where there are separate eigenvalues for each element of the batch (a `[n_layers, n_heads, d_head]` tensor of complex numbers), and calculate the copying score.

In [65]:
OV_circuit_all_heads = model.OV
OV_circuit_all_heads_eigenvalues = OV_circuit_all_heads.eigenvalues 
print(OV_circuit_all_heads_eigenvalues.shape)
OV_copying_score = OV_circuit_all_heads_eigenvalues.sum(dim=-1).real / OV_circuit_all_heads_eigenvalues.abs().sum(dim=-1)
imshow(utils.to_numpy(OV_copying_score), xaxis="Head", yaxis="Layer", title="OV Copying Score for each head in GPT-2 Small", zmax=1.0, zmin=-1.0)

torch.Size([12, 12, 64])


Now we can visualize the eigenvalues of a selected head.

In [66]:
scatter(x=OV_circuit_all_heads_eigenvalues[-3, -3, :].real, y=OV_circuit_all_heads_eigenvalues[-3, -3, :].imag, title="Eigenvalues of Head L9H9 of GPT-2 Small", xaxis="Real", yaxis="Imaginary")

We can even look at the full OV circuit, from the input tokens to output tokens: $W_E W_V W_O W_U$. This is a `[d_vocab, d_vocab]==[50257, 50257]` matrix, so absolutely enormous, even for a single head. But with the FactoredMatrix class, we can compute the full eigenvalue copying score of every head in a few seconds.

In [67]:
full_OV_circuit = model.embed.W_E @ OV_circuit_all_heads @ model.unembed.W_U
print(full_OV_circuit)
full_OV_circuit_eigenvalues = full_OV_circuit.eigenvalues

full_OV_copying_score = full_OV_circuit_eigenvalues.sum(dim=-1).real / full_OV_circuit_eigenvalues.abs().sum(dim=-1)
imshow(utils.to_numpy(full_OV_copying_score), xaxis="Head", yaxis="Layer", title="OV Copying Score for each head in GPT-2 Small", zmax=1.0, zmin=-1.0)

FactoredMatrix: Shape(torch.Size([12, 12, 50257, 50257])), Hidden Dim(64)


In [68]:
scatter(x=full_OV_copying_score.flatten(), y=OV_copying_score.flatten(), hover_name=[f"L{layer}H{head}" for layer in range(12) for head in range(12)], title="OV Copying Score for each head in GPT-2 Small", xaxis="Full OV Copying Score", yaxis="OV Copying Score")

We can characterise an **induction head** by just giving a sequence of random tokens repeated once, and measuring the average attention paid from the *second copy* of a token to the token *after* the first copy. At the same time, we can also measure the average attention paid *from the second copy* of a token to the *first* copy of the token, which is the attention that the induction head would pay if it were a **duplicate token head**, and the average attention paid to the *previous* token to find **previous token heads**.

In conformity with the IOI paper, these 3 types of heads compose the early heads of the circuit.

<details> <summary>Technical Implementation Details</summary> 
We can do this again by using hooks, this time just to access the attention patterns rather than to intervene on them. 

Our hook function acts on the attention pattern activation. This has the name
"blocks.{layer}.{layer_type}.hook_{activation_name}" in general, here it's
"blocks.{layer}.attn.hook_attn". And it has shape [batch, head_index, query_pos, token_pos]. Our
hook function takes in the attention pattern activation, calculates the score for the relevant type
of head, and write it to an external cache.

We add in hooks using `model.run_with_hooks(tokens, fwd_hooks=[(names_filter, hook_fn)])` to
temporarily add in the hooks and run the model, getting the resulting output. Previously
names_filter was the name of the activation, but here it's a boolean function mapping activation
names to whether we want to hook them or not. Here it's just whether the name ends with hook_attn.
hook_fn must take in the two inputs activation (the activation tensor) and hook (the HookPoint
object, which contains the name of the activation and some metadata such as the current layer).

Internally our hooks use the function `tensor.diagonal`, this takes the diagonal between two
dimensions, and allows an arbitrary offset - offset by 1 to get previous tokens, seq_len to get
duplicate tokens (the distance to earlier copies) and seq_len-1 to get induction heads (the distance
to the token *after* earlier copies). Different offsets give a different length of output tensor,
and we can now just average to get a score in [0, 1] for each head.
</details>

In [69]:
seq_len = len(model.to_str_tokens(clean_tokens[0]))
batch_size = 1

prev_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)

def prev_token_hook(pattern, hook):
    layer = hook.layer()
    diagonal = pattern.diagonal(offset=1, dim1=-1, dim2=-2)
    # print(diagonal)
    # print(pattern)
    prev_token_scores[layer] = einops.reduce(
        diagonal, "batch head_index diagonal -> head_index", "mean"
    )


duplicate_token_scores = torch.zeros(
    (model.cfg.n_layers, model.cfg.n_heads), device=device
)


def duplicate_token_hook(pattern, hook):
    layer = hook.layer()
    diagonal = pattern.diagonal(offset=seq_len, dim1=-1, dim2=-2)
    duplicate_token_scores[layer] = einops.reduce(
        diagonal, "batch head_index diagonal -> head_index", "mean"
    )


induction_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)


def induction_hook(pattern, hook):
    layer = hook.layer()
    diagonal = pattern.diagonal(offset=seq_len - 1, dim1=-1, dim2=-2)
    induction_scores[layer] = einops.reduce(
        diagonal, "batch head_index diagonal -> head_index", "mean"
    )


torch.manual_seed(0)
original_tokens = torch.randint(
    100, 20000, size=(batch_size, seq_len), device="cpu"
).to(device)
repeated_tokens = einops.repeat(
    original_tokens, "batch seq_len -> batch (2 seq_len)"
).to(device)

pattern_filter = lambda act_name: act_name.endswith("hook_pattern")

loss = model.run_with_hooks(
    repeated_tokens,
    return_type="loss",
    fwd_hooks=[
        (pattern_filter, prev_token_hook),
        (pattern_filter, duplicate_token_hook),
        (pattern_filter, induction_hook),
    ],
)
print(torch.round(utils.get_corner(prev_token_scores).detach().cpu(), decimals=3))
print(torch.round(utils.get_corner(duplicate_token_scores).detach().cpu(), decimals=3))
print(torch.round(utils.get_corner(induction_scores).detach().cpu(), decimals=3))

tensor([[0.0920, 0.0000, 0.0710],
        [0.2040, 0.2060, 0.1310],
        [0.1730, 0.0540, 0.5050]])
tensor([[0.0210, 0.3740, 0.0470],
        [0.0010, 0.0020, 0.0120],
        [0.0070, 0.0360, 0.0000]])
tensor([[0.0300, 0.0000, 0.0370],
        [0.0030, 0.0030, 0.0140],
        [0.0110, 0.0310, 0.0060]])


We can now plot the head scores, and instantly see that the relevant early heads are induction heads or duplicate token heads.

In [70]:
imshow(
    prev_token_scores, labels={"x": "Head", "y": "Layer"}, title="Previous Token Scores"
)
imshow(
    duplicate_token_scores,
    labels={"x": "Head", "y": "Layer"},
    title="Duplicate Token Scores",
)
imshow(
    induction_scores, labels={"x": "Head", "y": "Layer"}, title="Induction Head Scores"
)

Can also be done with `transformer_lens.head_detector`. From docs: currently supported heads are `"previous_token_head", "duplicate_token_head", "induction_head"`. The advantage of using this is because we can pass in all prompts with the `seq` argument.

In [46]:
from transformer_lens import head_detector

head_detector.get_supported_heads()

previous_token_score = head_detector.detect_head(model=model, seq=clean_prompts, cache=ioi_cache, detection_pattern="previous_token_head")
# imshow(previous_token_score, labels={"x": "Head", "y": "Layer"}, title="Previous Heads")

duplicate_token_score = head_detector.detect_head(model=model, seq=clean_prompts, cache=ioi_cache, detection_pattern="duplicate_token_head")
# imshow(duplicate_token_scores, labels={"x": "Head", "y": "Layer"}, title="Duplicate Heads")

induction_token_score = head_detector.detect_head(model=model, seq=clean_prompts, cache=ioi_cache, detection_pattern="induction_head")
# imshow(induction_token_score, labels={"x": "Head", "y": "Layer"}, title="Induction Heads")

Supported heads: ('previous_token_head', 'duplicate_token_head', 'induction_head')


### Backup Name Movers

In chapter "Exploring Anomalies" from the Exploratory Analysis Demo notebook, the Backup Name Movers are investigated, as part of the late heads. They form when the Name Movers are ablated (this can be approached in two ways: ablating with *zero* and with the *mean* of the other heads) as helpers to continue the task that the Name Movers were previously doing.

<details> <summary>Implementation Details</summary> 

Ablating heads is really easy in TransformerLens! We can just define a hook on the `z` activation in the relevant attention layer (recall, `z` is the mixed values, and comes immediately before multiplying by the output weights $W_O$). `z` has a head_index axis, so we can set the component for the relevant head and for position -1 to zero, and return it. (Technically we could just edit in place without returning it, but by convention we always return an edited activation). 

We now want to compare all internal activations with a hook, which is hard to do with the nice `run_with_hooks` API. So we can directly access the hook on the `z` activation with `model.blocks[layer].attn.hook_z` and call its `add_hook` method. This adds in the hook to the *global state* of the model. We can now use `run_with_cache`, and don't need to care about the global state, `because run_with_cache` internally adds a bunch of caching hooks, and then removes all hooks after the run, *including* the previously added ablation hook. This can be disabled with the `reset_hooks_end` flag, but here it's useful! 

</details>

In [71]:
top_name_mover = per_head_logit_diffs.flatten().argmax().item()
top_name_mover_layer = top_name_mover // model.cfg.n_heads
top_name_mover_head = top_name_mover % model.cfg.n_heads
print(f"Top Name Mover to ablate: L{top_name_mover_layer}H{top_name_mover_head}")

def ablate_top_head_hook(z: Float[torch.Tensor, "batch pos head_index d_head"], hook):
    z[:, -1, top_name_mover_head, :] = 0
    return z

# Adds a hook into global model state
model.blocks[top_name_mover_layer].attn.hook_z.add_hook(ablate_top_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(clean_tokens)
print(f"Original logit diff: {ioi_average_logit_diff:.2f}")
print(
    f"Post ablation logit diff: {logits_to_ave_logit_diff(ablated_logits, clean_answers_tokens_ids).item():.2f}"
)
print(
    f"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item():.2f}"
)
print(
    f"Naive prediction of post ablation logit diff: {ioi_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item():.2f}"
)

Top Name Mover to ablate: L9H9
Original logit diff: 3.15
Post ablation logit diff: 0.21
Direct Logit Attribution of top name mover head: 0.25
Naive prediction of post ablation logit diff: 2.90


In [72]:
def residual_stack_to_logit_diff(
    residual_stack: Float[torch.Tensor, "components batch d_model"],
    cache: ActivationCache,
) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(clean_prompts)

answer_residual_directions = model.tokens_to_residual_directions(clean_answers_tokens_ids)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

per_head_ablated_residual, labels = ablated_cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(
    per_head_ablated_residual, ablated_cache
)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(
    model.cfg.n_layers, model.cfg.n_heads
)
imshow(
    torch.stack([
        per_head_logit_diffs,
        per_head_ablated_logit_diffs,
        per_head_ablated_logit_diffs - per_head_logit_diffs
    ]),
    title="Direct logit contribution by head, pre / post ablation",
    labels={"x":"Head", "y":"Layer"},
    facet_col=0,
    facet_labels=["No ablation", "9.9 is ablated", "Change in head contribution post-ablation"],
)
scatter(
    y=per_head_logit_diffs.flatten(),
    x=per_head_ablated_logit_diffs.flatten(),
    hover_name=head_labels,
    range_x=(-3, 3),
    range_y=(-3, 3),
    xaxis="Ablated",
    yaxis="Original",
    title="Original vs Post-Ablation Direct Logit Attribution of Heads",
)

Answer residual directions shape: torch.Size([100, 2, 768])
Logit difference directions shape: torch.Size([100, 768])
Tried to stack head results when they weren't cached. Computing head results now


How can we interpet these plots?

1. The first plot indicates that Backup Name Movers form after ablating L9H9. They are: L8H10, L10H6, L10H10.

2. What we measure in the second plot is the effect of ablating the top Name Mover in terms of logit attribution. So this effect is actually how the direct logit attribution of each individual head is modified because of the ablation of the top Name Mover. We see that ablating L9H9 actually decreases the negaive effect of Negative Name Movers, such as L10H7 and L11H2. 

If there is evidence on the existance of Backup Head Movers, looking at the Original vs Post-Ablation Direct Logit Attribution of Heads is useful to determine by how much these class of heads write in the same direction as the Name Movers, so it's not just a form of validating them as part of the circuit. 

One natural hypothesis is that this is because the final LayerNorm scaling has changed, which can scale up or down the final residual stream. This is slightly true, and we can see that the typical head is a bit off from the x=y line. But the average LN scaling ratio is 1.05, and this should uniformly change all heads by the same factor, so this can't be sufficient.

In [73]:
print(
    "Average LN scaling ratio:",
    round(
        (
            ioi_cache["ln_final.hook_scale"][:, -1]
            / ablated_cache["ln_final.hook_scale"][:, -1]
        )
        .mean()
        .item(),
        3,
    ),
)
print(
    "Ablation LN scale",
    ablated_cache["ln_final.hook_scale"][:, -1].detach().cpu().round(decimals=2),
)
print(
    "Original LN scale",
    ioi_cache["ln_final.hook_scale"][:, -1].detach().cpu().round(decimals=2),
)

Average LN scaling ratio: 1.003
Ablation LN scale tensor([[23.6700],
        [23.4800],
        [23.4100],
        [23.4000],
        [23.2900],
        [23.1600],
        [23.6200],
        [23.2800],
        [23.2900],
        [23.4200],
        [22.3500],
        [23.3800],
        [22.2200],
        [23.5600],
        [23.6100],
        [23.2400],
        [15.8600],
        [23.5100],
        [23.2000],
        [17.3000],
        [22.3100],
        [23.3100],
        [23.3300],
        [23.2200],
        [23.1600],
        [23.4300],
        [23.3900],
        [23.3700],
        [22.3100],
        [23.4900],
        [15.4400],
        [23.2300],
        [22.3900],
        [23.2300],
        [23.3900],
        [23.4800],
        [23.3100],
        [23.4600],
        [23.2000],
        [23.2400],
        [22.3700],
        [23.2900],
        [23.2800],
        [16.8800],
        [23.2800],
        [16.2100],
        [23.2600],
        [23.3200],
        [23.4000],
        [14.6900],


### Copying & Writing Directions of NMs and Negative NMs

In [74]:
def check_copy_circuit(model, layer, head, ioi_dataset, verbose=False, neg=False):
    cache = {}
    def hook_fn(activation: torch.Tensor, hook: HookPoint, name: str = "activation"):
        '''Stores activations in hook context.'''
        cache[name] = activation
        return activation
    # model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    module_name = "blocks.0.hook_resid_post"
    fwd_hooks: List[Tuple[str, callable]] = [
        (module_name, partial(hook_fn, name='l0_hook_resid_post'))
    ]
    model.run_with_hooks(ioi_dataset.toks.long(), fwd_hooks=fwd_hooks)
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = cache["l0_hook_resid_post"]

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    for seq_idx, prompt in enumerate(ioi_dataset.ioi_prompts):
        for word in ["IO", "S1", "S2"]:
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, ioi_dataset.word_idx[word][seq_idx]], k
                ).indices
            ]
            if "S" in word:
                name = "S"
            else:
                name = word
            if " " + prompt[name] in pred_tokens:
                n_right += 1
            else:
                if verbose:
                    print("-------")
                    print("Seq: " + ioi_dataset.sentences[seq_idx])
                    print("Target: " + ioi_dataset.ioi_prompts[seq_idx][name])
                    print(
                        " ".join(
                            [
                                f"({i+1}):{model.tokenizer.decode(token)}"
                                for i, token in enumerate(
                                    torch.topk(
                                        logits[
                                            seq_idx, ioi_dataset.word_idx[word][seq_idx]
                                        ],
                                        k,
                                    ).indices
                                )
                            ]
                        )
                    )
    percent_right = (n_right / (ioi_dataset.N * 3)) * 100
    print(
        f"Copy circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%"
    )
    return percent_right


neg_sign = False
print(" --- Name Mover heads --- ")
check_copy_circuit(model, 9, 9, ioi_dataset, neg=neg_sign)
check_copy_circuit(model, 10, 0, ioi_dataset, neg=neg_sign)
check_copy_circuit(model, 9, 6, ioi_dataset, neg=neg_sign)

neg_sign = True
print(" --- Negative heads --- ")
check_copy_circuit(model, 10, 7, ioi_dataset, neg=neg_sign)
check_copy_circuit(model, 11, 10, ioi_dataset, neg=neg_sign)

neg_sign = False
print(" ---  Random heads for control ---  ")
check_copy_circuit(
    model, random.randint(0, 11), random.randint(0, 11), ioi_dataset, neg=neg_sign
)
check_copy_circuit(
    model, random.randint(0, 11), random.randint(0, 11), ioi_dataset, neg=neg_sign
)
check_copy_circuit(
    model, random.randint(0, 11), random.randint(0, 11), ioi_dataset, neg=neg_sign
)

 --- Name Mover heads --- 
Copy circuit for head 9.9 (sign=1) : Top 5 accuracy: 100.0%
Copy circuit for head 10.0 (sign=1) : Top 5 accuracy: 94.66666666666667%
Copy circuit for head 9.6 (sign=1) : Top 5 accuracy: 97.66666666666667%
 --- Negative heads --- 
Copy circuit for head 10.7 (sign=-1) : Top 5 accuracy: 100.0%
Copy circuit for head 11.10 (sign=-1) : Top 5 accuracy: 100.0%
 ---  Random heads for control ---  
Copy circuit for head 9.10 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 3.2 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 6.1 (sign=1) : Top 5 accuracy: 0.0%


0.0

Attention probability vs projection of the head output along $W_{U}$[IO] or $W_{U}$[S] respectively. 




In [75]:
def scatter_embedding_vs_attn(
    attn_from_end_to_io: torch.FloatTensor,
    attn_from_end_to_s: torch.FloatTensor,
    projection_in_io_dir: torch.FloatTensor,
    projection_in_s_dir: torch.FloatTensor,
    layer: int,
    head: int
):
    x = torch.concat([attn_from_end_to_io, attn_from_end_to_s], dim=0).cpu()
    y = torch.concat([projection_in_io_dir, projection_in_s_dir], dim=0).cpu()
    # color=["IO"] * N + ["S"] * N,
    color = ["IO"] * len(attn_from_end_to_io) + ["S"] * len(attn_from_end_to_s)


    fig = px.scatter(
        x=x,
        y=y,
        color=color,
        title=f"Projection of the output of {layer}.{head} along the name<br>embedding vs attention probability on name",
        labels={"x": "Attn prob on name", "y": "Dot w Name Embed", "color": "Name type"},
        color_discrete_sequence=["#72FF64", "#C9A5F7"],
        width=650
    )

    fig.show()

def calculate_and_show_scatter_embedding_vs_attn(
    layer: int,
    head: int,
    cache: ActivationCache = ioi_cache,
    dataset: IOIDataset = ioi_dataset,
) -> None:
    '''
    Creates and plots a figure equivalent to 3(c) in the paper.

    This should involve computing the four 1D tensors:
        attn_from_end_to_io
        attn_from_end_to_s
        projection_in_io_dir
        projection_in_s_dir
    and then calling the scatter_embedding_vs_attn function.
    '''
    # SOLUTION
    # Get the value written to the residual stream at the end token by this head
    z: Float[Tensor, "batch seq d_head"] = cache[utils.get_act_name("z", layer)][:, :, head]
    N = z.size(0)
    output: Float[Tensor, "batch seq d_model"] = z @ model.W_O[layer, head]
    output_on_end_token: Float[Tensor, "batch d_model"] = output[torch.arange(N), dataset.word_idx["end"]]

    # Get the directions we'll be projecting onto
    io_unembedding: Float[Tensor, "batch d_model"] = model.W_U.T[dataset.io_tokenIDs]
    s_unembedding: Float[Tensor, "batch d_model"] = model.W_U.T[dataset.s_tokenIDs]

    # Get the value of projections, by multiplying and summing over the d_model dimension
    projection_in_io_dir: Float[Tensor, "batch"] = (output_on_end_token * io_unembedding).sum(-1)
    projection_in_s_dir: Float[Tensor, "batch"] = (output_on_end_token * s_unembedding).sum(-1)

    # Get attention probs, and index to get the probabilities from END -> IO / S
    attn_probs: Float[Tensor, "batch q k"] = cache["pattern", layer][:, head]
    attn_from_end_to_io = attn_probs[torch.arange(N), dataset.word_idx["end"], dataset.word_idx["IO"]]
    attn_from_end_to_s = attn_probs[torch.arange(N), dataset.word_idx["end"], dataset.word_idx["S1"]]

    # Show scatter plot
    scatter_embedding_vs_attn(
        attn_from_end_to_io,
        attn_from_end_to_s,
        projection_in_io_dir,
        projection_in_s_dir,
        layer,
        head
    )

nmh = (9, 9)
calculate_and_show_scatter_embedding_vs_attn(*nmh)

nnmh = (11, 10)
calculate_and_show_scatter_embedding_vs_attn(*nnmh)

Validating Name Movers and Negative Name Movers can be done by their OV-circuit by studying what values are written via the heads’ OV matrix.

From the IOI paper: 

"To check that the Name Mover Heads copy names generally, we studied what values are written via the heads’ OV matrix. Specifically, we first obtained the state of the residual stream at the position of each name token after the first MLP layer. Then, we multiplied this by the OV matrix of a Name Mover Head (simulating what would happen if the head attended perfectly to that token), multiplied by the unembedding matrix, and applied the final layer norm to obtain logit probabilities. We compute the proportion of samples that contain the input name token in the top 5 logits (N = 1000) and call this the copy score. All three Name Mover Heads have a copy score above 95%, compared to less than 20% for an average head.

Negative Name Mover Heads ... have a large negative copy score–the copy score calculated with the negative of the OV matrix (98% compared to 12% for an average head)."

<details><summary>Technical details for Copying Scores</summary>

You should replicate these results by completing the `get_copying_scores` function below.

You could do this by indexing from the `ioi_cache`, but a much more principled alternative would be to embed all the names in the `NAMES` list and apply operations like MLPs, layernorms and OV matrices manually. This is what the solutions do.

A few notes:

- You can use `model.to_tokens` to convert the names to tokens. Remember to use `prepend_bos=False`, since you just want the tokens of names so you can embed them. Note that this function will treat the list of names as a batch of single-token inputs, which works fine for our purposes.

- You can apply MLPs and layernorms as functions, by just indexing the model's blocks (e.g. use `model.blocks[i].mlp` or `model.blocks[j].ln1` as a function). Remember that `ln1` is the layernorm that comes *before attention*, and `ln2` comes *before the MLP*.

- Remember that you need to apply MLP0 before you apply the OV matrix (which is why we omit the 0th layer in our scores). The reason for this is that ablating MLP0 has a strangely large effect in gpt2-small relative to ablating other MLPs, possibly because it's acting as an extended embedding (see here for an explanation).

In [76]:
from jaxtyping import Float, Int, Bool
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP

def get_copying_scores(
    model: HookedTransformer,
    k: int = 5,
    names: list = NAMES
) -> Float[Tensor, "2 layer-1 head"]:
    '''
    Gets copying scores (both positive and negative) as described in page 6 of the IOI paper, for every (layer, head) pair in the model.

    Returns these in a 3D tensor (the first dimension is for positive vs negative).
    '''
    # SOLUTION
    results = torch.zeros((2, model.cfg.n_layers, model.cfg.n_heads), device="cuda")

    # Define components from our model (for typechecking, and cleaner code)
    embed: Embed = model.embed
    mlp0: MLP = model.blocks[0].mlp
    ln0: LayerNorm = model.blocks[0].ln2
    unembed: Unembed = model.unembed
    ln_final: LayerNorm = model.ln_final

    # Get embeddings for the names in our list
    name_tokens: Int[Tensor, "batch 1"] = model.to_tokens(names, prepend_bos=False)
    name_embeddings: Int[Tensor, "batch 1 d_model"] = embed(name_tokens)

    # Get residual stream after applying MLP
    resid_after_mlp1 = name_embeddings + mlp0(ln0(name_embeddings))

    # Loop over all (layer, head) pairs
    for layer in range(1, model.cfg.n_layers):
        for head in range(model.cfg.n_heads):

            # Get W_OV matrix
            W_OV = model.W_V[layer, head] @ model.W_O[layer, head]

            # Get residual stream after applying W_OV or -W_OV respectively
            # (note, because of bias b_U, it matters that we do sign flip here, not later)
            resid_after_OV_pos = resid_after_mlp1 @ W_OV
            resid_after_OV_neg = resid_after_mlp1 @ -W_OV

            # Get logits from value of residual stream
            logits_pos: Float[Tensor, "batch d_vocab"] = unembed(ln_final(resid_after_OV_pos)).squeeze()
            logits_neg: Float[Tensor, "batch d_vocab"] = unembed(ln_final(resid_after_OV_neg)).squeeze()

            # Check how many are in top k
            topk_logits: Int[Tensor, "batch k"] = torch.topk(logits_pos, dim=-1, k=k).indices
            in_topk = (topk_logits == name_tokens).any(-1)
            # Check how many are in bottom k
            bottomk_logits: Int[Tensor, "batch k"] = torch.topk(logits_neg, dim=-1, k=k).indices
            in_bottomk = (bottomk_logits == name_tokens).any(-1)

            # Fill in results
            results[:, layer-1, head] = torch.tensor([in_topk.float().mean(), in_bottomk.float().mean()])

    return results

In [77]:
copying_results = get_copying_scores(model)

imshow(
    copying_results,
    facet_col=0,
    facet_labels=["Positive copying scores", "Negative copying scores"],
    title="Copying scores of attention heads' OV circuits",
    width=800
)

In [78]:
from rich.table import Table, Column
from rich import print as rprint

def make_table(cols, colnames, title="", n_rows=5, decimals=4):
    '''Makes and displays a table, from cols rather than rows (using rich print)'''
    table = Table(*colnames, title=title)
    rows = list(zip(*cols))
    f = lambda x: x if isinstance(x, str) else f"{x:.{decimals}f}"
    for row in rows[:n_rows]:
        table.add_row(*list(map(f, row)))
    rprint(table)

heads = {"name mover": [(8, 11), (9, 6), (9, 9), (10, 0)], "negative name mover": [(10, 7), (11, 10)]}

for i, name in enumerate(["name mover", "negative name mover"]):
    make_table(
        title=f"Copying Scores ({name} heads)",
        colnames=["Head", "Score"],
        cols=[
            list(map(str, heads[name])) + ["[dark_orange bold]Average"],
            [f"{copying_results[i, layer-1, head]:.2%}" for (layer, head) in heads[name]] + [f"[dark_orange bold]{copying_results[i].mean():.2%}"]
        ]
    )

In [79]:
def generate_repeated_tokens(
    model: HookedTransformer,
    seq_len: int,
    batch: int = 1
) -> Float[Tensor, "batch 2*seq_len"]:
    '''
    Generates a sequence of repeated random tokens (no start token).
    '''
    rep_tokens_half = torch.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=torch.int64)
    rep_tokens = torch.cat([rep_tokens_half, rep_tokens_half], dim=-1).to(device)
    return rep_tokens


def get_attn_scores(
    model: HookedTransformer,
    seq_len: int,
    batch: int,
    head_type: Literal["duplicate", "prev", "induction"]
):
    '''
    Returns attention scores for sequence of duplicated tokens, for every head.
    '''
    # SOLUTION
    rep_tokens = generate_repeated_tokens(model, seq_len, batch)

    _, cache = model.run_with_cache(
        rep_tokens,
        return_type=None,
        names_filter=lambda name: name.endswith("pattern")
    )

    # Get the right indices for the attention scores

    if head_type == "duplicate":
        src_indices = range(seq_len)
        dest_indices = range(seq_len, 2 * seq_len)
    elif head_type == "prev":
        src_indices = range(seq_len)
        dest_indices = range(1, seq_len + 1)
    elif head_type == "induction":
        dest_indices = range(seq_len, 2 * seq_len)
        src_indices = range(1, seq_len + 1)
    else:
        raise ValueError(f"Unknown head type {head_type}")

    results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device="cuda", dtype=torch.float32)
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            attn_scores: Float[Tensor, "batch head dest src"] = cache["pattern", layer]
            avg_attn_on_duplicates = attn_scores[:, head, dest_indices, src_indices].mean().item()
            results[layer, head] = avg_attn_on_duplicates

    return results


def plot_early_head_validation_results(seq_len: int = 50, batch: int = 50):
    '''
    Produces a plot that looks like Figure 18 in the paper.
    '''
    head_types = ["duplicate", "prev", "induction"]

    results = torch.stack([
        get_attn_scores(model, seq_len, batch, head_type=head_type)
        for head_type in head_types
    ])

    imshow(
        results,
        facet_col=0,
        facet_labels=[
            f"{head_type.capitalize()} token attention prob.<br>on sequences of random tokens"
            for head_type in head_types
        ],
        labels={"x": "Head", "y": "Layer"},
        width=1300,
    )

model.reset_hooks()
plot_early_head_validation_results()

# Path Patching

Path Patching is an algorithm that can be used to isolate the direct effect of a path (input -> **head** -> output logits) from the indirect effects (input -> MLP -> **head 1** -> head 2 -> output logits) of other components in the Transformer. 

In the IOI paper Path Patching is used as follows: 

- start from the Name Mover Heads with the highest direct effect on logits (from Logit Difference)
- work backwards; which Head influences the Name Mover Heads, then which heads influences those heads and so on
    - note that Path Patching supports receiver heads as Q/K/V as `attn.hook_q/k/v`
- it is important for patching at specific position from the prompt. the question can be framed as: Which Heads influence the Name Movers at position `END`?  
    - the implementation `ioi_utils.path_patching` supports patching at specific positions and it will be throughout this section

In [None]:
# activations and shapes of the cache dictionary
for k, v in ioi_cache.values() and ioi_cache.items():
    print(f"{k}, {v.shape}")

hook_embed, torch.Size([100, 21, 768])
hook_pos_embed, torch.Size([100, 21, 768])
blocks.0.hook_resid_pre, torch.Size([100, 21, 768])
blocks.0.ln1.hook_scale, torch.Size([100, 21, 1])
blocks.0.ln1.hook_normalized, torch.Size([100, 21, 768])
blocks.0.attn.hook_q, torch.Size([100, 21, 12, 64])
blocks.0.attn.hook_k, torch.Size([100, 21, 12, 64])
blocks.0.attn.hook_v, torch.Size([100, 21, 12, 64])
blocks.0.attn.hook_attn_scores, torch.Size([100, 12, 21, 21])
blocks.0.attn.hook_pattern, torch.Size([100, 12, 21, 21])
blocks.0.attn.hook_z, torch.Size([100, 21, 12, 64])
blocks.0.hook_attn_out, torch.Size([100, 21, 768])
blocks.0.hook_resid_mid, torch.Size([100, 21, 768])
blocks.0.ln2.hook_scale, torch.Size([100, 21, 1])
blocks.0.ln2.hook_normalized, torch.Size([100, 21, 768])
blocks.0.mlp.hook_pre, torch.Size([100, 21, 3072])
blocks.0.mlp.hook_post, torch.Size([100, 21, 3072])
blocks.0.hook_mlp_out, torch.Size([100, 21, 768])
blocks.0.hook_resid_post, torch.Size([100, 21, 768])
blocks.1.hook_r

In [80]:
from src.patching.path_patching import get_path_patch_head_to_final_resid_post

abc_logits, abc_cache = model.run_with_cache(abc_dataset.toks)
abc_average_logit_diff = logits_to_ave_logit_diff_2(abc_logits)

def ioi_metric_2(
    logits: Float[Tensor, "batch seq d_vocab"],
    clean_logit_diff: float = ioi_average_logit_diff,
    corrupted_logit_diff: float = abc_average_logit_diff,
    ioi_dataset: IOIDataset = ioi_dataset,
) -> float:
    '''
    We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),
    and -1 when performance has been destroyed (i.e. is same as ABC dataset).
    '''
    patched_logit_diff = logits_to_ave_logit_diff_2(logits, ioi_dataset)
    return (patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff)


path_patch_head_to_final_resid_post = get_path_patch_head_to_final_resid_post(
    model, 
    ioi_metric_2,
    abc_dataset,
    ioi_dataset,
    abc_cache,
    ioi_cache,
    )

imshow(
    100 * path_patch_head_to_final_resid_post,
    title="Direct effect on logit difference",
    labels={"x":"Head", "y":"Layer", "color": "Logit diff. variation"},
    # coloraxis=dict(colorbar_ticksuffix = "%"),
    width=600,
)

100%|██████████| 144/144 [00:03<00:00, 39.74it/s]


In [81]:
from src.patching.path_patching import get_path_patch_head_to_heads

model.reset_hooks()

## NM heads queries patching
nm_heads_query_path_patching_results = get_path_patch_head_to_heads(
    receiver_heads = [(9, 9), (9, 6), (10, 0)],
    receiver_input = "q",
    model = model,
    patching_metric = ioi_metric_2,
    new_dataset = abc_dataset,
    orig_dataset = ioi_dataset,
    new_cache= abc_cache,
    orig_cache = ioi_cache,
    positions = "end",
)

imshow(
    100 * nm_heads_query_path_patching_results,
    title="Direct effect on Name Mover Heads' queries",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff.<br>variation"},
    width=600,
    # coloraxis=dict(colorbar_ticksuffix = "%"),
)


  0%|          | 0/120 [00:00<?, ?it/s]

100%|██████████| 120/120 [00:04<00:00, 27.12it/s]


In [82]:
model.reset_hooks()

s_inhibition_value_path_patching_results = get_path_patch_head_to_heads(
    receiver_heads = [(8, 6), (8, 10), (7, 9), (7, 3)],
    receiver_input = "v",
    model = model,
    patching_metric = ioi_metric_2,
    new_dataset = abc_dataset,
    orig_dataset = ioi_dataset,
    new_cache= abc_cache,
    orig_cache = ioi_cache,
    positions = "S2",
)

imshow(
    100 * s_inhibition_value_path_patching_results,
    title="Direct effect on S-Inhibition Heads' values",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff.<br>variation"},
    width=600,
    # coloraxis=dict(colorbar_ticksuffix = "%"),
)

  0%|          | 0/96 [00:00<?, ?it/s]

100%|██████████| 96/96 [00:03<00:00, 27.02it/s]


In [83]:
model.reset_hooks()

s_inhibition_key_path_patching_results = get_path_patch_head_to_heads(
    receiver_heads = [(8, 6), (8, 10), (7, 9), (7, 3)],
    receiver_input = "k",
    model = model,
    patching_metric = ioi_metric_2,
    new_dataset = abc_dataset,
    orig_dataset = ioi_dataset,
    new_cache= abc_cache,
    orig_cache = ioi_cache,
    positions = "S2",
)

imshow(
    100 * s_inhibition_key_path_patching_results,
    title="Direct effect on S-Inhibition Heads' keys",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff.<br>variation"},
    width=600,
    # coloraxis=dict(colorbar_ticksuffix = "%"),
)

  0%|          | 0/96 [00:00<?, ?it/s]

100%|██████████| 96/96 [00:03<00:00, 27.05it/s]


In [84]:
model.reset_hooks()

induction_key_path_patching_results = get_path_patch_head_to_heads(
    receiver_heads = [(5, 5), (5, 8), (5, 9), (6, 9)],
    receiver_input = "k",
    model = model,
    patching_metric = ioi_metric_2,
    new_dataset = abc_dataset,
    orig_dataset = ioi_dataset,
    new_cache= abc_cache,
    orig_cache = ioi_cache,
    positions = "S1+1",
)

imshow(
    100 * induction_key_path_patching_results,
    title="Direct effect on Induction Heads' keys",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff.<br>variation"},
    width=600,
    # coloraxis=dict(colorbar_ticksuffix = "%"),
)



  0%|          | 0/72 [00:00<?, ?it/s]

100%|██████████| 72/72 [00:02<00:00, 27.02it/s]


# Circuit Validation

In [87]:
CIRCUIT = {
    "name mover": [(9, 9), (10, 0), (9, 6)],
    "backup name mover": [(10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (9, 7), (9, 0), (11, 9)],
    "negative name mover": [(10, 7), (11, 10)],
    "s2 inhibition": [(7, 3), (7, 9), (8, 6), (8, 10)],
    "induction": [(5, 5), (5, 8), (5, 9), (6, 9)],
    "duplicate token": [(0, 1), (0, 10), (3, 0)],
    "previous token": [(2, 2), (4, 11)],
}

SEQ_POS_TO_KEEP = {
    "name mover": "end",
    "backup name mover": "end",
    "negative name mover": "end",
    "s2 inhibition": "end",
    "induction": "S2",
    "duplicate token": "S2",
    "previous token": "S1+1",
}

To be clear, the things that we'll be mean-ablating are:

- Every head not in the `CIRCUIT` dict

- Every sequence position for the heads in `CIRCUIT` dict, except for the sequence positions given by the `SEQ_POS_TO_KEEP` dict

And we'll be mean-ablating by replacing a head's output with the mean output for abc_dataset, over all sentences with the same template as the sentence in the batch. You can access the templates for a dataset using the `dataset.groups` attribute, which returns a list of tensors (each one containing the indices of sequences in the batch sharing the same template).

In [88]:
def get_heads_and_posns_to_keep(
    means_dataset: IOIDataset,
    model: HookedTransformer,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
) -> Dict[int, Bool[Tensor, "batch seq head"]]:
    '''
    Returns a dictionary mapping layers to a boolean mask giving the indices of the
    z output which *shouldn'torch* be mean-ablated.

    The output of this function will be used for the hook function that does ablation.
    '''
    heads_and_posns_to_keep = {}
    batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads

    for layer in range(model.cfg.n_layers):

        mask = torch.zeros(size=(batch, seq, n_heads))

        for (head_type, head_list) in circuit.items():
            seq_pos = seq_pos_to_keep[head_type]
            indices = means_dataset.word_idx[seq_pos]
            for (layer_idx, head_idx) in head_list:
                if layer_idx == layer:
                    mask[:, indices, head_idx] = 1

        heads_and_posns_to_keep[layer] = mask.bool()

    return heads_and_posns_to_keep



def hook_fn_mask_z(
    z: Float[Tensor, "batch seq head d_head"],
    hook: HookPoint,
    heads_and_posns_to_keep: Dict[int, Bool[Tensor, "batch seq head"]],
    means: Float[Tensor, "layer batch seq head d_head"],
) -> Float[Tensor, "batch seq head d_head"]:
    '''
    Hook function which masks the z output of a transformer head.

    heads_and_posns_to_keep
        Dict created with the get_heads_and_posns_to_keep function. This tells
        us where to mask.

    means
        Tensor of mean z values of the means_dataset over each group of prompts
        with the same template. This tells us what values to mask with.
    '''
    # Get the mask for this layer, and add d_head=1 dimension so it broadcasts correctly
    mask_for_this_layer = heads_and_posns_to_keep[hook.layer()].unsqueeze(-1).to(z.device)

    # Set z values to the mean
    z = torch.where(mask_for_this_layer, z, means[hook.layer()])

    return z


def compute_means_by_template(
    means_dataset: IOIDataset,
    model: HookedTransformer
) -> Float[Tensor, "layer batch seq head_idx d_head"]:
    '''
    Returns the mean of each head's output over the means dataset. This mean is
    computed separately for each group of prompts with the same template (these
    are given by means_dataset.groups).
    '''
    # Cache the outputs of every head
    _, means_cache = model.run_with_cache(
        means_dataset.toks.long(),
        return_type=None,
        names_filter=lambda name: name.endswith("z"),
    )
    # Create tensor to store means
    n_layers, n_heads, d_head = model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_head
    batch, seq_len = len(means_dataset), means_dataset.max_len
    means = torch.zeros(size=(n_layers, batch, seq_len, n_heads, d_head), device=model.cfg.device)

    # Get set of different templates for this data
    for layer in range(model.cfg.n_layers):
        z_for_this_layer: Float[Tensor, "batch seq head d_head"] = means_cache[utils.get_act_name("z", layer)]
        for template_group in means_dataset.groups:
            z_for_this_template = z_for_this_layer[template_group]
            z_means_for_this_template = einops.reduce(z_for_this_template, "batch seq head d_head -> seq head d_head", "mean")
            means[layer, template_group] = z_means_for_this_template

    return means

def add_mean_ablation_hook(
    model: HookedTransformer,
    means_dataset: IOIDataset,
    circuit: Dict[str, List[Tuple[int, int]]] = CIRCUIT,
    seq_pos_to_keep: Dict[str, str] = SEQ_POS_TO_KEEP,
    is_permanent: bool = True,
) -> HookedTransformer:
    '''
    Adds a permanent hook to the model, which ablates according to the circuit and
    seq_pos_to_keep dictionaries.

    In other words, when the model is run on ioi_dataset, every head's output will
    be replaced with the mean over means_dataset for sequences with the same template,
    except for a subset of heads and sequence positions as specified by the circuit
    and seq_pos_to_keep dicts.
    '''

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = compute_means_by_template(means_dataset, model)

    # Convert this into a boolean map
    heads_and_posns_to_keep = get_heads_and_posns_to_keep(means_dataset, model, circuit, seq_pos_to_keep)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't  important for the circuit
    hook_fn = partial(
        hook_fn_mask_z,
        heads_and_posns_to_keep=heads_and_posns_to_keep,
        means=means
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)

    return model

model = add_mean_ablation_hook(model, means_dataset=abc_dataset, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)

ioi_logits_minimal = model(ioi_dataset.toks)

print(f"Average logit difference (IOI dataset, using entire model): {logits_to_ave_logit_diff_2(ioi_logits_original):.4f}")
print(f"Average logit difference (IOI dataset, only using circuit): {logits_to_ave_logit_diff_2(ioi_logits_minimal):.4f}")

Average logit difference (IOI dataset, using entire model): 3.1490
Average logit difference (IOI dataset, only using circuit): 2.8226


We see that the model cannot generate text, idk why yet.

In [115]:
model.generate("(CNN) President Barack Obama caught in embarrassing new scandal\n", max_new_tokens=20, prepend_bos=True)

  0%|          | 0/20 [00:00<?, ?it/s]

RuntimeError: The size of tensor a (21) must match the size of tensor b (13) at non-singleton dimension 1

In [90]:
CIRCUIT

{'name mover': [(9, 9), (10, 0), (9, 6)],
 'backup name mover': [(10, 10),
  (10, 6),
  (10, 2),
  (10, 1),
  (11, 2),
  (9, 7),
  (9, 0),
  (11, 9)],
 'negative name mover': [(10, 7), (11, 10)],
 's2 inhibition': [(7, 3), (7, 9), (8, 6), (8, 10)],
 'induction': [(5, 5), (5, 8), (5, 9), (6, 9)],
 'duplicate token': [(0, 1), (0, 10), (3, 0)],
 'previous token': [(2, 2), (4, 11)]}

In [91]:
K_FOR_EACH_COMPONENT = {
    (9, 9): set(),
    (10, 0): {(9, 9)},
    (9, 6): {(9, 9), (10, 0)},
    (10, 7): {(11, 10)},
    (11, 10): {(10, 7)},
    (8, 10): {(7, 9), (8, 6), (7, 3)},
    (7, 9): {(8, 10), (8, 6), (7, 3)},
    (8, 6): {(7, 9), (8, 10), (7, 3)},
    (7, 3): {(7, 9), (8, 10), (8, 6)},
    (5, 5): {(5, 9), (6, 9), (5, 8)},
    (5, 9): {(11, 10), (10, 7)},
    (6, 9): {(5, 9), (5, 5), (5, 8)},
    (5, 8): {(11, 10), (10, 7)},
    (0, 1): {(0, 10), (3, 0)},
    (0, 10): {(0, 1), (3, 0)},
    (3, 0): {(0, 1), (0, 10)},
    (4, 11): {(2, 2)},
    (2, 2): {(4, 11)},
    (11, 2): {(9, 9), (10, 0), (9, 6)},
    (10, 6): {(9, 9), (10, 0), (9, 6), (11, 2)},
    (10, 10): {(9, 9), (10, 0), (9, 6), (11, 2), (10, 6)},
    (10, 2): {(9, 9), (10, 0), (9, 6), (11, 2), (10, 6), (10, 10)},
    (9, 7): {(9, 9), (10, 0), (9, 6), (11, 2), (10, 6), (10, 10), (10, 2)},
    (10, 1): {(9, 9), (10, 0), (9, 6), (11, 2), (10, 6), (10, 10), (10, 2), (9, 7)},
    (11, 9): {(9, 9), (10, 0), (9, 6), (9, 0)},
    (9, 0): {(9, 9), (10, 0), (9, 6), (11, 9)},
}

In [92]:
def get_score(
    model: HookedTransformer,
    ioi_dataset: IOIDataset,
    abc_dataset: IOIDataset,
    K: Set[Tuple[int, int]],
    C: Dict[str, List[Tuple[int, int]]],
) -> float:
    '''
    Returns the value F(C \ K), where F is the logit diff, C is the
    core circuit, and K is the set of circuit components to remove.
    '''
    C_excl_K = {k: [head for head in v if head not in K] for k, v in C.items()}
    model = add_mean_ablation_hook(model, abc_dataset, C_excl_K, SEQ_POS_TO_KEEP)
    logits = model(ioi_dataset.toks.long())
    score = logits_to_ave_logit_diff_2(logits, ioi_dataset).item()

    return score


def get_minimality_score(
    model: HookedTransformer,
    ioi_dataset: IOIDataset,
    abc_dataset: IOIDataset,
    v: Tuple[int, int],
    K: Set[Tuple[int, int]],
    C: Dict[str, List[Tuple[int, int]]] = CIRCUIT,
) -> float:
    '''
    Returns the value | F(C \ K_union_v) - F(C | K) |, where F is
    the logit diff, C is the core circuit, K is the set of circuit
    components to remove, and v is a head (not in K).
    '''
    assert v not in K
    K_union_v = K | {v}
    C_excl_K_score = get_score(model, ioi_dataset, abc_dataset, K, C)
    C_excl_Kv_score = get_score(model, ioi_dataset, abc_dataset, K_union_v, C)

    return abs(C_excl_K_score - C_excl_Kv_score)


def get_all_minimality_scores(
    model: HookedTransformer,
    ioi_dataset: IOIDataset = ioi_dataset,
    abc_dataset: IOIDataset = abc_dataset,
    k_for_each_component: Dict = K_FOR_EACH_COMPONENT
) -> Dict[Tuple[int, int], float]:
    '''
    Returns dict of minimality scores for every head in the model (as
    a fraction of F(M), the logit diff of the full model).

    Warning - this resets all hooks at the end (including permanent).
    '''
    # Get full circuit score F(M), to divide minimality scores by
    model.reset_hooks(including_permanent=True)
    logits = model(ioi_dataset.toks.long())
    full_circuit_score = logits_to_ave_logit_diff_2(logits, ioi_dataset).item()

    # Get all minimality scores, using the `get_minimality_score` function
    minimality_scores = {}
    for v, K in tqdm(k_for_each_component.items()):
        score = get_minimality_score(model, ioi_dataset, abc_dataset, v, K)
        minimality_scores[v] = score / full_circuit_score

    model.reset_hooks(including_permanent=True)

    return minimality_scores

In [93]:
minimality_scores = get_all_minimality_scores(model)
print(minimality_scores)

100%|██████████| 26/26 [01:07<00:00,  2.61s/it]

{(9, 9): 0.10926508694524731, (10, 0): 0.07147942254803424, (9, 6): 0.08601778085671401, (10, 7): 0.3815147987681731, (11, 10): 0.3110908708909924, (8, 10): 0.21769024892063726, (7, 9): 0.2174503190295294, (8, 6): 0.20810463714294908, (7, 3): 0.05832533630122596, (5, 5): 0.4074826100041452, (5, 9): 0.23873539003749614, (6, 9): 0.206406483936849, (5, 8): 0.11590877515752733, (0, 1): 0.44827290605522024, (0, 10): 0.4048859500190225, (3, 0): 0.1371713976919335, (4, 11): 0.34689054570989986, (2, 2): 0.14237567690854613, (11, 2): 0.011904883584033191, (10, 6): 0.09304880934129062, (10, 10): 0.11365264678102896, (10, 2): 0.10673927403983563, (9, 7): 0.045173975670094506, (10, 1): 0.0928292837119857, (11, 9): 0.006304046214328032, (9, 0): 0.00504201044440785}





In [94]:
def plot_minimal_set_results(minimality_scores: Dict[Tuple[int, int], float]):
    '''
    Plots the minimality results, in a way resembling figure 7 in the paper.

    minimality_scores:
        Dict with elements like (9, 9): minimality score for head 9.9 (as described
        in section 4.2 of the paper)
    '''

    CIRCUIT_reversed = {head: k for k, v in CIRCUIT.items() for head in v}
    colors = [CIRCUIT_reversed[head].capitalize() + " head" for head in minimality_scores.keys()]
    color_sequence = [px.colors.qualitative.Dark2[i] for i in [0, 1, 2, 5, 3, 6]] + ["#BAEA84"]

    bar(
        list(minimality_scores.values()),
        x=list(map(str, minimality_scores.keys())),
        labels={"x": "Attention head", "y": "Change in logit diff", "color": "Head type"},
        color=colors,
        template="ggplot2",
        color_discrete_sequence=color_sequence,
        bargap=0.02,
        yaxis_tickformat=".0%",
        legend_title_text="",
        title="Plot of minimality scores (as percentages of full model logit diff)",
        width=800,
        hovermode="x unified"
    )

plot_minimal_set_results(minimality_scores)