<a href="https://colab.research.google.com/github/andyrdt/mi/blob/main/ARENA/monthly_algorithmic_problems/08_2023/First_Unique_Character.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Monthly Algorithmic Challenge (August 2023): First Unique Character

This post is the second in the sequence of monthly mechanistic interpretability challenges. They are designed in the spirit of [Stephen Casper's challenges](https://www.lesswrong.com/posts/KSHqLzQscwJnv44T8/eis-vii-a-challenge-for-mechanists), but with the more specific aim of working well in the context of the rest of the ARENA material, and helping people put into practice all the things they've learned so far.

If you prefer, you can access the Streamlit page [here](https://arena-ch1-transformers.streamlit.app/Monthly_Algorithmic_Problems).

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/writer.png" width="350">

## Setup

In [1]:
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys

if IN_COLAB:
    # Install packages
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

    # Code to download the necessary files (e.g. solutions, test funcs)
    import os, sys
    if not os.path.exists("chapter1_transformers"):
        !curl -o /content/main.zip https://codeload.github.com/callummcdougall/ARENA_2.0/zip/refs/heads/main
        !unzip /content/main.zip 'ARENA_2.0-main/chapter1_transformers/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter1_transformers/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter1_transformers", "chapter1_transformers")
        os.rmdir("ARENA_2.0-main")
        os.chdir("chapter1_transformers/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m974.7 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1
Collecting jaxtyping
  Downloading jaxtyping-0.2.21-py3-none-any.whl (25 kB)
Collecting typeguard>=2.13.3 (from jaxtyping)
  Downloading typeguard-4.1.3-py3-none-any.whl (33 kB)
Installing collected packages: typeguard, jaxtyping
Successfully installed jaxtyping-0.2.21 typeguard-4.1.3
Collecting transformer_lens
  Downloading transformer_lens-1.6.0-py3-none-any.whl (105 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.0/106.0 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl (739

In [2]:
import torch as t
from pathlib import Path

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "monthly_algorithmic_problems" / "august23_unique_char"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from monthly_algorithmic_problems.august23_unique_char.dataset import UniqueCharDataset
from monthly_algorithmic_problems.august23_unique_char.model import create_model
from plotly_utils import hist, bar, imshow

device = t.device("cuda" if t.cuda.is_available() else "cpu")

In [3]:
# Andy's imports

import einops
import circuitsvis as cv
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformer_lens.components import LayerNorm
from transformer_lens.hook_points import HookPoint

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression

import torch.nn.functional as F

from jaxtyping import Float, Int, Bool
from torch import Tensor
from typing import List, Tuple, Union, Optional, Callable, Dict

import functools
from tqdm import tqdm
from rich.table import Table, Column
from rich import print as rprint




## Prerequisites

The following ARENA material should be considered essential:

* **[1.1] Transformer from scratch** (sections 1-3)
* **[1.2] Intro to Mech Interp** (sections 1-3)

The following material isn't essential, but is very strongly recommended:

* **[1.2] Intro to Mech Interp** (section 4)
* **[1.4] Balanced Bracket Classifier** (all sections)


## Difficulty

This problem is of roughly comparable difficulty to the June problem. The algorithmic problem is of a similar flavour, and the model architecture is very similar (the main difference is that this model has 3 attention heads per layer, instead of 2). I've done this because this problem is the first I'm also crossposting to LessWrong, and I want it to be reasonably accessible. The next problem in this sequence will probably be a step up in difficulty.


## Motivation

Neel Nanda's post [200 COP in MI: Interpreting Algorithmic Problems](https://www.lesswrong.com/posts/ejtFsvyhRkMofKAFy/200-cop-in-mi-interpreting-algorithmic-problems) does a good job explaining the motivation behind solving algorithmic problems such as these. I'd strongly recommend reading the whole post, because it also gives some high-level advice for approaching such problems.

The main purpose of these challenges isn't to break new ground in mech interp, rather they're designed to help you practice using & develop better understanding for standard MI tools (e.g. interpreting attention, direct logit attribution), and more generally working with libraries like TransformerLens.

Also, they're hopefully pretty fun, because why shouldn't we have some fun while we're learning?

## Logistics

The solution to this problem will be published on this page in the first few days of September, at the same time as the next problem in the sequence. There will also be an associated LessWrong post.

If you try to interpret this model, you can send your attempt in any of the following formats:

* Colab notebook,
* GitHub repo (e.g. with ipynb or markdown file explaining results),
* Google Doc (with screenshots and explanations),
* or any other sensible format.

You can send your attempt to me (Callum McDougall) via any of the following methods:

* The [Slack group](https://join.slack.com/t/arena-la82367/shared_invite/zt-1uvoagohe-JUv9xB7Vr143pdx1UBPrzQ), via a direct message to me
* My personal email: `cal.s.mcdougall@gmail.com`
* LessWrong message ([here](https://www.lesswrong.com/users/themcdouglas) is my user)

**I'll feature the names of everyone who sends me a solution on this website, and also give a shout out to the best solutions.** It's possible that future challenges will also feature a monetary prize, but this is not guaranteed.

Please don't discuss specific things you've found about this model until the challenge is over (although you can discuss general strategies and techniques, and you're also welcome to work in a group if you'd like). The deadline for this problem will be the end of this month, i.e. 31st August.

## What counts as a solution?

Going through the solutions for the previous problem in the sequence (July: Palindromes) as well as the exercises in **[1.4] Balanced Bracket Classifier** should give you a good idea of what I'm looking for. In particular, I'd expect you to:

* Describe a mechanism for how the model solves the task, in the form of the QK and OV circuits of various attention heads (and possibly any other mechanisms the model uses, e.g. the direct path, or nonlinear effects from layernorm),
* Provide evidence for your mechanism, e.g. with tools like attention plots, targeted ablation / patching, or direct logit attribution.
* (Optional) Include additional detail, e.g. identifying the subspaces that the model uses for certain forms of information transmission, or using your understanding of the model's behaviour to construct adversarial examples.

## Task & Dataset

The algorithmic task is as follows: the model is presented with a sequence of characters, and for each character it has to correctly identify the first character in the sequence (up to and including the current character) which is unique up to that point.

The null character `"?"` has two purposes:

* In the input, it's used as the start character (because it's often helpful for interp to have a constant start character, to act as a "rest position").
* In the output, it's also used as the start character, **and** to represent the classification "no unique character exists".

Here is an example of what this dataset looks like:

```python
dataset = UniqueCharDataset(size=2, vocab=list("abc"), seq_len=6, seed=42)

for seq, first_unique_char_seq in zip(dataset.str_toks, dataset.str_tok_labels):
    print(f"Seq = {''.join(seq)}, Target = {''.join(first_unique_char_seq)}")
```

<div style='font-family:monospace;'>
Seq = ?acbba, Target = ?aaaac<br>
Seq = ?cbcbc, Target = ?ccb??
</div><br>

Explanation:

1. In the first sequence, `"a"` is unique in the prefix substring `"acbb"`, but it repeats at the 5th sequence position, meaning the final target character is `"c"` (which appears second in the sequence).
2. In the second sequence, `"c"` is unique in the prefix substring `"cb"`, then it repeats so `"b"` is the new first unique token, and for the last 2 positions there are no unique characters (since both `"b"` and `"c"` have been repeated) so the correct classification is `"?"` (the "null character").

The relevant files can be found in local storage (after you run the setup code at the top of this notebook), at:

```
chapter1_transformers/
└── exercises/
    └── monthly_algorithmic_problems/
        └── august23_unique_char/
            └── august23_unique_char/
                ├── model.py               # code to create the model
                ├── dataset.py             # code to define the dataset
                ├── training.py            # code to training the model
                └── training_model.ipynb   # actual training script
```

We've given you the class `UniqueCharDataset` to store your data, as you can see above. You can slice this object to get batches of tokens and labels (e.g. `dataset[:5]` returns a length-2 tuple, containing the 2D tensors representing the tokens and correct labels respectively). You can also use `dataset.toks` or `dataset.labels` to access these tensors directly, or `dataset.str_toks` and `dataset.str_tok_labels` to get the string representations of the tokens and labels (like we did in the code above).

## Model

Our model was trained by minimising cross-entropy loss between its predictions and the true labels, at every sequence position simultaneously (including the zeroth sequence position, which is trivial because the input and target are both always `"?"`). You can inspect the notebook `training_model.ipynb` to see how it was trained. I used the version of the model which achieved highest accuracy over 40 epochs.



The model is is a 2-layer transformer with 3 attention heads, and causal attention. It includes layernorm, but no MLP layers. You can load it in as follows:

In [4]:
filename = section_dir / "first_unique_char_model.pt"

model = create_model(
    seq_len=20,
    vocab=list("abcdefghij"),
    seed=42,
    d_model=42,
    d_head=14,
    n_layers=2,
    n_heads=3,
    normalization_type="LN",
    d_mlp=None # attn-only model
)

state_dict = t.load(filename)

state_dict = model.center_writing_weights(t.load(filename))
state_dict = model.center_unembed(state_dict)
state_dict = model.fold_layer_norm(state_dict)
state_dict = model.fold_value_biases(state_dict)
model.load_state_dict(state_dict, strict=False);

The code to process the state dictionary is a bit messy, but it's necessary to make sure the model is easy to work with. For instance, if you inspect the model's parameters, you'll see that `model.ln_final.w` is a vector of 1s, and `model.ln_final.b` is a vector of 0s (because the weight and bias have been folded into the unembedding).

In [5]:
print("ln_final weight: ", model.ln_final.w)
print("\nln_final, bias: ", model.ln_final.b)

ln_final weight:  Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1.], device='cuda:0', requires_grad=True)

ln_final, bias:  Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0', requires_grad=True)


<details>
<summary>Aside - the other weight processing parameters</summary>

Here's some more code to verify that our weights processing worked, in other words:

* The unembedding matrix has mean zero over both its input dimension (`d_model`) and output dimension (`d_vocab`)
* All writing weights (i.e. `b_O`, `W_O`, and both embeddings) have mean zero over their output dimension (`d_model`)
* The value biases `b_V` are zero (because these can just be folded into the output biases `b_O`)

```python
W_U_mean_over_input = einops.reduce(model.W_U, "d_model d_vocab -> d_model", "mean")
t.testing.assert_close(W_U_mean_over_input, t.zeros_like(W_U_mean_over_input))

W_U_mean_over_output = einops.reduce(model.W_U, "d_model d_vocab -> d_vocab", "mean")
t.testing.assert_close(W_U_mean_over_output, t.zeros_like(W_U_mean_over_output))

W_O_mean_over_output = einops.reduce(model.W_O, "layer head d_head d_model -> layer head d_head", "mean")
t.testing.assert_close(W_O_mean_over_output, t.zeros_like(W_O_mean_over_output))

b_O_mean_over_output = einops.reduce(model.b_O, "layer d_model -> layer", "mean")
t.testing.assert_close(b_O_mean_over_output, t.zeros_like(b_O_mean_over_output))

W_E_mean_over_output = einops.reduce(model.W_E, "token d_model -> token", "mean")
t.testing.assert_close(W_E_mean_over_output, t.zeros_like(W_E_mean_over_output))

W_pos_mean_over_output = einops.reduce(model.W_pos, "position d_model -> position", "mean")
t.testing.assert_close(W_pos_mean_over_output, t.zeros_like(W_pos_mean_over_output))

b_V = model.b_V
t.testing.assert_close(b_V, t.zeros_like(b_V))
```

</details>

The model's output is a logit tensor, of shape `(batch_size, seq_len, d_vocab+1)`. The `[i, j, :]`-th element of this tensor is the logit distribution for the label at position `j` in the `i`-th sequence in the batch. The first `d_vocab` elements of this tensor correspond to the elements in the vocabulary, and the last element corresponds to the null character `"?"` (which is not in the input vocab).

A demonstration of the model working:


In [6]:
dataset = UniqueCharDataset(size=1000, vocab=list("abcdefghij"), seq_len=20, seed=42)

logits, cache = model.run_with_cache(dataset.toks)

logprobs = logits.log_softmax(-1) # [batch seq_len d_vocab]
probs = logprobs.softmax(-1) # [batch seq_len d_vocab]

batch_size, seq_len = dataset.toks.shape
logprobs_correct = logprobs[t.arange(batch_size)[:, None], t.arange(seq_len)[None, :], dataset.labels] # [batch seq_len]
probs_correct = probs[t.arange(batch_size)[:, None], t.arange(seq_len)[None, :], dataset.labels] # [batch seq_len]

avg_cross_entropy_loss = -logprobs_correct.mean().item()
avg_correct_prob = probs_correct.mean().item()
min_correct_prob = probs_correct.min().item()

print(f"Average cross entropy loss: {avg_cross_entropy_loss:.3f}")
print(f"Average probability on correct label: {avg_correct_prob:.3f}")
print(f"Min probability on correct label: {min_correct_prob:.3f}")

Average cross entropy loss: 0.017
Average probability on correct label: 0.988
Min probability on correct label: 0.001


And a visualisation of its probability output for a single sequence:

In [7]:
imshow(
    probs[0].T,
    y=dataset.vocab,
    x=[f"{dataset.str_toks[0][i]}<br>({i})" for i in range(model.cfg.n_ctx)],
    labels={"x": "Token", "y": "Vocab"},
    xaxis_tickangle=0,
    title="Sample model probabilities (for batch idx = 0), with correct classification highlighted",
    text=[
        ["〇" if str_tok == correct_str_tok else "" for correct_str_tok in dataset.str_tok_labels[0]]
        for str_tok in dataset.vocab
    ], # text can be a 2D list of lists, with the same shape as the data
)

If you want some guidance on how to get started, I'd recommend reading the solutions for the July problem - I expect there to be a lot of overlap in the best way to tackle these two problems. You can also reuse some of that code!


Note - although this model was trained for long enough to get loss close to zero (you can test this for yourself), it's not perfect. There are some weaknesses that the model has which might make it vulnerable to adversarial examples, and I've decided to leave these in. The model is still very good at its intended task, and the main focus of this challenge is on figuring out how it solves the task, not dissecting the situations where it fails. However, you might find that the adversarial examples help you understand the model better.


Best of luck! 🎈

# Andy's work starts here

## Chores

Feel free to skip this section. Just implementing some helper functions.

### Estimating LayerNorm

We examine the LayerNorm behavior for each of the 3 LayerNorms.

From this point forward, we will estimate each LayerNorm operation as a constant scaling.

In [8]:
dataset = UniqueCharDataset(size=1000, vocab=list("abcdefghij"), seq_len=20, seed=42)
logits, cache = model.run_with_cache(dataset.toks)

scale_ln1_layer0 = cache["scale", 0, "ln1"][:, :, 0, 0] # shape (batch, seq)
scale_ln1_layer1 = cache["scale", 1, "ln1"][:, :, 0, 0] # shape (batch, seq)
scale_lnfinal = cache["scale"][:, :, 0] # shape (batch, seq)

for scale, label in zip(
    [scale_ln1_layer0, scale_ln1_layer1, scale_lnfinal],
    ["ln1, layer 0", "ln1, layer 1", "lnfinal"]):

    df = pd.DataFrame({
        "std": scale.std(0).cpu().numpy(),
        "mean": scale.mean(0).cpu().numpy(),
    })

    display(
        px.bar(
            df,
            title=f"Mean & std of layernorm before {label}",
            template="simple_white", width=450, height=300, barmode="group"
        )
    )

### Some helper functions

We define some helper functions that aid with later analysis.

In [9]:
def get_out_by_components(model: HookedTransformer, toks: Int[Tensor, "batch seq"]) -> Float[Tensor, "component batch seq d_model"]:
    '''
    Computes a tensor of shape [8, dataset_size, seq_pos, d_model] representing the output of the model's components when run on the data.
    The first dimension is  [emb, pos_emb, head 0.0, head 0.1, head 0.2, head 1.0, head 1.1, head 1.2]
    '''
    logits, cache = model.run_with_cache(toks)

    emb = cache[utils.get_act_name('embed')].unsqueeze(0)
    pos_emb = cache[utils.get_act_name('pos_embed')].unsqueeze(0)

    head_0 = einops.rearrange(cache[utils.get_act_name('result', 0)], 'b s h d -> h b s d')
    head_1 = einops.rearrange(cache[utils.get_act_name('result', 1)], 'b s h d -> h b s d')

    out = t.cat((emb, pos_emb, head_0, head_1), dim=0)
    return out

In [10]:
def plot_logit_att_by_component(logit_contributions: Float[Tensor, "n seq d_vocab"],
                                vocab: List[str],
                                component_labels: List[str],
                                toks: List[int],
                                str_toks: List[str],
                                labels: List[int],
                                title: str,
                                height: int = 600,
                                width: int = 800) -> go.Figure:
    '''
    Plots the logit attributions for each component.
    Returns a 2x1 subplot, where the first plot allows user to select components
    via a dropdown menu, and the second plot displays the sum of logits over all components.
    '''
    fig = make_subplots(rows=2, cols=1, subplot_titles=(f"Selected component", f"Sum of components"))

    zmax = float(t.abs(logit_contributions).max())
    n_components = logit_contributions.shape[0]

    for component in range(n_components):
        fig.add_trace(
            go.Heatmap(
                z=logit_contributions[component].T.detach().cpu().numpy(),
                colorscale="RdBu",
                zmin=-zmax, zmid=0, zmax=zmax,
                y=vocab,
                x=[f"{str_toks[i]}({i})" for i in range(len(str_toks))],
                visible=(component==0)
            ),
            row=1, col=1
        )

    fig.add_trace(
        go.Heatmap(
            z=logit_contributions.sum(axis=0).T.detach().cpu().numpy(),
            colorscale="RdBu",
            zmin=-zmax, zmid=0, zmax=zmax,
            y=vocab,
            x=[f"{str_toks[i]}({i})" for i in range(len(str_toks))],
            visible=True,
        ),
        row=2, col=1
    )

    for r in [1, 2]:
        fig.update_xaxes(title_text="Position", row=r, col=1)
        fig.update_yaxes(title_text="Output logit", row=r, col=1)
        fig.update_yaxes(autorange="reversed", row=r, col=1)

    label_annotation_array = [
        ["〇" if tok == correct_tok else "" for correct_tok in labels] for tok in range(len(dataset.vocab))]

    for i, row in enumerate(label_annotation_array):
        for j, val in enumerate(row):
            if val:  # Only add the annotation if val is not an empty string
                fig.add_annotation(
                    go.layout.Annotation(
                        text=val,
                        x=j,
                        y=dataset.vocab[i],
                        xref='x2',
                        yref='y2',
                        showarrow=False,
                        font=dict(color="black",size=10)))
                fig.add_annotation(
                    go.layout.Annotation(
                        text=val,
                        x=j,
                        y=dataset.vocab[i],
                        xref='x1',
                        yref='y1',
                        showarrow=False,
                        font=dict(color="black",size=10)))

    # Create the dropdown menu
    buttons = []
    for i, comp in enumerate(range(n_components)):
        visibility = [False] * (n_components + 1)
        visibility[i] = True
        visibility[-1] = True  # Always show the sum
        buttons.append(
            dict(label=component_labels[i],
                method='update',
                args=[{'visible': visibility}])
        )

    # Add dropdown to the figure
    fig.update_layout(
        updatemenus=[{
            'buttons': buttons,
            'direction': 'down',
            'showactive': True,
            'x': 1.05,
            'xanchor': 'left',
            'y': 1.15,
            'yanchor': 'top'
        }],
        title_text=title,
        height=600, width=800
    )
    fig.update_xaxes(tickangle=90)

    return fig

In [11]:
def print_repeat_info(dataset, data_index, position_index, logits):
    table = Table("Token", "Occurences", "Repeated?", "Logit", title=f"Logits for example {data_index}, position {position_index}")

    for i, c in enumerate(dataset.vocab):
        occurrences = t.where(dataset.toks[data_index][:position_index+1] == i, 1, 0).sum()
        repeat = occurrences > 1
        correct = i == dataset.labels[data_index][position_index]

        style = "white"
        if repeat:
            style = "red1"
        if correct:
            style = "green"
        table.add_row(c, repr(occurrences.item()), repr(repeat.item()), f"{logits[data_index][position_index][i].item() :.3f}", style=style)
    print()
    rprint(table)
    print("Sequence: " + ''.join(dataset.str_toks[data_index]))
    print("Correct label: " + dataset.str_tok_labels[data_index][position_index])

def print_sequence(dataset, data_index):
    print(f"\n### EXAMPLE {data_index} ###")
    for i, c in enumerate(dataset.str_toks[data_index]):
        print(f"|{i: <2}", end="| ")
    print()
    for i, c in enumerate(dataset.str_toks[data_index]):
        print(f"|{c: <2}", end="| ")
    print("\n")

## Layer 0 QK

### Observing Layer 0 attention patterns

We'll start our analysis by observing the attention patterns of Layer 0 heads on a handful of examples.

In [12]:
dataset = UniqueCharDataset(size=1000, vocab=list("abcdefghij"), seq_len=20, seed=42)
logits, cache = model.run_with_cache(dataset.toks)

attn_patterns = cache[utils.get_act_name("pattern", 0)]

for i in range(5):
    display(
        cv.attention.attention_patterns(
          tokens=dataset.str_toks[i],
          attention=attn_patterns[i],
          attention_head_names=[f"{l}.{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)])
    )

From staring at these activation patterns, we make the following observations:

- **Head 0.0** - "*anti-duplicate detector*"
  - Seems to be attend ~evenly across tokens that are not duplicates of the current token
- **Head 0.1** - "*duplicate detector*"
  - Seems to attend to tokens that are duplicates of the current token
  - It also seems to attend to earlier positions (including position 0) more strongly than the current position
- **Head 0.2** - "*hybrid detector*"
  - This head looks like a hybrid between an anti-duplicate detector and a duplicate detector
    - For some tokens, it behaves as an anti-duplicate detector
    - For other tokens, it behaves as a duplicate detector
- It looks like the behaviors noted above are inconsistent across tokens
  - H0.0 seems to pay some attention to duplicate `h` and `j` tokens in the 3rd example
  - H0.1 seems not to attend to duplicate `a` tokens in the 1st and 3rd examples
- Position 1 seems like a special case
  - All 3 Layer 0 heads seem to attend strongly to the token at position 1, and it should likely be treated as a special case

By "*duplicate detector*" I mean a head that pays attention to instances of the current token.

By "*anti-duplicate detector*" I mean a head that pays little or no attention to instances of the current token, and weighs attention fairly evenly over the remaining tokens.

#### An illustrative example

To solidify and refine this understanding, we examine the attention patterns for the simple input string `?aabbccddeeffgghhiij`.

It contains adjacent duplicates of each token (except for `j`, since there is a 20 character limit), and so we will be able to see each head's behavior on each token duplicate.

In [13]:
custom_toks = t.tensor([[10, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]])
custom_toks_str = [list("?aabbccddeeffgghhiij")]

logits, cache = model.run_with_cache(custom_toks)
attn_patterns = cache[utils.get_act_name("pattern", 0)]
for i in range(custom_toks.shape[0]):
    display(
        cv.attention.attention_patterns(
          tokens=custom_toks_str[i],
          attention=attn_patterns[i],
          attention_head_names=[f"{0}.{h}" for h in range(model.cfg.n_heads)])
    )

This example suggests that the behavior of each Layer 0 head differs across specific tokens.

- **Head 0.0**
  - Anti-duplicate detector for `{a, b, c, d, g, i}`
  - Neutral for `{e, f, h, j}`
  - <img src="https://github.com/andyrdt/mi/blob/main/ARENA/monthly_algorithmic_problems/08_2023/assets/h00_attn.png?raw=true" width="200">
- **Head 0.1**
  - Duplicate detector for `{b, c, d, e, f, g, h, i, j}`
  - Anti-duplicate detector for `{a}`
  - <img src="https://github.com/andyrdt/mi/blob/main/ARENA/monthly_algorithmic_problems/08_2023/assets/h01_attn.png?raw=true" width="200">
- Head 0.2
  - Duplicate detector for `{a, b, g}`
  - Anti-duplicate detector for `{c, e, f, h, j}`
  - Neutral for `{d, i}`
  - <img src="https://github.com/andyrdt/mi/blob/main/ARENA/monthly_algorithmic_problems/08_2023/assets/h02_attn.png?raw=true" width="200">

### Layer 0 QK circuits - embedding

We can confirm the above understanding by examining the QK circuits of the Layer 0 heads.

For $h \in \{ 0, 1, 2\}$, we will look at the circuit
$$W_{emb} W_{QK}^{0.h} W_{emb}^T$$

This circuit is a `(vocab_size x vocab_size)` matrix. The $(i, j)^{th}$ entry represents how much source token embedding $i$ attends to target token embedding $j$ in head $0.h$.  

We expect the anti-duplicate detectors to show up as anti-diagonal patterns (token $i$ *does not* attend to the same token $i$), and the duplicate detectors show up as diagnoal (token $i$ does attend to the same token $i$).

In [14]:
# visualize layer 0 QK embedding circuits
layer = 0

W_emb = model.W_E
W_emb_scaled = W_emb / scale_ln1_layer0.mean()

for head in range(model.cfg.n_heads):
    W_QK = model.W_Q[layer, head] @ model.W_K[layer, head].T / model.cfg.d_head**0.5
    QK_full_emb = W_emb_scaled[:-1] @ W_QK @ W_emb_scaled.T

    fig = px.imshow(
        # QK_full_emb.detach().cpu().numpy(),
        t.softmax(QK_full_emb, dim=-1).detach().cpu().numpy(),
        title=f"Head {layer}.{head} QK circuit (embed)",
        labels={"x": "Key", "y": "Query", "color": "QK weight"},
        width=400,
        height=400,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0,
        x=dataset.vocab[:],
        y=dataset.vocab[:-1]
    )

    display(fig)

These visualizations confirm our hypotheses.

To summarize:
- **Head 0.0**
  - Anti-duplicate detector (anti-diagonal) for `{a, b, c, d, g, i}`
- **Head 0.1**
  - Duplicate detector (diagonal) for `{b, c, d, e, f, g, h, i, j}`
  - Anti-duplicate detector (anti-diagonal) for `{a}`
- **Head 0.2**
  - Duplicate detector for `{a, b, g}`
  - Anti-duplicate detector for `{c, e, f, h, j}`


In fact, we can write a function to retreive these mappings:

In [15]:
from collections import defaultdict

def get_L0_detectors(model: HookedTransformer, dataset: UniqueCharDataset):
    tok_to_dup = defaultdict(list)
    dup_to_toks = defaultdict(list)
    tok_to_antidup = defaultdict(list)
    antidup_to_toks = defaultdict(list)

    W_emb = model.W_E
    W_emb_scaled = W_emb / scale_ln1_layer0.mean()

    for head in range(model.cfg.n_heads):
        W_QK = model.W_Q[0, head] @ model.W_K[0, head].T / model.cfg.d_head**0.5
        QK_full_emb = W_emb_scaled[:-1] @ W_QK @ W_emb_scaled[:-1].T
        QK_full_emb = t.softmax(QK_full_emb, dim=-1)

        for i, v in enumerate(dataset.vocab[:-1]):
            if QK_full_emb[i, i] > 0.30:
                tok_to_dup[i].append(head)
                dup_to_toks[head].append(i)
            if QK_full_emb[i, i] < 0.04:
                tok_to_antidup[i].append(head)
                antidup_to_toks[head].append(i)
    return dict(tok_to_dup), dict(dup_to_toks), dict(tok_to_antidup), dict(antidup_to_toks)

In [16]:
tok_to_dup, dup_to_toks, tok_to_antidup, antidup_to_toks = get_L0_detectors(model, dataset)
print("Duplicate detectors:")
for k, v in sorted(dup_to_toks.items()):
    tokens = ', '.join(dataset.vocab[t] for t in v)
    print(f"  H0.{k}: {tokens}")
print("Anti-duplicate detectors:")
for k, v in sorted(antidup_to_toks.items()):
    tokens = ', '.join(dataset.vocab[t] for t in v)
    print(f"  H0.{k}: {tokens}")

Duplicate detectors:
  H0.1: b, c, d, e, f, g, h, i, j
  H0.2: a, b, g
Anti-duplicate detectors:
  H0.0: a, b, c, d, g, i
  H0.1: a
  H0.2: c, e, f, h, j


### Layer 0 QK circuits - position

From the above analysis, it's pretty clear that the embedding QK circuit $W_{emb} W_{QK}^{0.h} W_{emb}^T$ plays a large role in determining Layer 0 attention patterns.

However, it is not the full story. Positional information also plays a role through the following pathways:
- $W_{pos} W_{QK}^{0.h} W_{pos}^T$
- $W_{pos} W_{QK}^{0.h} W_{emb}^T$
- $W_{emb} W_{QK}^{0.h} W_{pos}^T$

We can visualize each pathway's circuit:

In [17]:
layer = 0

W_E_pos = t.concat([model.W_E, model.W_pos], dim=0)
W_E_pos_scaled = W_E_pos / scale_ln1_layer0.mean()

n_vocab = len(dataset.vocab)

W_E_labels = dataset.vocab
W_pos_labels = [f"[{i}]" for i in range(20)]

for head in range(model.cfg.n_heads):
    W_QK = model.W_Q[layer, head] @ model.W_K[layer, head].T / model.cfg.d_head**0.5

    QK_full = W_E_pos_scaled @ W_QK @ W_E_pos_scaled.T

    # make pos x pos part lower triangular
    QK_full[n_vocab:, n_vocab:] = t.tril(QK_full[n_vocab:, n_vocab:])

    fig = px.imshow(
        QK_full.detach().cpu().numpy(),
        title=f"Head {layer}.{head} QK circuit",
        labels={"x": "Key", "y": "Query", "color": "QK weight"},
        width=700, height=700,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0,
        x=W_E_labels + W_pos_labels,
        y=W_E_labels + W_pos_labels,
    )
    fig.add_shape(
        type="line",
        x0=len(W_E_labels)-0.5, x1=len(W_E_labels)-0.5,
        y0=0-0.5, y1=QK_full.size(0)-0.5,
        line=dict(color="Black", width=1.5)
    )
    fig.add_shape(
        type="line",
        x0=0-0.5, x1=QK_full.size(1)-0.5,
        y0=len(W_E_labels)-0.5, y1=len(W_E_labels)-0.5,
        line=dict(color="Black", width=1.5)
    )
    display(fig)

Some observations:
- Upper-right quadrant `(Q=emb, K=pos)`
  - For all 3 heads, we can see that there is a general bias towards paying attention to earlier positions over later positions
- Lower-right quadrant `(Q=pos, K=pos)`
  - H0.0 and H0.2 have an extremely high positive score for `(Q=1, K=1)`
  - For H0.1, there is stronger attention to previous positions than to the current position (at least for the first 10 positions or so)
- Lower-left quadrant `(Q=pos, K=emb)`
  - H0.0 and H0.2 again have interesting entries for `Q_pos=1` - it looks like they attend to particular tokens in this special case
  - For all 3 heads, biases attention towards the `?` token

### Relative importance

We approximate the relative importance of each of these pathways by computing the norm of each empirical contribution to attention scores.

In [18]:
labels = ['(Q=emb, K=emb)', '(Q=emb, K=pos)', '(Q=pos, K=emb)', '(Q=pos, K=pos)']

dataset = UniqueCharDataset(size=1000, vocab=list("abcdefghij"), seq_len=20, seed=42)
logits, cache = model.run_with_cache(dataset.toks)

components = cache.decompose_resid(layer=0, incl_embeds=True)
components = components / cache[utils.get_act_name('scale', 0, 'ln1')][:, :, 0, :]

q_by_components = einops.einsum(components, model.W_Q[0], 'n batch seq d_model, n_head d_model d_head -> n n_head batch seq d_head')
k_by_components = einops.einsum(components, model.W_K[0], 'n batch seq d_model, n_head d_model d_head -> n n_head batch seq d_head')

attn_by_components = einops.einsum(q_by_components, k_by_components, 'nq n_head batch seqq d_head, nk n_head batch seqk d_head -> nq nk n_head batch seqq seqk')
attn_by_components = einops.rearrange(attn_by_components, 'nq nk n_head batch seqq seqk -> (nq nk) n_head batch seqq seqk')

for head in range(model.cfg.n_heads):
    data = []
    for component, label in enumerate(labels):
        for pos in range(1, 20):
            attn_by_components_norm = attn_by_components[component, head, :, pos].reshape((1000, -1)).norm(dim=-1)
            mean = attn_by_components_norm.mean().item()
            std  = attn_by_components_norm.std().item()

            data.append({
                'Position': pos,
                'Components': label,
                'Mean': mean,
                'Std': std
            })

    df = pd.DataFrame(data)

    fig = px.scatter(
        df,
        x='Position',
        y='Mean',
        color='Components',
        title=f'Head 0.{head} attention pattern norms by (Q, K) component',
        # error_y='Std',
        labels={'Mean': 'Attention pattern norm (mean)', 'Position': 'Position'},
        width=700, height=400,
    )
    fig.update_traces(mode='lines+markers')
    fig.show()

Observations:
- `(Q=emb, K=emb)`
  - Overall the strongest impact over across all positions for all 3 heads
- `(Q=pos, K=pos)` and `(Q=pos, K=emb)`
  - Extremely strong impact for H0.0 and H0.2 at postition 1
  - Significant impact for all 3 heads in early positions (up until position 10 or so), and then they drop off
- `(Q=emb, K=pos)`
  - Has fairly significant impact for all 3 heads across all positions

### Summary of Layer 0 QK analysis

- L0 heads serve as "duplicate detectors" and "anti-duplicate detectors" for particular token embeddings.
  - A head is a "duplicate detector" for token embedding `t` if, from `t`, it attends strongly to `t`.
  - A head is an "anti-duplicate detector" for token embedding `t` if, from `t`, it does not attend to `t` and instead attends to other token embeddings ~evenly.
- Positional information also play a role, but is not as significant embedding information
  - Positional information generally biases L0 heads towards attending to earlier positions, and torwards attending to `emb=?` / `pos=0`
  - L0 attends to previous positions more strongly than its own position

Here is a table summarizing duplicate detectors and anti-duplicate detectors:

>|          | **Duplicate detector**        | **Anti-duplicate detector** |
|----------|-------------------------------|-----------------------------|
| **H0.0** | `{}`                          | `{a, b, c, d, g, i}`        |
| **H0.1** | `{b, c, d, e, f, g, h, i, j}` | `{a}`                       |
| **H0.2** | `{a, b, g}`                   | `{c, e, f, h, j}`           |

## Direct logit attribution

### Observing direct logit attribution for individual examples

We'll start by taking a handful of examples and looking at each head's direct contribution to each token's logit.

Suppose we have $x_{post}$ (a row vector) as the post-final-LayerNorm residual stream. The logit contribution to token $i$ is given by
$$(x_{post} W_U)_i = x_{post} W_U[:, i]$$

We will estimate $x_{post}$ from the pre-final-LayerNorm residual stream $x_{pre}$ by approximating the final LayerNorm operation as a multiplication by scalar $\alpha_f \in \mathbb{R}$: $$x_{post} \approx \alpha_{f} x_{pre}$$

Decomposing pre-final-LayerNorm residual stream $x_{pre}$, we have
$$x_{post} \approx \alpha_{f} x_{pre} = \alpha_{f} (emb + pos + h_{0.0} + h_{0.1} + h_{0.2} + h_{1.0} + h_{1.1} + h_{1.2}) := \sum_{j=1}^{8} \alpha_{f} y_j$$

We can compute the contribution of each component $y_j$ on each vocabulary word $i$'s logit value: $\alpha_{f} y_j W_U[:, i]$.

We will take a look at a handful of examples, and visualize the logit contributions of each component on the *last* token prediction. In particular, we expect to see significant negative logit contributions for duplicated tokens.

In [19]:
num_examples = 3
current_pos  = 19

logits, cache = model.run_with_cache(dataset.toks[:num_examples])

emb_pos_components = cache.decompose_resid(layer=0, incl_embeds=True)
attn_components    = cache.stack_head_results(layer=-1)
components = t.concat((emb_pos_components, attn_components), dim=0)
components = components / cache[utils.get_act_name('scale')]
labels = ['embed', 'pos_embed', 'H0.0', 'H0.1', 'H0.2', 'H1.0', 'H1.1', 'H1.2']

for current_ex in range(num_examples):
    cur_components = components[:, current_ex, current_pos, :] # (n, d_model)
    logit_contributions_by_components = einops.einsum(cur_components, model.W_U, 'n d_model, d_model d_vocab -> n d_vocab')

    print_repeat_info(dataset=dataset, data_index=current_ex, position_index=current_pos, logits=logits)
    display(
        px.imshow(
            logit_contributions_by_components.detach().cpu().numpy(),
            title=f"Direct logit attribution for ex {current_ex}, pos {current_pos}",
            labels={"x": "Output logit", "y": "Component", "color": "Logit attribution"},
            width=500, height=400,
            color_continuous_scale="RdBu", color_continuous_midpoint=0,
            x=dataset.vocab, y=labels,
        )
    )




Sequence: ?chgegfaeadieaebcffh
Correct label: d





Sequence: ?gjgdbjdbjhjcafjdejg
Correct label: h





Sequence: ?cagchjhddedhajajgjf
Correct label: e


From the above visualizations, we make the following observations:
- Almost all significant direct logit contributions (at least for predicting position `19`) come from the Layer 1 heads
  - This entails both large positive contributions and large negative contributions
- Repeated tokens are heavily penalized with a dramatic negative logit contribution
- Different heads seem to be "responsible for" different tokens
  - **H1.0** seems to be responsible for `{a, c}`
  - **H1.1** seems to be responsible for `{d, e, f, j}`
  - **H1.2** seems to be responsible for `{b, g, h, i}`

By saying "H1.x is responsible for token `t`", I mean that head H1.x is primarily responsible for decreasing the logit value for `t` when it is a duplicate (and perhaps primarily responsible for increasing the logit value when it is the correct answer, but this is not so clear yet).

### Direct logit attribution by position

The above observations only examine position `19`. We saw that L1 heads dominate logit contributions, while L0 heads do not contribute much.

Let's see if this observation holds at other positions. We can check this by computing the overall correct logit attribution for each component at every position.

In [20]:
logits, cache = model.run_with_cache(dataset.toks)

emb_pos_components = cache.decompose_resid(layer=0, incl_embeds=True)
attn_components    = cache.stack_head_results(layer=-1)
components = t.concat((emb_pos_components, attn_components), dim=0)
components = components / cache[utils.get_act_name('scale')]
labels = ['embed', 'pos_embed', 'H0.0', 'H0.1', 'H0.2', 'H1.0', 'H1.1', 'H1.2']

correct_W_U_directions = model.W_U.T[dataset.labels]

correct_logit_contribution_by_components = einops.einsum(
    components,
    correct_W_U_directions,
    'n batch seq d_model, batch seq d_model -> n batch seq')

correct_logit_contribution_mean = correct_logit_contribution_by_components.mean(dim=1) # (n, seq)
correct_logit_contribution_std  = correct_logit_contribution_by_components.std(dim=1) # (n, seq)

data = []
for i, label in enumerate(labels):
    data.extend([{
        'X': x,
        'Y': correct_logit_contribution_mean[i, x].cpu().detach(),
        'Y_std': correct_logit_contribution_std[i, x].cpu().detach(),
        'Label': label
    } for x in range(1, 20)])

df = pd.DataFrame(data)

fig1 = px.scatter(df, x='X', y='Y', color='Label',
                title=f'Correct logit attribution by component (mean)',
                labels={'X': 'Position', 'Y': 'Correct logit attribution (mean)', 'Label': 'Component'},
                width=600, height=400)
fig1.update_traces(mode='lines+markers')
display(fig1)

# fig2 = px.scatter(df, x='X', y='Y_std', color='Label',
#                 title=f'Correct logit attribution by component (std)',
#                 labels={'X': 'Position', 'Y_std': 'Correct logit attribution (std)', 'Label': 'Component'},
#                 width=600, height=400)
# fig2.update_traces(mode='lines+markers')
# display(fig2)

This analysis confirms that the L1 heads are by far the most significant components with respect to direct logit attribution.

There is a notable exception for position `1`, where heads H0.0 and H0.2 also play a significant role in direct logit attribution.

### Direct logit attribution by duplicate token

We previously observed that each L1 head seems to be "responsible" for a particular set of tokens. When a token is duplicated, its responsible head decreases its logit value.

To see this more clearly, we can isolate duplicate tokens from the dataset, and measure each head's logit contribution torwards outputing the duplicated token.

In [21]:
dataset = UniqueCharDataset(size=1000, vocab=list("abcdefghij"), seq_len=20, seed=42)
seq_len = dataset.toks.shape[-1]

# flag all duplicates in dataset
is_dup = t.zeros((dataset.size, seq_len), dtype=t.bool)
for i in range(seq_len):
    current_col = dataset.toks[:, i]
    for j in range(i):
        is_dup[:, i] = is_dup[:, i] | (dataset.toks[:, j] == dataset.toks[:, i])

dup_W_U_directions = model.W_U.T[dataset.toks] # (batch, seq, d_model)

components = cache.stack_head_results(layer=-1)[-3:]
components = components / cache[utils.get_act_name('scale')]
labels = ['H1.0', 'H1.1', 'H1.2']

dup_logit_contribution_by_components = einops.einsum(
    components,
    dup_W_U_directions,
    'n batch seq d_model, batch seq d_model -> n batch seq')

dup_logit_contributions_mean = t.zeros((len(labels), len(dataset.vocab)))
dup_logit_contributions_std  = t.zeros((len(labels), len(dataset.vocab)))

for h, _ in enumerate(labels):
    for i, _ in enumerate(dataset.vocab): # split examples by token
        dup_logit_contributions = dup_logit_contribution_by_components[h][is_dup & (dataset.toks == i)]
        dup_logit_contributions_mean[h, i] = dup_logit_contributions.mean()
        dup_logit_contributions_std[h, i]  = dup_logit_contributions.std()

data = []
for h, label in enumerate(labels):
    data.extend([{
        'X': v,
        'Y': dup_logit_contributions_mean[h, i].cpu().detach(),
        'Y_std': dup_logit_contributions_std[h, i].cpu().detach(),
        'Label': label
    } for i, v in enumerate(dataset.vocab[:-1])])

df = pd.DataFrame(data)

fig = px.bar(df, x='X', y='Y', color='Label',
             title='Duplicate token logit attribution',
             labels={'X': 'Duplicate token', 'Y': 'Duplicate token logit attribution (mean)'},
             #error_y='Y_std',
             width=900, height=500)
fig.update_layout(barmode='group')
display(fig)

This chart is very clear - each token has one L1 head that is responsible for decreasing its logit value when it is duplicated.
- **H1.0**: `{a, c}`
- **H1.1**: `{d, e, f, j}`
- **H1.2**: `{b, g, h, i}`

### Direct logit attribution by correct token

We now know that each L1 head is responsible for *decreasing* logit values for some set of tokens when duplicated.

It may be reasonable to check if each L1 head is also primarily responsible for *increasing* logit values for some set of tokens when correct.

We can visualize this by splitting the dataset by correct label tokens, and computing each L1 head's contribution.

In [22]:
dataset = UniqueCharDataset(size=1000, vocab=list("abcdefghij"), seq_len=20, seed=42)
seq_len = dataset.toks.shape[-1]

correct_W_U_directions = model.W_U.T[dataset.labels] # (batch, seq, d_model)

components = cache.stack_head_results(layer=-1)[-3:]
components = components / cache[utils.get_act_name('scale')]

labels = ['H1.0', 'H1.1', 'H1.2']

correct_logit_contribution_by_components = einops.einsum(
    components,
    correct_W_U_directions,
    'n batch seq d_model, batch seq d_model -> n batch seq')

correct_logit_contributions_mean = t.zeros((len(labels), len(dataset.vocab)))
correct_logit_contributions_std  = t.zeros((len(labels), len(dataset.vocab)))

for h, _ in enumerate(labels):
    for i, _ in enumerate(dataset.vocab): # split examples by token
        correct_logit_contributions = correct_logit_contribution_by_components[h][dataset.labels == i]
        correct_logit_contributions_mean[h, i] = correct_logit_contributions.mean()
        correct_logit_contributions_std[h, i]  = correct_logit_contributions.std()

data = []
for h, label in enumerate(labels):
    data.extend([{
        'X': v,
        'Y': correct_logit_contributions_mean[h, i].cpu().detach(),
        'Y_std': correct_logit_contributions_std[h, i].cpu().detach(),
        'Label': label
    } for i, v in enumerate(dataset.vocab)])

df = pd.DataFrame(data)

fig = px.bar(df, x='X', y='Y', color='Label',
             title='Correct token logit attribution',
             labels={'X': 'Correct token', 'Y': 'Correct token logit attribution (mean)'},
             #error_y='Y_std',
             width=900, height=500)
fig.update_layout(barmode='group')
display(fig)

Although not as clear as the duplicate token logit attribution, this visualization suggests roughly that each token has one head that is *primarily* responsible for it.
- H1.0: `{a, c}`
- H1.1: `{d, e, f, j}`
- H1.2: `{b, g, h, i}`
- Responsibility for the null token `?` is split evenly

Note that these are the same responsibility sets that we found from the duplicate token analysis!

Some of head-to-token responsibilities are clearer than others. For example, it seems very clear that H1.0 is most responsible for increasing `a` and `c` logits when they are the correct answer, while increasing `j` logits seems like more of a team effort across the three heads.

### Summary of direct logit attribution analysis

- L1 heads are essentially entirely responsible for direct logit contributions
  - Except in the case of position `1`, where H0.0 and H0.2 also have significant direct logit contributions
- Each L1 head has a set of tokens that it is "responsible for"
  - If L1 head is "responsible for token `t`", then it is exclusively responsible for decreasing the logit value for `t` when it is a duplicate
  - Responsible heads also seem to be primarily responsible for increasing logit values of correct tokens in their responsibility set.
    - However, this mechanism is less clear cut, and it seems like the L1 heads generally work together to increase correct logit values.

Here is a table summarizing each L1 head's responsibility set:

>|          | **Responsibility set** |
|----------|------------------------|
| **H1.0** | `{a, c}`               |
| **H1.1** | `{d, e, f, j}`         |
| **H1.2** | `{b, g, h, i}`         |


## Ablation

### Head ablations

From our analysis so far, we have a sense that every head is playing role in the mechanism.

We can quickly confirm this by ablating each head and checking the resulting difference in loss.

In [23]:
def head_ablation_hook(
    attn_result: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
    ablation_type: str, # either 'zero' or 'mean'
    head_index_to_ablate: Optional[int] = None, # if -1, then ablate all heads at this layer
) -> Float[Tensor, "batch seq n_heads d_model"]:
    if ablation_type == 'zero':
        if head_index_to_ablate == None:
            attn_result[:, :, :, :] = 0
        else:
            attn_result[:,:,head_index_to_ablate,:] = 0
    elif ablation_type == 'mean':
        if head_index_to_ablate == None:
            attn_result[:, :, :, :] = attn_result.mean(0, keepdim=True)
        else:
            attn_result[:,:,head_index_to_ablate,:] = attn_result[:,:,head_index_to_ablate].mean(0, keepdim=True)
    return attn_result

In [24]:
def get_loss(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    labels: Int[Tensor, "batch seq_len"]
) -> Float:
    logprobs = logits.log_softmax(-1) # [batch seq_len d_vocab]
    batch_size, seq_len = labels.shape
    logprobs_correct = logprobs[t.arange(batch_size)[:, None], t.arange(seq_len)[None, :], labels] # [batch seq_len]
    avg_cross_entropy_loss = -logprobs_correct.mean().item()

    return avg_cross_entropy_loss

def get_loss_at_k(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    labels: Int[Tensor, "batch seq_len"],
    k: Int
) -> Float:
    seq_len = logits.shape[1]
    return get_loss(logits[:, k, :].unsqueeze(1), labels[:, k].unsqueeze(-1))

In [25]:
dataset = UniqueCharDataset(size=1000, vocab=list("abcdefghij"), seq_len=20, seed=42)

model.reset_hooks()
logits_no_ablation = model(dataset.toks, return_type="logits")
loss_no_ablation = get_loss(logits_no_ablation, dataset.labels)
# print(f"loss_no_ablation: {loss_no_ablation}")

# note - both zero and mean ablation yield similar results, so just display zero
for ablation_type in ['zero']: #, 'mean']:
    ablation_scores = t.zeros((model.cfg.n_layers, model.cfg.n_heads))
    ablation_scores_per_pos = t.zeros((model.cfg.n_layers, model.cfg.n_heads, 20))
    ablation_scores_per_layer = t.zeros((model.cfg.n_layers, 20))
    for layer in range(model.cfg.n_layers):
        for head in list(range(model.cfg.n_heads)) + [None]:
            temp_hook_fn = functools.partial(
                head_ablation_hook,
                head_index_to_ablate=head,
                ablation_type=ablation_type)
            ablated_logits = model.run_with_hooks(
                dataset.toks,
                return_type="logits",
                fwd_hooks=[(utils.get_act_name("result", layer), temp_hook_fn)]
            )

            for k in range(20):
                ablated_loss_k = get_loss_at_k(ablated_logits, dataset.labels, k=k)
                if head == None:
                    ablation_scores_per_layer[layer, k] = ablated_loss_k - loss_no_ablation
                    continue
                ablation_scores_per_pos[layer, head, k] = ablated_loss_k - loss_no_ablation

            if head == None:
                continue

            ablated_loss = get_loss(ablated_logits, dataset.labels)
            ablation_scores[layer, head] = ablated_loss - loss_no_ablation

    # display ablation scores, mean across position
    fig = px.imshow(
        ablation_scores.cpu().numpy(),
        title=f"Ablation loss diff, ablation_type: {ablation_type}",
        labels={"x": "Head", "y": "Layer"},
        text_auto=".2f",
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0,
        x=[f"{head}" for head in range(model.cfg.n_heads)],
        y = [f"{layer}" for layer in range(model.cfg.n_layers)],
        width=700, height=400)
    fig.update_layout(coloraxis_showscale=False)
    display(fig)

    # display ablation scores, per position
    df = pd.DataFrame()
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            df_temp = pd.DataFrame({
                'X': range(20),
                'Y': ablation_scores_per_pos[layer, head],
                'Label': [f'H{layer}.{head}'] * 20
            })
            df = pd.concat([df, df_temp])

    fig = px.scatter(df, x='X', y='Y', color='Label',
                    title=f'Ablation loss diff by position, ablation_type: {ablation_type}',
                    labels={'X': 'Position', 'Y': 'Logit diff', 'Label': 'Head'},
                    width=700, height=400)
    fig.update_traces(mode='lines+markers')
    display(fig)

    # display ablation scores, all heads in layer
    df = pd.DataFrame()
    for layer in range(model.cfg.n_layers):
            df_temp = pd.DataFrame({
                'X': range(20),
                'Y': ablation_scores_per_layer[layer],
                'Label': [f'L{layer}'] * 20
            })
            df = pd.concat([df, df_temp])

    fig = px.scatter(df, x='X', y='Y', color='Label',
                    title=f'Ablation loss diff by position, ablation_type: {ablation_type} (all heads in layer)',
                    labels={'X': 'Position', 'Y': 'Logit diff', 'Label': 'Layer'},
                    width=700, height=400)
    fig.update_traces(mode='lines+markers')
    display(fig)

Some observations:
- All heads play a significant role in the mechanism - ablating any one head leads to significant increase in loss
- H0.1 becomes more critical as position increases
  - I speculate this is because the probability of a duplicate existing increases as position increases, and H0.1 is critical to detecting duplicates
- H0.0 and H0.2 seem to be primarily responsible for getting positions `1` and `2` correct
  - Even when all L1 heads are ablated, loss is not too high for positions `1` and `2`

### QKV ablations

We can do more fine-grained ablations on Q, K, and V activations of each head.

In [26]:
def qkv_ablation_hook(
    qkv: Float[Tensor, "batch seq n_heads d_head"],
    hook: HookPoint,
    head_index_to_ablate: int,
    ablation_type: str, # either 'zero' or 'mean'
) -> Float[Tensor, "batch seq n_heads d_model"]:
    if ablation_type == 'zero':
        qkv[:,:,head_index_to_ablate,:] = 0
    elif ablation_type == 'mean':
        qkv[:,:,head_index_to_ablate] = qkv[:,:,head_index_to_ablate].mean(0, keepdim=True)
    return qkv

In [27]:
dataset = UniqueCharDataset(size=1000, vocab=list("abcdefghij"), seq_len=20, seed=42)

model.reset_hooks()
logits_no_ablation = model(dataset.toks, return_type="logits")
loss_no_ablation = get_loss(logits_no_ablation, dataset.labels)

activation_types = ['q', 'k', 'v']

for ablation_type in ['zero', 'mean']:
    ablation_scores = t.zeros((len(activation_types), model.cfg.n_layers, model.cfg.n_heads))
    for i, activation_type in enumerate(activation_types):
        for layer in range(model.cfg.n_layers):
            for head in range(model.cfg.n_heads):
                temp_hook_fn = functools.partial(
                    qkv_ablation_hook,
                    head_index_to_ablate=head,
                    ablation_type=ablation_type)
                ablated_logits = model.run_with_hooks(
                    dataset.toks,
                    return_type="logits",
                    fwd_hooks=[(utils.get_act_name(activation_type, layer), temp_hook_fn)]
                )

                ablated_loss = get_loss(ablated_logits, dataset.labels)
                ablated_loss_diff = ablated_loss - loss_no_ablation

                ablation_scores[i, layer, head] = ablated_loss_diff

    fig = px.imshow(
        ablation_scores.view(len(activation_types), -1).cpu().numpy(),
        title=f"Ablation loss diff, ablation_type: {ablation_type}",
        labels={"x": "Head", "y": "Ablated activation type"},
        text_auto=".2f",
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0,
        x=[f"H{layer}.{head}" for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)],
        y = [f"{activation_type.upper()}" for activation_type in activation_types],
        width=600, height=500)
    fig.update_layout(coloraxis_showscale=False)
    display(fig)

Some observations:
- For H0.0, mean ablating the Q and K activations doesn't have a significant impact on loss
  - I speculate that this is because H0.0's attention pattern is usually very diffuse, so the mean will also be very diffuse.
- For all L1 heads, mean ablating Q doesn't have a significant impact on loss
  - This implies something like: the L1 heads are always querying for the same type of information
  - We will leverage this insight when analyzing L1 QK circuits

## Layer 1 head theory

Consider L1 head H1.$h$, with responsibility set $R_{1.h} \subseteq V$.

We know that when $t \in R_{1.h}$ is a duplicate, H1.$h$ is responsible for decreasing the logit value for $t$. Thus, H1.$h$ must be listening for ("querying for") a duplicate detection signal for $t$, and then transforming this signal into a negative logit value for $t$. In other words, H1.$h$ composes with $t$'s L0 duplicate detector via K-composition and V-composition.

We also know that when $t \in R_{1.h}$ is the correct answer, H1.$h$ is primarily responsible for increasing the logit value for $t$. If $t$ is the correct token, then it appears relatively early in the sequence, and so all subsequent non-$t$ positions will carry an anti-duplicate detection for $t$. H1.$h$ somehow attends to these non-$t$ tokens, and transforms the anti-duplicate detection signal for $t$ into a positive logit attribution. In other words, H1.$h$ composes with $t$'s anti-duplicate detection signal via V-composition.

To explore this theory, let's investigate the L1 QK and OV circuits, and in particular how they compose with the L0 OV circuits.

## Layer 1 QK

From the QKV ablation experiment, we know that mean-ablating Q activations for L1 heads does not impair the model. This suggests that we can use these mean Q activations when analyzing L1 QK circuits (rather than splitting Q activation contributions by residual stream component):
$$x_{resid}W_{Q}^{1.h} \approx \text{mean } Q \text{ activation for head } 1.h $$

As for the K activations, we will split them by residual stream component (ignoring LayerNorms for simplicity):
$$
\begin{align}
x_{resid} W_K^{1.h} &= (emb + pos + h_{0.0} + h_{0.1} + h_{0.2}) W_K^{1.h} \\
&= (emb + pos)\left( I + W_{OV}^{0.0} + W_{OV}^{0.1} + W_{OV}^{0.2} \right) W_K^{1.h} \\
&= (emb)\left( I + W_{OV}^{0.0} + W_{OV}^{0.1} + W_{OV}^{0.2} \right) W_K^{1.h} + (pos)\left( I + W_{OV}^{0.0} + W_{OV}^{0.1} + W_{OV}^{0.2} \right) W_K^{1.h}\\
\end{align}
$$

We'll visualize the left (embedding components) and right (positional components) terms separately.

In [28]:
def visualize_l1_QK(path_type: str, cmax: Optional[int] = None):
    if path_type == "Emb":
        W_emb_pos = model.W_E
    elif path_type == "Pos":
        W_emb_pos = model.W_pos
    else:
        raise ValueError("path_type must be either 'Emb' or 'Pos'")

    dataset = UniqueCharDataset(size=1000, vocab=list("abcdefghij"), seq_len=20, seed=42)
    logits, cache = model.run_with_cache(dataset.toks)

    # get mean queries for layer 1 heads
    q = einops.rearrange(cache['q', 1], 'b seq n_heads d_head -> n_heads b seq d_head')
    q_mean = q.mean(dim=1) # n_heads seq d_head

    OV = model.W_V @ model.W_O

    component_labels = [path_type, "H0.0", "H0.1", "H0.2"]
    x_labels = dataset.vocab if path_type == "Emb" else list(range(20))

    subplot_titles = [f"(Q=H{1}.{head1}, K={label})" for head1 in range(model.cfg.n_heads) for label in component_labels]
    fig = make_subplots(rows=model.cfg.n_heads, cols=len(component_labels), subplot_titles=subplot_titles)

    for head1 in range(model.cfg.n_heads):
        query = q_mean[head1]  # seq d_head

        QK = query @ model.W_K[1, head1].T @ (W_emb_pos.T / scale_ln1_layer1.mean())
        QK /= model.cfg.d_head**0.5

        if path_type == "Pos":
            QK = t.tril(QK)

        fig.add_trace(
            go.Heatmap(
                z=QK.detach().cpu().numpy(),
                x=x_labels,
                y=list(range(20)),
                coloraxis="coloraxis"),
            row=head1+1, col=1)

        for head0 in range(model.cfg.n_heads):
            OV_0 = OV[0, head0] / scale_ln1_layer1.mean()
            QK_OV_0 = query @ model.W_K[1, head1].T @ OV_0.T @ (W_emb_pos.T / scale_ln1_layer0.mean())
            QK_OV_0 /= model.cfg.d_head**0.5

            if path_type == "Pos":
                QK_OV_0 = t.tril(QK_OV_0)

            fig.add_trace(
                go.Heatmap(
                    z=QK_OV_0.detach().cpu().numpy(),
                    x=x_labels,
                    y=list(range(20)),
                    coloraxis="coloraxis"),
                row=head1+1, col=head0+2)

    for r in range(model.cfg.n_heads):
        for c in range(len(component_labels)):
            fig.update_yaxes(autorange="reversed", row=r+1, col=c+1)
            fig.update_xaxes(title_text=f"Key ({path_type})", row=r+1, col=c+1, titlefont=dict(size=12), title_standoff=5)
            fig.update_yaxes(title_text="Query", row=r+1, col=c+1, titlefont=dict(size=12), title_standoff=0)

    cmin = -cmax if cmax else cmax
    fig.update_layout(
        title=f"Layer 1 QK by K component (path_type={path_type})",
        height=1000,
        width=1000,
        coloraxis=dict(colorscale='RdBu', cmid=0, cmin=cmin, cmax=cmax),
    )
    display(fig)

#### Layer 1 QK circuits - embedding

In [29]:
visualize_l1_QK(path_type="Emb", cmax=10)

Observations:
- Looking at the left column `(Q=H1.*, K=Emb)`, we notice that, in the absence of L0 head K-composition, L1 heads **attend to tokens outisde their responsibility set**, and **attend away from tokens in their responsibility set**
  - `(Q=H1.0, K=Emb)`: H1.0 attends to `V \ {a, c}` more than `{a, c}`
  - `(Q=H1.1, K=Emb)`: H1.1 attends to `V \ {d, e, f, j}` more than `{d, e, f, j}`
  - `(Q=H1.2, K=Emb)`: H1.2 attends to `V \ {b, g, h, i}` more than `{b, g, h, i}`
- Responsible heads **attend to duplicate detection signal of tokens in their responsibility set**, and **attend away from duplicate detection signal of tokens outside their responsibility set**
  - H1.0
      - `(Q=H1.0, K=H0.1)`: H1.0 attends to H0.1's duplicate detection of `c`
      - `(Q=H1.0, K=H0.2)`: H1.0 attends to H0.2's duplicate detection of `a`
  - H1.1
    - `(Q=H1.1, K=H0.1)`: H1.1 attends to H0.1's duplicate detection of `d, e, f, j`
  - H1.2
    - `(Q=H1.2, K=H0.1)`: H1.2 attends to H0.1's duplicate detection of `b, g, h, i`
    - `(Q=H1.2, K=H0.2)`: H1.2 attends to H0.1's duplicate detection of `b, g`

Another way of stating the above observations is: "an L1 head attends to non-duplicate tokens (first instances only) outside of its responsibility set, and to duplicate tokens (2nd+ instances only) inside of its responsibility set."

This can be validated by looking at a concrete example. Looking at attention patterns for `?abcdeabcde`, we can observe the behavior described above (not perfectly, but I'd say it looks roughly correct). For example, H1.2 attends from the last posisition strongly to the first instances of `c, d` (outside responsibility set), and the second instance of `b` (inside responsibility set).

In [30]:
custom_toks = t.tensor([[10, 0, 1, 2, 3, 0, 1, 2, 3]])
custom_toks_str = [list("?abcdabcd")]

logits, cache = model.run_with_cache(custom_toks)
attn_patterns = cache[utils.get_act_name("pattern", 1)]
for i in range(custom_toks.shape[0]):
    display(
        cv.attention.attention_patterns(
          tokens=custom_toks_str[i],
          attention=attn_patterns[i],
          attention_head_names=[f"{1}.{h}" for h in range(model.cfg.n_heads)])
    )

#### Layer 1 QK circuits - position

In [31]:
visualize_l1_QK(path_type="Pos")

Some observations:
- The overall magnitude of the position pathway is smaller than that of the embedding pathway
- Responsible L1 heads seem to attend more strongly to duplicate detection signals in earlier positions
- Responsible L1 heads attend more strongly to anti-duplicate detection signals in later positions

#### A note on how L1 heads handle duplicates

When a duplicate detection signal for token `t` is propagated from an L0 head, there are two possible causes:
1. It was simply the first instance of token `t`, and the L0 duplicate detector attention is split between `t` and `?`
2. Token `t` is a duplicate, and the L0 duplicate detector's attention is split between multiple instances of `t` and `?`

The L1 head responsible for `t` should attend only in the second case. How does it distinguish these two cases?

I notice the following mechanisms:
- For L0 duplicate detectors:
  - Attention is biased away from the current position (via positional QK), meaning that in case 1 a majority of attention is placed on `?`
  - More instances of `t` means that the total proportion of weight on `t` tokens is higher compared to `?`
  - Thus the two cases can be distinguished by the relative proportion of weight placed on `t` and `?`
- L1 responsible heads:
  - Can distinguish between the two cases by having duplicate signal from `emb=?` or `pos=0` inhibit attention
    - H1.2 detects duplicate signal from `emb=?` to inhibit attention (see `(Q=H1.2, K=H0.1)` in the L1 QK embedding figure
    - H1.0 and H1.1 detect duplicate signal from `pos=0` to inhibit attention (see `(Q=H1.0, K=H0.2)` and `(Q=H1.1, K=H0.1)` in the L1 QK position figure)

I leave this as an incomplete theory for now, but wanted to note it down.

### Summary of Layer 1 QK analysis

- By default, L1 heads *attend away* from tokens in their responsibility set and *attend towards* tokens outside their responsibility set
- This is inverted for duplicate tokens - L1 heads attend towards duplicate tokens in their responsibility set, and away from duplicate tokens outside of their responsibility set
  - This inversion is implemented via K-composition between L0 duplicate detectors and L1 responsible heads

## Layer 1 OV

We want to examine the OV circuits of L1 heads:
$$W_{OV}^{1.h}W_{U}$$

As usual, we can split the residual stream into components (ignoring LayerNorms for simplicity):
$$
\begin{align}
x_{resid} W_{OV}^{1.h}W_{U} &= (emb + pos + h_{0.0} + h_{0.1} + h_{0.2}) W_{OV}^{1.h}W_{U} \\
&= (emb + pos)\left( I + W_{OV}^{0.0} + W_{OV}^{0.1} + W_{OV}^{0.2} \right) W_{OV}^{1.h}W_{U} \\
&\approx (emb)\left( I + W_{OV}^{0.0} + W_{OV}^{0.1} + W_{OV}^{0.2} \right) W_{OV}^{1.h}W_{U}\\
\end{align}
$$

The approximation in the last line comes from the assumption that positional information likely does not play a role in the OV circuit.

### Visualizing Layer 1 OV circuits

From our theory, we expect the following:
- Let $t$ be a one-hot row vector encoding a token.
- Let $a(t)$ be a mapping from tokens to corresponding anti-duplicate detector L0 heads.
- Let $d(t)$ be a mapping from tokens to corresponding duplicate detector L0 heads.
- Let $r(t)$ be a mapping from tokens to corresopnding responsible L1 heads.

Then:
- $t W_E W_{OV}^{d(t)} W_{OV}^{r(t)} W_U t^T$ should be very negative
  - If $d(t)$ pays attention to $t$ and propogates it, then $r(t)$ should detect it and decrease the logits for $t$
- $t W_E W_{OV}^{a(t)} W_{OV}^{r(t)} W_U t^T$ should be very positive
  - If $a(t)$ pays attention to $t$ and propogates it, then $r(t)$ should detect it and increase the logits for $t$

In [32]:
# visualize layer 1 OV circuits

W_emb = model.W_E
W_U = model.W_U
W_U_scaled = W_U / scale_lnfinal.mean()

component_labels = [None, "OV0.0", "OV0.1", "OV0.2"]
min_max = 150
subplot_titles = [f"({label}, OV{1}.{head1})" if label else f"(OV{1}.{head1})" for head1 in range(model.cfg.n_heads) for label in component_labels]
fig = make_subplots(rows=model.cfg.n_heads, cols=len(component_labels), subplot_titles=subplot_titles)

for head1 in range(model.cfg.n_heads):
    W_OV_1 = model.W_V[1, head1] @ model.W_O[1, head1]
    W_OV_1_scaled = W_OV_1 / scale_ln1_layer1.mean()

    OV_full_emb = W_emb @ W_OV_1_scaled @ W_U_scaled
    fig.add_trace(
        go.Heatmap(z=OV_full_emb.detach().cpu().numpy(),
                   x=dataset.vocab, y=dataset.vocab, coloraxis="coloraxis"),
        row=head1+1, col=1
    )

    for head0 in range(model.cfg.n_heads):
        W_OV_0 = model.W_V[0, head0] @ model.W_O[0, head0]
        W_OV_0_scaled = W_OV_0 / scale_ln1_layer0.mean()
        OV_comp = W_emb @ W_OV_0_scaled @ W_OV_1_scaled @ W_U_scaled
        fig.add_trace(
            go.Heatmap(z=OV_comp.detach().cpu().numpy(),
                       x=dataset.vocab, y=dataset.vocab, coloraxis="coloraxis"),
            row=head1+1, col=head0+2
        )

for r in range(model.cfg.n_heads):
    for c in range(len(component_labels)):
        fig.update_yaxes(autorange="reversed", row=r+1, col=c+1)
        fig.update_xaxes(title_text="Output logit", row=r+1, col=c+1, titlefont=dict(size=12), title_standoff=5)
        fig.update_yaxes(title_text="Input emb", row=r+1, col=c+1, titlefont=dict(size=12), title_standoff=0)

fig.update_layout(
    title="Layer 1 OV circuits",
    height=1000,
    width=1100,
    coloraxis=dict(colorscale='RdBu', cmin=-min_max, cmid=0, cmax=min_max),
)

fig.show()

These OV circuit visualizations confirm our V-composition theory.

Observations:
- Each L1 head's most significant logit contributions, positive or negative, corresponds to tokens in its responsibility set
- Whether the contribution is positive or negative depends on whether the L0 head is a duplicate detector or anti-duplicate detector.
  - For example, let's look at `d`
    - `d` has responsible head H1.1, duplicate detector H0.1, and anti-duplicate detector H0.0
      - `(OV0.0, OV1.1)` is the anti-duplicate detector V-composing with the responsible head, and we observe positive logit contribution for `d`
      - `(OV0.1, OV1.1)` is the duplicate detector V-composing with the responsible head, and we observe negative logit contribution for `d`
- Duplicate detection of a token `t` boosts logits for other tokens `t' != t` (including `?`), while anti-duplicate detection of a token `t` decreases logits for other tokens `t' != t` (including `?`)
- The first column of matrices, representing vanilla OV circuits with no L0 V-composition, is also noteworthy
  - It indicates that when a responsible head attends to a token in its responsibility set, it will output a negative logit value for that token.
  - This makes sense, since responsible heads only pay attention to tokens in their responsibility set if the token is a duplicate.

We can get a clearer visualization of what's going on by taking the diagonal values of each matrix.

In [33]:
# visualize layer 1 OV circuit diagonals

W_emb = model.W_E

W_U = model.W_U
W_U_scaled = W_U / scale_lnfinal.mean()

component_labels = ["Emb", "OV0.0", "OV0.1", "OV0.2"]
min_max = 150

subplot_titles = [f"H{1}.{head1}" for head1 in range(model.cfg.n_heads)]
fig = make_subplots(rows=1, cols=model.cfg.n_heads, subplot_titles=subplot_titles)

for head1 in range(model.cfg.n_heads):
    diag_OVs = t.zeros((len(dataset.vocab), len(component_labels)))

    W_OV_1 = model.W_V[1, head1] @ model.W_O[1, head1]
    W_OV_1_scaled = W_OV_1 / scale_ln1_layer1.mean()

    OV_full_emb = W_emb @ W_OV_1_scaled @ W_U_scaled
    diag_OVs[:, 0] = t.diag(OV_full_emb)

    for head0 in range(model.cfg.n_heads):
        W_OV_0 = model.W_V[0, head0] @ model.W_O[0, head0]
        W_OV_0_scaled = W_OV_0 / scale_ln1_layer0.mean()
        OV_comp = W_emb @ W_OV_0_scaled @ W_OV_1_scaled @ W_U_scaled

        diag_OVs[:, head0+1] = t.diag(OV_comp)

    fig.add_trace(
        go.Heatmap(z=diag_OVs.detach().cpu().numpy(),
                   x=component_labels, y=dataset.vocab, coloraxis="coloraxis"),
        row=1, col=head1+1
    )

for h in range(model.cfg.n_heads):
    fig.update_yaxes(autorange="reversed", row=1, col=h+1)

fig.update_layout(
    title="Layer 1 OV composition diagonal values",
    height=500, width=600,
    coloraxis=dict(colorscale='RdBu', cmin=-min_max, cmid=0, cmax=min_max),
)

fig.show()

From this figure, we can very easily read off each L1 head's responsibility set, and see the V-composition with each token's duplicate detector and anti-duplicate detector.

In fact, the norm of each row seems like a good metric for determining which heads are responsible for which tokens (implemented below).

Also note the diagonal entries for tokens outside of the L1 head's repsonsibility set: the OV0.0 and OV0.2 diagonals (and even OV0.1 for H1.0) have positive value for tokens outside the responsibility set. These positive values are milder in magnitude, but still seem fairly significant. This suggests a pathway for L1 heads to contribute positive logits for tokens not in their responsibility set (we saw this was the case in our direct logit attribution analysis).

### Computing responsible heads

In [34]:
from collections import defaultdict

def get_L1_responsible_heads(model: HookedTransformer, dataset: UniqueCharDataset):
    tok_to_resp = defaultdict(list)
    resp_to_toks = defaultdict(list)

    W_emb = model.W_E

    W_U = model.W_U
    W_U_scaled = W_U / scale_lnfinal.mean()

    component_labels = ["Emb", "OV0.0", "OV0.1", "OV0.2"]
    min_max = 150

    diag_OVs = t.zeros((model.cfg.n_heads, len(dataset.vocab), len(component_labels)))

    for head1 in range(model.cfg.n_heads):
        W_OV_1 = model.W_V[1, head1] @ model.W_O[1, head1]
        W_OV_1_scaled = W_OV_1 / scale_ln1_layer1.mean()

        OV_full_emb = W_emb @ W_OV_1_scaled @ W_U_scaled
        diag_OVs[head1, :, 0] = t.diag(OV_full_emb)

        for head0 in range(model.cfg.n_heads):
            W_OV_0 = model.W_V[0, head0] @ model.W_O[0, head0]
            W_OV_0_scaled = W_OV_0 / scale_ln1_layer0.mean()
            OV_comp = W_emb @ W_OV_0_scaled @ W_OV_1_scaled @ W_U_scaled

            diag_OVs[head1, :, head0+1] = t.diag(OV_comp)

    # Compute a norm score per L1 head per token by considering the diagonal
    # entries for that token. Responsible head can be understood as the L1 head
    # with the largest score.
    diag_OVs_norm = diag_OVs.norm(dim=-1)
    resp_heads = t.argmax(diag_OVs_norm, dim=0)

    for i, v in enumerate(dataset.vocab[:-1]):
        resp_head = resp_heads[i].item()
        tok_to_resp[i].append(resp_head)
        resp_to_toks[resp_head].append(i)

    return dict(tok_to_resp), dict(resp_to_toks)

In [35]:
tok_to_resp, resp_to_toks = get_L1_responsible_heads(model, dataset)
print("Responsible heads:")
for k, v in sorted(resp_to_toks.items()):
    tokens = ', '.join(dataset.vocab[t] for t in v)
    print(f"  H1.{k}: {tokens}")


Responsible heads:
  H1.0: a, c
  H1.1: d, e, f, j
  H1.2: b, g, h, i


### Summary of Layer 1 OV analysis

- L1 heads V-compose with L0 duplicate detectors / anti-duplicate detectors corresponding to tokens in their responsibility set
  - Duplicate signal for a token is translated to negative logit contribution while anti-duplicate signal is translated to positive logit contribution
- L1 heads also V-compose with L0 anti-duplicate detectors for tokens outside their responsibility set
  - Anti-duplicate signal is translated to positive logit contribution

## Putting it all together

### Mechanism description

Here is my intuitive description of how the mechanism works:
- L0 heads serve as "duplicate detectors" and "anti-duplicate detectors". These L0 heads can be understood as adding some signal to the residual stream, which will later be read by the L1 heads.
  - Duplicate detectors add signal indicating whether a particular token has been repeated
  - Anti-duplicate detectors add signal consisting of a mix of all tokens that have appeared previously in the sequence, excluding the current token
  - Each L0 head can be understood to function as a duplicate detector for some set of tokens, and an anti-duplicate detector for another set of tokens.
- The L1 heads listen to ("compose with") these signals in order to determine the correct answer.
  - Duplicate detector signal for token `t` is read and translated to a heavy negative logit contribution for `t`
  - Anti-duplicate detector signal for token `t` is read and translated to positive logit contribution for `t`
    - When a token is at an earlier position, there will be more anti-duplicate detector signal for it, since all subsequent non-duplicated positions will attend to it.
    - Thus, with anti-duplicate detector signal translating to positive logit contribution, the earliest tokens should have the largest positive logit contribution (subject to duplicates).

From this description, it is clear that the mechanism relies on two major sub-mechanisms: decreasing logits for duplicates, and increasing logits for previous tokens.

As a next step, we visualize the full OV circuits for each of these two mechanisms.

### Full duplicate and anti-duplicate OV circuits

- Let $a(t)$ be a mapping from tokens to corresponding anti-duplicate detector L0 head indices.
- Let $d(t)$ be a mapping from tokens to corresponding duplicate detector L0 head indices.
- Let $r(t)$ be a mapping from tokens to its corresponding responsible L1 head index.

For token $t$:
- The full duplicate OV circuit is given by:
$$W_E \left( \sum_{d \in d(t)} OV_{0.d}\right) OV_{1.r(t)} W_U$$

- The full anti-duplicate OV circuit is given by:
$$W_E \left( \sum_{a \in a(t)} OV_{0.a}\right) \left( \sum_{h \in \{ 0, 1, 2\}} OV_{1.h} \right) W_U$$


We can compute and visualize these circuits explicitly:

In [36]:
n_heads = model.cfg.n_heads
d_model = model.cfg.d_model
n_vocab = len(dataset.vocab)

W_E = model.W_E
W_U = model.W_U
W_OV = model.W_V @ model.W_O

scale_factor = scale_ln1_layer0.mean() * scale_ln1_layer1.mean() * scale_lnfinal.mean()

dup_resp_OV = t.zeros(n_vocab-1, d_model, d_model).cuda()
antidup_all_OV = t.zeros(n_vocab-1, d_model, d_model).cuda()

# get L1 responsible heads
tok_to_resp, _exit_code = get_L1_responsible_heads(model, dataset)

# get L0 duplicate/anti-duplicate detectors
tok_to_dup, _, tok_to_antidup, _ = get_L0_detectors(model, dataset)

for tok, _ in enumerate(dataset.vocab[:-1]):
    resp_head = tok_to_resp[tok][0]
    dup_detectors = tok_to_dup[tok]
    antidup_detectors = tok_to_antidup[tok]

    # compute effective duplicate and anti-duplicate OVs for this token
    dup_OV = W_OV[0, dup_detectors].sum(dim=0)
    antidup_OV = W_OV[0, antidup_detectors].sum(dim=0)

    # compute effective responsible head and non-responsible head OVs for this token
    resp_head_OV = W_OV[1, resp_head]
    all_L1_heads_OV = W_OV[1].sum(dim=0)

    # compose L0 and L1 OVs
    dup_resp_OV[tok]     = dup_OV @ resp_head_OV
    antidup_all_OV[tok] = antidup_OV @ all_L1_heads_OV

# plot full duplicate OV circuit
dup_circuit = W_E[:-1] @ dup_resp_OV @ W_U / scale_factor
dup_circuit = dup_circuit[list(range(n_vocab-1)), list(range(n_vocab-1))]
fig = px.imshow(
    dup_circuit.detach().cpu().numpy(),
    title=f"Full duplicate OV circuit",
    labels={"x": "Output (logits)", "y": "Input (embedding)", "color": "Logits"},
    width=500, height=500,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    x=dataset.vocab,
    y=dataset.vocab[:-1]
)
display(fig)

# plot full anti-duplicate OV circuit
antidup_circuit = W_E[:-1] @ antidup_all_OV @ W_U / scale_factor
antidup_circuit = antidup_circuit[list(range(n_vocab-1)), list(range(n_vocab-1))]
fig = px.imshow(
    antidup_circuit.detach().cpu().numpy(),
    title=f"Full anti-duplicate OV circuit",
    labels={"x": "Output (logits)", "y": "Input (embedding)", "color": "Logits"},
    width=500, height=500,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    x=dataset.vocab,
    y=dataset.vocab[:-1]
)
display(fig)

We can further split the full anti-duplicate circuit into contributions from responsible heads, and contributions from non-responsible heads:

- The anti-duplicate OV circuit *through responsible heads* is given by:
$$W_E \left( \sum_{a \in a(t)} OV_{0.a}\right) OV_{1.r(t)} W_U$$

- The anti-duplicate OV circuit *through non-responsible heads* is given by:
$$W_E \left( \sum_{a \in a(t)} OV_{0.a}\right) \left( \sum_{nr \neq r(t)} OV_{1.nr} \right) W_U$$

In [37]:
antidup_resp_head_OV = t.zeros(n_vocab-1, d_model, d_model).cuda()
antidup_nonresp_head_OV = t.zeros(n_vocab-1, d_model, d_model).cuda()

for tok, _ in enumerate(dataset.vocab[:-1]):
    resp_head = tok_to_resp[tok][0]
    antidup_detectors = tok_to_antidup[tok]

    # compute effective anti-duplicate OV for this token
    antidup_OV = W_OV[0, antidup_detectors].sum(dim=0)

    # compute effective responsible head and non-responsible head OVs for this token
    resp_head_OV = W_OV[1, resp_head]
    nonresp_heads = [(resp_head + i) % n_heads for i in range(1, n_heads)]
    nonresp_head_OV = W_OV[1, nonresp_heads].sum(dim=0)

    # compose L0 and L1 OVs
    antidup_resp_head_OV[tok]    = antidup_OV @ resp_head_OV
    antidup_nonresp_head_OV[tok] = antidup_OV @ nonresp_head_OV

# plot full anti-duplicate OV circuit (responsible L1 head)
antidup_resp_circuit = W_E[:-1] @ antidup_resp_head_OV @ W_U / scale_factor
antidup_resp_circuit = antidup_resp_circuit[list(range(n_vocab-1)), list(range(n_vocab-1))]
fig = px.imshow(
    antidup_resp_circuit.detach().cpu().numpy(),
    title=f"Anti-duplicate OV circuit (responsible L1 head)",
    labels={"x": "Output (logits)", "y": "Input (embedding)", "color": "Logits"},
    width=500, height=500,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    x=dataset.vocab,
    y=dataset.vocab[:-1]
)
display(fig)

# plot full anti-duplicate OV circuit (non-responsible heads)
antidup_nonresp_circuit = W_E[:-1] @ antidup_nonresp_head_OV @ W_U / scale_factor
antidup_nonresp_circuit = antidup_nonresp_circuit[list(range(n_vocab-1)), list(range(n_vocab-1))]
fig = px.imshow(
    antidup_nonresp_circuit.detach().cpu().numpy(),
    title=f"Anti-duplicate OV circuit (non-responsible L1 heads)",
    labels={"x": "Output (logits)", "y": "Input (embedding)", "color": "Logits"},
    width=500, height=500,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    x=dataset.vocab,
    y=dataset.vocab[:-1]
)
display(fig)

## Further work


The crux of the mechanism is explained above. However, there are still some questions to be explored:

1. How does the mechanism correctly predict "first-instance" tokens?
  - For example, consider the sequence `?aaab`. The correct answer for the last position is `b`, and this is the first instance of `b`.
  - The mechanisms described above do not account for this. The residual in the last position will contain a little bit of duplicate information, and no anti-duplicate information, so any correct positive attribution will somehow have to translate the little bit of duplicate information into positive logits.
  - See [Appendix: Correct first-instance predictions](#scrollTo=LjsY5ajcPf0U) for some initial analysis.

2. How does the mechanism reliably predict `?` when all present tokens are duplicates?
  - Duplicate detection leads to negative logits for the duplicated token and postive logits for all other tokens, including `?`. It's not immediately clear how `?` is priviledged here. Perhaps there is some mechanism to bias towards `?` over other tokens to begin with.
  - I think the anti-duplicate OV circuit through non-responsible L1 heads could also play a role here, as when an L1 head detects presence of a non-responsible token, it boosts logits for that token, but also boosts logits for `?`.

3. How do Layer 1 heads reliably not attend to "first instance" tokens in their responsibility set?
  - I explored this briefly (see [A note on how L1 heads handle duplicates](#scrollTo=xUqmWb_NSqB_)), but I don't understand it fully. There is also a gap in my understanding for how H1.0 does not attend to first instances of `c`.

4. How is token position precisely quantified?
  - I have some intuitive understanding that earlier tokens will have more anti-duplicate signal, but it seems like this quantification needs to be fairly precise, for example to distinguish between two adjacent tokens. It would be good to do some quantitative analysis mapping token position to anti-duplicate signal.

5. What is the impact of the L1 QK position circuits?
  - I noticed that L1 heads attend more strongly to duplicate detection signals in earlier positions, while they attend more strongly to anti-duplicate detection signals in later positions. I don't understand why this is the case.

## Appendix

### Position 1

Position 1 has a special mechanism. It's not that difficult a task - the correct prediction at position 1 is always just the token at position 1.

The model achieves this essentially by having H0.0 and H0.2 always self-attend to position 1. The OV circuits of H0.0 and H0.2 contribute positive logits for the correct answer directly. They also V-compose with L1 heads in most cases - H0.0 and H0.2 are primarily used as anti-duplicate detectors, so L1 heads read their signal and output positive logits (there are a few exceptions for tokens where H0.2 is a duplicate detector - in these three cases either the L1 head does not attend to position 1, or the L1 head outputs negative logits that get outweighed by the other positive contributions).

We can visualize the logit contributions and attention patterns for each component on each of the possible starting sequences:

In [84]:
short_toks = t.tensor([[10, i] for i in range(len(dataset.vocab)-1)], dtype=t.int)
short_labels = short_toks
short_toks_str = [['?',f'{dataset.vocab[i]}'] for i in range(len(dataset.vocab)-1)]
seq_len = short_toks.shape[-1]

short_logits, short_cache = model.run_with_cache(short_toks)

head_labels = ['H0.0', 'H0.1', 'H0.2', 'H1.0', 'H1.1', 'H1.2']
attn = t.cat((short_cache["pattern", 0], short_cache["pattern", 1]), dim=1)
attn = einops.rearrange(attn, 'batch n_head seqq seqk -> n_head batch seqq seqk')



correct_W_U_directions = model.W_U.T[short_labels] # (batch, seq, d_model)

emb_pos_components = short_cache.decompose_resid(layer=0, incl_embeds=True)
attn_components    = short_cache.stack_head_results(layer=-1)
components = t.concat((emb_pos_components, attn_components), dim=0)
components = components / short_cache[utils.get_act_name('scale')]
components = components[2:]

correct_logit_contribution_by_components = einops.einsum(
    components,
    correct_W_U_directions,
    'n batch seq d_model, batch seq d_model -> n batch seq')

correct_logit_contributions_mean = t.zeros((len(head_labels), len(dataset.vocab)-1))
pos_1_self_attn = t.zeros((len(head_labels), len(dataset.vocab)-1))

for h, _ in enumerate(head_labels):
    for i, _ in enumerate(dataset.vocab[:-1]): # split examples by token
        correct_logit_contributions = correct_logit_contribution_by_components[h][short_labels == i]
        correct_logit_contributions_mean[h, i] = correct_logit_contributions.mean()
        pos_1_self_attn[h, i] = attn[h, i, 1, 1]

fig = px.imshow(
    correct_logit_contributions_mean.detach().cpu().numpy(),
    title=f"Correct logit attribution, position 1",
    labels={"x": "Token at position 1", "y": "Component", "color": "Logits"},
    width=500, height=500,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    x=dataset.vocab[:-1],
    y=head_labels
)
display(fig)

fig = px.imshow(
    pos_1_self_attn.detach().cpu().numpy(),
    title=f"Self-attention at position 1 (to and from pos 1)",
    labels={"x": "Token at position 1", "y": "Component", "color": "Self-attention"},
    width=500, height=500,
    color_continuous_scale="RdBu",
    # color_continuous_midpoint=0,
    x=dataset.vocab[:-1],
    y=head_labels
)
display(fig)

We can see that logit contributions for position 1 are usually a combination of {H0.0, H0.2} and {H1.0, H1.1, H1.2}.

We can take a look at the L0 OV circuits, and observe that H0.0 and H0.2 are more-or-less diagonal:

In [39]:
# visualize layer 0 OV circuits

W_emb = model.W_E
W_emb_scaled = W_emb / scale_ln1_layer0.mean()

W_U = model.W_U
W_U_scaled = W_U / scale_lnfinal.mean()

for head in range(model.cfg.n_heads):
    W_OV = model.W_V[0, head] @ model.W_O[0, head]
    OV = W_emb_scaled @ W_OV @ W_U_scaled

    fig = px.imshow(
        OV.detach().cpu().numpy(),
        title=f"Head {0}.{head} OV circuit",
        labels={"x": "Output logit", "y": "Input token", "color": "Logit weight"},
        width=500,
        height=500,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0,
        x=dataset.vocab,
        y=dataset.vocab
    )

    display(fig)

Observations:
- H0.0 seems to be responsive to `{a, b, d, g, h, i, j}`
- H0.2 seems to be responsive to `{a, b, d, e, f, h, j}`

Since H0.0 and H0.2 attend strongly from position 1 to position 1, they contribute positive logits towards the token at position 1 (when the token is in its "responsive set").

The token missing from both H0.0 and H0.2's responsive sets is `c`. H1.0 fills in this gap with a strong correct logit attribution for `c`.

### Correct first-instance predictions

The below custom dataset consists of examples of correct first-instance predictions. We display direct logit contributions and attention patterns.

The most confusing part to me is that the responsible head often has the largest positive logit contribution, while not even attending to the first-instance token that it is contributing positive logits towards.

As a concrete example, check out example #2, H1.2 predicting the first-instance `g` correctly. Attention patterns indicate that H1.2 pays no attention to `g` (and therefore has no information about `g`), and yet still outputs significant positive logits for `g`.

In [40]:
custom_toks = t.tensor([
    [10, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4],
    [10, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6],
    [10, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8],
    [10, 6, 6, 7, 7, 8, 8, 9, 9, 0, 0],
    [10, 8, 8, 9, 9, 0, 0, 1, 1, 2, 2],
])
custom_toks_str = [
    list("?aabbccddee"),
    list("?ccddeeffgg"),
    list("?eeffgghhii"),
    list("?gghhiijjaa"),
    list("?iijjaabbcc"),
]

custom_toks_labels = t.tensor([
    [10, 0, 10, 1, 10, 2, 10, 3, 10, 4, 10],
    [10, 2, 10, 3, 10, 4, 10, 5, 10, 6, 10],
    [10, 4, 10, 5, 10, 6, 10, 7, 10, 8, 10],
    [10, 6, 10, 7, 10, 8, 10, 9, 10, 0, 10],
    [10, 8, 10, 9, 10, 0, 10, 1, 10, 2, 10],
    ])

logits, cache = model.run_with_cache(custom_toks)

emb_pos_components = cache.decompose_resid(layer=0, incl_embeds=True)
attn_components    = cache.stack_head_results(layer=-1)
components = t.concat((emb_pos_components, attn_components), dim=0)
components = components / cache[utils.get_act_name('scale')]
components = components[-3:]
labels = ['H1.0', 'H1.1', 'H1.2']

logit_contributions_out_by_components = einops.einsum(
    components, model.W_U, 'n batch seq d_model, d_model d_vocab -> batch n seq d_vocab')

attn0 = einops.rearrange(cache["pattern", 0], 'b n sq sk -> n b sq sk')
attn1 = einops.rearrange(cache["pattern", 1], 'b n sq sk -> n b sq sk')
attn = t.cat((attn0, attn1), dim=0)

for ex in range(5):
    fig = plot_logit_att_by_component(logit_contributions_out_by_components[ex],
                                      dataset.vocab,
                                      labels,
                                      custom_toks[ex],
                                      custom_toks_str[ex],
                                      custom_toks_labels[ex],
                                      f'Direct logit attribution for example #{ex}')
    display(fig)

    display(
        cv.attention.attention_patterns(
          tokens=custom_toks_str[ex],
          attention=attn[:, ex],
          attention_head_names=["0.0", "0.1", "0.2", "1.0", "1.1", "1.2"])
    )