<a href="https://colab.research.google.com/github/andyrdt/mi/blob/main/ARENA/monthly_algorithmic_problems/09_2023/Sum_Of_Two_Numbers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Monthly Algorithmic Challenge (September 2023): Sum Of Two Numbers

This post is the third in the sequence of monthly mechanistic interpretability challenges. They are designed in the spirit of [Stephen Casper's challenges](https://www.lesswrong.com/posts/KSHqLzQscwJnv44T8/eis-vii-a-challenge-for-mechanists), but with the more specific aim of working well in the context of the rest of the ARENA material, and helping people put into practice all the things they've learned so far.


If you prefer, you can access the Streamlit page [here](https://arena-ch1-transformers.streamlit.app/Monthly_Algorithmic_Problems).

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/machines.png" width="350">

## Setup

In [1]:
%%capture
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys

if IN_COLAB:
    # Install packages
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/eindex.git
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

    # Code to download the necessary files (e.g. solutions, test funcs)
    import os, sys
    if not os.path.exists("chapter1_transformers"):
        !curl -o /content/main.zip https://codeload.github.com/callummcdougall/ARENA_2.0/zip/refs/heads/main
        !unzip /content/main.zip 'ARENA_2.0-main/chapter1_transformers/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter1_transformers/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter1_transformers", "chapter1_transformers")
        os.rmdir("ARENA_2.0-main")
        os.chdir("chapter1_transformers/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

In [2]:
import torch as t
from pathlib import Path

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "monthly_algorithmic_problems" / "september23_sum"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from monthly_algorithmic_problems.september23_sum.dataset import SumDataset
from monthly_algorithmic_problems.september23_sum.model import create_model
from plotly_utils import hist, bar, imshow

device = t.device("cuda" if t.cuda.is_available() else "cpu")



In [3]:
import functools
from typing import List, Tuple, Union, Optional, Callable, Dict

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm import tqdm
from rich.table import Table, Column
from rich import print as rprint

from torch import Tensor
import torch.nn.functional as F

import einops
import circuitsvis as cv
from jaxtyping import Float, Int, Bool
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformer_lens.components import LayerNorm
from transformer_lens.hook_points import HookPoint


## Task & Dataset

The problem for this month (or at least as much of the month as remains!) is interpreting a model which has been trained to perform simple addition. The model was fed input in the form of a sequence of digits (plus special + and = characters with token ids 10 and 11), and was tasked with predicting the sum of digits one sequence position before they would appear. Cross entropy loss was only applied to these four token positions, so the model's output at other sequence positions is meaningless.
                
<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/predictions.png" width="600">

All the left-hand numbers are below 5000, so we don't have to worry about carrying past the thousands digit.

Here is an example of what this dataset looks like:

In [4]:
dataset = SumDataset(size=1, num_digits=4, seed=42)

print(dataset[0].tolist()) # tokens, for passing into model
print("".join(dataset.str_toks[0])) # string tokens, for printing

[2, 7, 6, 4, 10, 1, 5, 0, 4, 11, 4, 2, 6, 8]
2764+1504=4268


The relevant files can be found at:

```
chapter1_transformers/
└── exercises/
    └── monthly_algorithmic_problems/
        └── september23_sum/
            ├── model.py               # code to create the model
            ├── dataset.py             # code to define the dataset
            ├── training.py            # code to training the model
            └── training_model.ipynb   # actual training script
```

We've given you the class `SumDataset` to store your data, as you can see above. You can slice this object to get tokens, or use the `str_toks` attribute (a list of lists of strings).

## Model

Our model was trained by minimising cross-entropy loss between its predictions and the true labels, at the four positions of the sum's digits. You can inspect the notebook `training_model.ipynb` to see how it was trained. I used the version of the model which achieved highest accuracy over 100 epochs (accuracy ~100%).



The model is is a 2-layer transformer with 3 attention heads, and causal attention. It includes layernorm, but no MLP layers. You can load it in as follows:

In [5]:
filename = section_dir / "sum_model.pt"

model = create_model(
    num_digits=4,
    seed=0,
    d_model=48,
    d_head=24,
    n_layers=2,
    n_heads=3,
    normalization_type="LN",
    d_mlp=None
)

state_dict = t.load(filename)

state_dict = model.center_writing_weights(t.load(filename))
state_dict = model.center_unembed(state_dict)
state_dict = model.fold_layer_norm(state_dict)
state_dict = model.fold_value_biases(state_dict)
model.load_state_dict(state_dict, strict=False);

The code to process the state dictionary is a bit messy, but it's necessary to make sure the model is easy to work with. For instance, if you inspect the model's parameters, you'll see that `model.ln_final.w` is a vector of 1s, and `model.ln_final.b` is a vector of 0s (because the weight and bias have been folded into the unembedding).

In [6]:
print("ln_final weight: ", model.ln_final.w)
print("\nln_final, bias: ", model.ln_final.b)

ln_final weight:  Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0',
       requires_grad=True)

ln_final, bias:  Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0', requires_grad=True)


<details>
<summary>Aside - the other weight processing parameters</summary>

Here's some more code to verify that our weights processing worked, in other words:

* The unembedding matrix has mean zero over both its input dimension (`d_model`) and output dimension (`d_vocab`)
* All writing weights (i.e. `b_O`, `W_O`, and both embeddings) have mean zero over their output dimension (`d_model`)
* The value biases `b_V` are zero (because these can just be folded into the output biases `b_O`)

```python
W_U_mean_over_input = einops.reduce(model.W_U, "d_model d_vocab -> d_model", "mean")
t.testing.assert_close(W_U_mean_over_input, t.zeros_like(W_U_mean_over_input))

W_U_mean_over_output = einops.reduce(model.W_U, "d_model d_vocab -> d_vocab", "mean")
t.testing.assert_close(W_U_mean_over_output, t.zeros_like(W_U_mean_over_output))

W_O_mean_over_output = einops.reduce(model.W_O, "layer head d_head d_model -> layer head d_head", "mean")
t.testing.assert_close(W_O_mean_over_output, t.zeros_like(W_O_mean_over_output))

b_O_mean_over_output = einops.reduce(model.b_O, "layer d_model -> layer", "mean")
t.testing.assert_close(b_O_mean_over_output, t.zeros_like(b_O_mean_over_output))

W_E_mean_over_output = einops.reduce(model.W_E, "token d_model -> token", "mean")
t.testing.assert_close(W_E_mean_over_output, t.zeros_like(W_E_mean_over_output))

W_pos_mean_over_output = einops.reduce(model.W_pos, "position d_model -> position", "mean")
t.testing.assert_close(W_pos_mean_over_output, t.zeros_like(W_pos_mean_over_output))

b_V = model.b_V
t.testing.assert_close(b_V, t.zeros_like(b_V))
```

</details>


A demonstration of the model working:


In [7]:
from eindex import eindex

dataset = SumDataset(size=1000, num_digits=4, seed=42).to(device)

targets = dataset.toks[:, -4:]

logits, cache = model.run_with_cache(dataset.toks)
logits = logits[:, -5:-1]

logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
probs = logprobs.softmax(-1)

logprobs_correct = eindex(logprobs, targets, "batch seq [batch seq]")
probs_correct = eindex(probs, targets, "batch seq [batch seq]")

print(f"Average cross entropy loss: {-logprobs_correct.mean().item():.3f}")
print(f"Mean probability on correct label: {probs_correct.mean():.3f}")
print(f"Median probability on correct label: {probs_correct.median():.3f}")
print(f"Min probability on correct label: {probs_correct.min():.3f}")

Average cross entropy loss: 0.007
Mean probability on correct label: 0.993
Median probability on correct label: 0.996
Min probability on correct label: 0.759


And a visualisation of its probability output for a single sequence:

In [8]:
def show(i):

    imshow(
        probs[i].T,
        y=dataset.vocab,
        x=[f"{dataset.str_toks[i][j]}<br><sub>({j})</sub>" for j in range(9, 13)],
        labels={"x": "Token", "y": "Vocab"},
        xaxis_tickangle=0,
        title=f"Sample model probabilities:<br>{''.join(dataset.str_toks[i])}",
        text=[
            ["〇" if (str_tok == target) else "" for target in dataset.str_toks[i][-4:]]
            for str_tok in dataset.vocab
        ],
        width=400,
        height=550,
    )

show(0)

If you want some guidance on how to get started, I'd recommend reading the solutions for the July problem - I expect there to be a lot of overlap in the best way to tackle these two problems. You can also reuse some of that code!


Best of luck! 🎈

# Andy's work starts here

## Unembedding matrix

We'll start by examining the unembedding matrix. This matrix has a very interesting structure, and gives insight into how the model represents digits (at least towards the end of the model).

In [9]:
fig = px.imshow(
    model.W_U[:, :-2].detach().cpu().numpy(),
    title=f"W_U",
    labels={"x": "vocab", "y": "d_model"},
    width=400, height=400,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    x=dataset.vocab[:-2],
)
display(fig)

Observations:
- `W_U` is extremely sparse:
  - Only dimensions 19, 25, 26, 37 are significant
- The significant rows resemble cosine waves with the same frequency but different phase


### Fourier analysis

We can extract the magnitude and phase of each significant row by performing a Fourier transform.

In [10]:
significant_dims = [19, 25, 26, 37]

W_U_fft = t.fft.rfft(model.W_U[:, :-2], dim=1)

W_U_fft_magnitude = t.sqrt(W_U_fft.real**2 + W_U_fft.imag**2)
W_U_fft_phase = t.atan2(W_U_fft.imag, W_U_fft.real)

fig = px.imshow(
    W_U_fft_magnitude[significant_dims].detach().cpu().numpy(),
    title=f"W_U frequency magnitudes",
    labels={"x": "Frequency bucket", "y": "Dimension"},
    width=400, height=400,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    y=[f"dim_{d}" for d in significant_dims],
)
display(fig)

This confirms that each significant dimension primarily has a single frequency: $\omega_1 = \frac{2\pi}{10}$.

We can compute each significant dimension's phase:

In [11]:
fig = go.Figure()
fig.add_trace(go.Scatterpolar(
    r=W_U_fft_magnitude[significant_dims,1].detach().cpu().numpy(),
    theta=W_U_fft_phase[significant_dims,1].detach().cpu().numpy(),
    thetaunit='radians',
    marker=dict(size=5, color='blue'),
    name='Freq Bin 1',
    mode='markers+text',
    textposition="bottom center",
    text=[f"dim_{d}" for d in significant_dims],
))

fig.update_layout(
    title='W_U, phase of significant dimensions',
    font=dict(size=10),
    width=400, height=400
)
fig.update_polars(radialaxis=dict(showticklabels=False))

fig.show()

Observations:
- The phases for `dim_19` and `dim_26` differ by approximately $\pi$, and so they're encoding approximately the same information (one is the negative of the other)
- The phase for `dim_37` differs from `dim_19` / `dim_26` by approximately $\frac{\pi}{2}$, and so they can be thought of as representing a cos and sin of the same input

Each significant dimension $d \in \{ 19, 25, 26, 39\}$ roughly contains the value $$W_U[d, i] \approx A_d \cos \left( i \cdot \frac{2\pi}{10} + \phi_{d}\right)$$
where $A_d$ represents the amplitude and $\phi_d$ the phase offset for dimension $d$.

### Digits as phases

The $i^{th}$ column of `W_U` represents the direction in activation space corresponding to digit $i$.

We can project each direction onto each of the other directions:

In [12]:
W_U_proj = model.W_U[:, :-2].T @ model.W_U[:, :-2]

fig = px.imshow(
    W_U_proj.detach().cpu().numpy(),
    title=f"W_U.T @ W_U",
    labels={"x": "Vocab", "y": "Vocab"},
    width=400, height=400,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    x=dataset.vocab[:-2],
    y=dataset.vocab[:-2]
)
display(fig)

Now doing a Fourier transform:

In [13]:
W_U_proj_fft = t.fft.rfft(W_U_proj, dim=1)

W_U_proj_fft_magnitude = t.sqrt(W_U_proj_fft.real**2 + W_U_proj_fft.imag**2)
W_U_proj_fft_phase = t.atan2(W_U_proj_fft.imag, W_U_proj_fft.real)

fig = go.Figure()
fig.add_trace(go.Scatterpolar(
    r=W_U_proj_fft_magnitude[:,1].detach().cpu().numpy(),
    theta=W_U_proj_fft_phase[:,1].detach().cpu().numpy(),
    thetaunit='radians',
    marker=dict(size=5, color='blue'),
    mode='markers+text',
    textposition="bottom center",
    text=dataset.vocab[:-2],
))
fig.update_layout(
    title='W_U.T @ W_U',
    font=dict(size=10),
    width=400, height=400
)
fig.update_polars(radialaxis=dict(showticklabels=False))
fig.show()

Each output direction (i.e. each column of `W_U`) can be thought of as an angle. The angle is encoded by the significant dimensions as cos functions with various phases.

### SVD analysis

We can gain similar insight into `W_U` by looking at its SVD.

In [14]:
U, S, Vh = t.svd(model.W_U[:, :-2])

fig1 = px.bar(S.detach().cpu().numpy())
fig1.update_layout(title="W_U singular values",
                  xaxis_title="Index",
                  yaxis_title="Value",
                  showlegend=False,
                  width=450, height=300)
fig1.show()

fig2 = go.Figure()
for i in range(2):
    fig2.add_trace(go.Scatter(x=list(range(model.cfg.d_model)), y=U[:, i].detach().cpu().numpy(), name=f'{i}'))
fig2.update_layout(title="W_U, first 2 principal components",
                  xaxis_title="Index",
                  yaxis_title="Value",
                  width=450, height=300)
fig2.show()

We can see that `W_U` has two very large singular values, roughly corresponding to `dim_37` and `(-dim_19, dim_26)`.

We can project each digit's representation onto these directions:

In [15]:
primary_dirs = U[:, :2] # (d 2)

components = model.W_U[:, :-2].T @ primary_dirs
components = components.detach().cpu().numpy()

fig = go.Figure()
fig.add_trace(go.Scatter(
    x=components[:, 0],
    y=components[:, 1],
    mode="markers+text",
    text=dataset.vocab[:-2],
    textposition="top center",
))
fig.update_layout(
    height=500, width=500,
    title_text='W_U projected onto primary and secondary directions',
    xaxis_title="Component 1",
    yaxis_title="Component 2",
)
fig.show()

Beautiful!

### Summary of W_U analysis

Roughly, we can think of pre-`W_U` activations as angles:
- Dimensions `dim_37` and `(-dim_19, dim_26)` represent the cos and sin values of the angle
- Each column of `W_U` corresponds to cos and sin values for a particular digit/angle
  - `W_U[37,i]` *roughly* corresponds to $\cos(i \cdot \frac{2\pi}{10} + \phi)$
  - `W_U[26,i]` (and  `-W_U[19,i]`) *roughly* corresponds to $\sin(i \cdot \frac{2\pi}{10} + \phi)$

## Ablation analysis

### Head ablations

We'll do some simple head ablations to see if we can narrow down which heads are important.

In [16]:
def head_ablation_hook(
    attn_result: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
    ablation_type: str, # either 'zero' or 'mean'
    head_index_to_ablate: Optional[int] = None, # if -1, then ablate all heads at this layer
) -> Float[Tensor, "batch seq n_heads d_model"]:
    if ablation_type == 'zero':
        if head_index_to_ablate == None:
            attn_result[:, :, :, :] = 0
        else:
            attn_result[:,:,head_index_to_ablate,:] = 0
    elif ablation_type == 'mean':
        if head_index_to_ablate == None:
            attn_result[:, :, :, :] = attn_result.mean(0, keepdim=True)
        else:
            attn_result[:,:,head_index_to_ablate,:] = attn_result[:,:,head_index_to_ablate].mean(0, keepdim=True)
    return attn_result

In [17]:
def get_loss(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    labels: Int[Tensor, "batch seq_len"]
) -> Float:
    logprobs = logits.log_softmax(-1) # [batch seq_len d_vocab]
    batch_size, seq_len = labels.shape
    logprobs_correct = logprobs[t.arange(batch_size)[:, None], t.arange(seq_len)[None, :], labels] # [batch seq_len]
    logprobs_correct = eindex(logprobs, labels, "batch seq [batch seq]")
    avg_cross_entropy_loss = -logprobs_correct.mean().item()

    return avg_cross_entropy_loss

def get_loss_at_k(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    labels: Int[Tensor, "batch seq_len"],
    k: Int
) -> Float:
    seq_len = logits.shape[1]
    return get_loss(logits[:, k, :].unsqueeze(1), labels[:, k].unsqueeze(-1))

In [18]:
dataset = SumDataset(size=1000, num_digits=4, seed=42).to(device)
logits, cache = model.run_with_cache(dataset.toks)

model.reset_hooks()
logits_no_ablation = model(dataset.toks, return_type="logits")
loss_no_ablation = get_loss(logits_no_ablation[:, -5:-1], dataset.toks[:, -4:])

for ablation_type in ['mean']: #['zero', 'mean']:
    ablation_scores = t.zeros((model.cfg.n_layers, model.cfg.n_heads))
    ablation_scores_per_pos = t.zeros((model.cfg.n_layers, model.cfg.n_heads, 4))

    for layer in range(model.cfg.n_layers):
        for head in list(range(model.cfg.n_heads)):
            temp_hook_fn = functools.partial(
                head_ablation_hook,
                head_index_to_ablate=head,
                ablation_type=ablation_type)
            ablated_logits = model.run_with_hooks(
                dataset.toks,
                return_type="logits",
                fwd_hooks=[(utils.get_act_name("result", layer), temp_hook_fn)]
            )

            for k in range(4):
                loss_no_ablation_k = get_loss_at_k(logits_no_ablation[:, -5:-1], dataset.toks[:, -4:], k=k)
                ablated_loss_k = get_loss_at_k(ablated_logits[:, -5:-1], dataset.toks[:, -4:], k=k)
                ablation_scores_per_pos[layer, head, k] = ablated_loss_k - loss_no_ablation_k

            ablated_loss = get_loss(ablated_logits[:, -5:-1], dataset.toks[:, -4:])
            ablation_scores[layer, head] = ablated_loss - loss_no_ablation

    # display ablation scores, mean across position
    fig = px.imshow(
        ablation_scores.cpu().numpy(),
        title=f"Ablation loss diff, ablation_type: {ablation_type}",
        labels={"x": "Head", "y": "Layer"},
        text_auto=".2f",
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0,
        x=[f"{head}" for head in range(model.cfg.n_heads)],
        y = [f"{layer}" for layer in range(model.cfg.n_layers)],
        width=600, height=400)
    fig.update_layout(coloraxis_showscale=False)
    display(fig)

    # display ablation scores, per output digit position
    df = pd.DataFrame()
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            df_temp = pd.DataFrame({
                'X': range(4),
                'Y': ablation_scores_per_pos[layer, head],
                'Label': [f'H{layer}.{head}'] * 4
            })
            df = pd.concat([df, df_temp])

    fig = px.scatter(df, x='X', y='Y', color='Label',
                    title=f'Ablation loss diff by digit position, ablation_type: {ablation_type}',
                    labels={'X': 'Digit position', 'Y': 'Logit diff', 'Label': 'Head'},
                    width=600, height=400)
    fig.update_traces(mode='lines+markers')
    display(fig)

model.reset_hooks()

Observations:
- H0.1 and H0.2 appear to be the most critical overall
  - H0.1 is especially critical for digit positions 2, 3
  - H0.2 is especially critical for digit positions 1, 2, 3
- H0.0 seems to be the least critical overall

### QKV ablations

We can perform more fine-grained ablations, targeting specific activations.

In [19]:
def qkv_ablation_hook(
    qkv: Float[Tensor, "batch seq n_heads d_head"],
    hook: HookPoint,
    head_index_to_ablate: int,
    ablation_type: str, # either 'zero' or 'mean'
) -> Float[Tensor, "batch seq n_heads d_model"]:
    if ablation_type == 'zero':
        qkv[:,:,head_index_to_ablate,:] = 0
    elif ablation_type == 'mean':
        qkv[:,:,head_index_to_ablate] = qkv[:,:,head_index_to_ablate].mean(0, keepdim=True)
    return qkv

In [20]:
dataset = SumDataset(size=1000, num_digits=4, seed=42).to(device)
logits, cache = model.run_with_cache(dataset.toks)

model.reset_hooks()
logits_no_ablation = model(dataset.toks, return_type="logits")
loss_no_ablation = get_loss(logits_no_ablation[:, -5:-1], dataset.toks[:, -4:])

activation_types = ['q', 'k', 'v']

for ablation_type in ['mean']: #['zero', 'mean']:
    ablation_scores = t.zeros((len(activation_types), model.cfg.n_layers, model.cfg.n_heads))
    for i, activation_type in enumerate(activation_types):
        for layer in range(model.cfg.n_layers):
            for head in range(model.cfg.n_heads):
                temp_hook_fn = functools.partial(
                    qkv_ablation_hook,
                    head_index_to_ablate=head,
                    ablation_type=ablation_type)
                ablated_logits = model.run_with_hooks(
                    dataset.toks,
                    return_type="logits",
                    fwd_hooks=[(utils.get_act_name(activation_type, layer), temp_hook_fn)]
                )

                ablated_loss = get_loss(ablated_logits[:, -5:-1], dataset.toks[:, -4:])
                ablated_loss_diff = ablated_loss - loss_no_ablation

                ablation_scores[i, layer, head] = ablated_loss_diff

    fig = px.imshow(
        ablation_scores.view(len(activation_types), -1).cpu().numpy(),
        title=f"Ablation loss diff, ablation_type: {ablation_type}",
        labels={"x": "Head", "y": "Ablated activation type"},
        text_auto=".2f",
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0,
        x=[f"H{layer}.{head}" for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)],
        y = [f"{activation_type.upper()}" for activation_type in activation_types],
        width=600, height=500)
    fig.update_layout(coloraxis_showscale=False)
    display(fig)

Observations:
- Mean-ablating queries for L0 heads does not impact loss.
  - This suggests that L0 head queries are approximately the same across all distinct inputs, and are therefore probably determined by positional encodings (and the constant embeddings `+` and `=`).
- H0.1 and H0.2 seem to be the most critical components, with important functionality coming from their values.

### Summary of ablation analysis

- H0.1 and H0.2 are most critical, with most important functionality coming from their values.
- Mean-ablating L0 queries does not impact loss

## Layer 0 QK circuits

### Eyeballing L0 attention patterns

In [21]:
dataset = SumDataset(size=1000, num_digits=4, seed=42).to(device)
logits, cache = model.run_with_cache(dataset.toks)

attn_patterns = cache[utils.get_act_name("pattern", 0)]

for i in range(5):
    display(
        cv.attention.attention_patterns(
          tokens=dataset.str_toks[i],
          attention=attn_patterns[i],
          attention_head_names=[f"{l}.{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)])
    )

Observations:

- There's a common pattern where the $i^{th}$ digit of the output attends to the corresponding $i^{th}$ digits in both inputs. We'll call this attention pattern "CD" for "corresponding digit".

- H0.0: CD for position 0
- H0.1: CD for position 0, 2, 3
- H0.2: CD for position 1, 2, 3

### L0 QK circuit

From our QKV ablation analysis, we know that L0 queries are roughly constant.

We will therefore take the mean L0 queries and use them to compute the L0 QK circuits.

In [22]:
# visualize layer 0 QK circuits
layer = 0

mean_q = cache["q",0].mean(dim=0) # seq head d_head

W_emb_pos = t.concat((model.W_pos, model.W_E), dim=0)
W_emb_pos_scaled = W_emb_pos / cache["scale", 0, "ln1"][:, :, 0, 0].mean()

for head in range(model.cfg.n_heads):
    QK_full = mean_q[:, head, :] @ model.W_K[layer, head].T @ W_emb_pos_scaled.T
    QK_full /= model.cfg.d_head**0.5
    # make pos x pos part lower triangular
    QK_full[:, :14] = t.tril(QK_full[:,:14])

    fig = px.imshow(
        QK_full[:-1, :-1].detach().cpu().numpy(),
        title=f"H{layer}.{head} QK circuit",
        labels={"x": "Key", "y": "Query", "color": "QK weight"},
        width=600,
        height=500,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0,
        x=[f"p={i}" for i in range(4*3 + 2 - 1)]+dataset.vocab,
        y=[f"p={i}" for i in range(4*3 + 2 - 1)],
    )

    display(fig)

There is some interaction between the positional queries and the embedding keys - some queries are biased towards extreme embedding values, and others are biased away from extreme embedding values.

However, the most significant portion of the L0 QK circuits is the positional encoding portion. Let's zoom in on this portion.

In [23]:
# visualize layer 0 QK circuits
layer = 0

mean_q = cache["q",0].mean(dim=0) # seq head d_head

W_pos = model.W_pos.clone().detach()

# add in constant embeddings +, = at positions 4, 9
W_pos[4, :] = 0.5 * (model.W_pos[4, :] + model.W_E[10, :])
W_pos[9, :] = 0.5 * (model.W_pos[9, :] + model.W_E[11, :])

W_pos_scaled = W_pos / cache["scale", 0, "ln1"][:, :, 0, 0].mean()

for head in range(model.cfg.n_heads):
    QK_pos = mean_q[:, head, :] @ model.W_K[layer, head].T @ W_pos_scaled.T
    QK_pos /= model.cfg.d_head**0.5

    # causal masking
    QK_pos = QK_pos.masked_fill(t.tril(t.ones_like(QK_pos)) == 0, float("-Inf"))

    fig = px.imshow(
        t.softmax(QK_pos, dim=-1)[:-1, :-1].detach().cpu().numpy(),
        title=f"Head {layer}.{head} QK circuit (pos)",
        labels={"x": "Key", "y": "Query", "color": "QK weight"},
        width=400,
        height=400,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0,
        x=[f"p={i}" for i in range(4*3 + 2 - 1)],
        y=[f"p={i}" for i in range(4*3 + 2 - 1)],
    )

    display(fig)

Observations:
- These figures confirm our claims about corresponding digit detectors:
  - H0.0: CD for position 0
  - H0.1: CD for position 0, 2, 3
  - H0.2: CD for position 1, 2, 3

- There is also some other interesting behavior (possibly used for detecting carries):
  - H0.0: `pos=8` → `pos=3`, `pos=8`
  - H0.0: `pos=11` → `pos=3`, `pos=8`
  - H0.1: `pos=8` → `pos=2`, `pos=7`
  - H0.2: `pos=9` → `pos=1`, `pos=6`

- A reasonable hypothesis would be that corresponding digit detectors compute the sum without taking into account carries, and then somehow the other behavior detects when a carry is necessary
  - To predict output digit 0 (at `pos=9`): H0.0 and H0.1 compute `pos=0 + pos=5`, and H0.2 detects a carry from `pos=1 + pos=6` (and stores this carry info in `pos=9`)
  - To predict output digit 1 (at `pos=10`): H0.2 computes `pos=1 + pos=6`, and H0.0 and H0.1 detect a carry from `pos=2 + pos=7` (and stores this info in `pos=8`)
  - To predict output digit 2 (at `pos=11`): H0.1 and H0.2 compute `pos=2 + pos=7`, and H0.0 detects a carry from `pos=3 + pos=8` (and stores this carry info in `pos=11`)
  - To predict output digit 3 (at `pos=12`): H0.1 and H0.2 compute `pos=3 + pos=8`, and no carry information is required

## Direct logit attribution

### Simple DLA

In [24]:
def get_dla_simple(cache: ActivationCache) -> Tuple[Float[Tensor, "batch n seq vocab"], List[str]]:
    labels = ['embed', 'pos_embed', 'H0.0', 'H0.1', 'H0.2', 'H1.0', 'H1.1', 'H1.2']

    emb_pos_components = cache.decompose_resid(layer=0, incl_embeds=True)
    attn_components    = cache.stack_head_results(layer=-1)
    components = t.concat((emb_pos_components, attn_components), dim=0)

    t.testing.assert_close(components.sum(dim=0)+ model.b_O[0]+ model.b_O[1], cache["resid_post", 1])

    components = components / cache[utils.get_act_name('scale')]
    components = einops.rearrange(components, 'n batch seq d -> batch n seq d')
    dla = components @ model.W_U

    return dla, labels

In [25]:
dataset = SumDataset(size=1000, num_digits=4, seed=42).to(device)
logits, cache = model.run_with_cache(dataset.toks)

dla_simple, labels_simple = get_dla_simple(cache)

current_pos = -5
for current_ex in range(5):
    print()
    print("".join(dataset.str_toks[current_ex]))
    display(
        px.imshow(
            dla_simple[current_ex, :, current_pos, :-2].detach().cpu().numpy(),
            title=f"DLA for ex {current_ex}, pos {current_pos}",
            labels={"x": "Output logit", "y": "Component", "color": "Logit attribution"},
            width=500, height=400,
            color_continuous_scale="RdBu", color_continuous_midpoint=0,
            x=dataset.vocab[:-2], y=labels_simple,
        )
    )


2968+1606=4574



2262+0672=2934



1501+3261=4762



4300+4863=9163



1746+2946=4692


Observations:
- L1 head DLA values dominate over L0 heads, at least for predicting position -5
- It's not always the case that a single head gives the correct answer. It looks like multiple heads work together to steer the answer in the right direction.
  - Looking at example 0, H1.0 steers towards `5` while H1.1 steers towards `3`. When the logits are summed, they give the right answer of `4`.

We can check that L1 heads are dominant across all positions, not just position -5.

In [26]:
dla_simple, labels_simple = get_dla_simple(cache)
dla_simple = dla_simple[:, :, -5:-1, :] # batch n 4 d_vocab
dla_simple = einops.rearrange(dla_simple, 'b n seq d -> n b seq d')
correct_dla_simple = dla_simple[
    t.arange(dla_simple.shape[0])[:, None, None],
    t.arange(dla_simple.shape[1])[None, :, None],
    t.arange(dla_simple.shape[2])[None, None, :],
    dataset.toks[:, -4:]]

correct_dla_simple_mean = correct_dla_simple.mean(dim=1).detach().cpu().numpy()
correct_dla_simple_std = correct_dla_simple.std(dim=1).detach().cpu().numpy()

df_list = []
for i in range(8):
    for j in range(4):
        df_list.append({
            "Component": labels_simple[i],
            "Digit position": j,
            "Mean": correct_dla_simple_mean[i, j],
            "Std": correct_dla_simple_std[i, j]
        })

df = pd.DataFrame(df_list)

fig = px.bar(df,
             x='Digit position',
             y='Mean',
             color='Component',
             barmode='group',
             title='Correct DLA by component (mean)',
             height=400, width=800)
fig.show()

fig = px.bar(df,
             x='Digit position',
             y='Std',
             color='Component',
             barmode='group',
             title='Correct DLA by component (std)',
             height=400, width=800)
fig.show()

### Fine-grained DLA

We can get even more granular, getting the direct logit attributions through each path through the layer 1 heads (inspired by the [August solutions](https://arena-ch1-transformers.streamlit.app/Monthly_Algorithmic_Problems#monthly-algorithmic-challenge-september-2023-sum-of-two-numbers)).

In [27]:
def get_dla(cache: ActivationCache) -> Tuple[Float[Tensor, "batch n seq vocab"], List[str]]:
    resid_decomposed = t.stack([
        cache["embed"] + cache["pos_embed"],
        *[cache["result", 0][:, :, head] for head in range(3)]
    ], dim=1)
    t.testing.assert_close(resid_decomposed.sum(1) + model.b_O[0], cache["resid_post", 0])

    dla = (resid_decomposed / cache["scale"].unsqueeze(1)) @ model.W_U

    # Get DLA from paths through layer 1
    resid_decomposed_post_W_OV = einops.einsum(
        (resid_decomposed / cache["scale", 0, "ln1"][:, None, :, 0]),
        model.W_V[1] @ model.W_O[1],
        "batch decomp seqK d_model, head d_model d_model_out -> batch decomp seqK head d_model_out"
    )
    resid_decomposed_post_attn = einops.einsum(
        resid_decomposed_post_W_OV,
        cache["pattern", 1],
        "batch decomp seqK head d_model, batch head seqQ seqK -> batch decomp seqQ head d_model"
    )
    new_dla = (resid_decomposed_post_attn / cache["scale"][:, None, :, None]) @ model.W_U
    dla = t.concat([
        dla,
        einops.rearrange(new_dla, "batch decomp seq head vocab -> batch (decomp head) seq vocab")
    ], dim=1)

    layer0 = [" "] + [f"0.{i} " for i in range(3)]
    layer1 = [f"1.{i} " for i in range(3)]
    labels = layer0 + [f"{c0}➔ {c1}".lstrip(" ➔ ") for c0 in layer0 for c1 in layer1]
    return dla, labels

In [28]:
dataset = SumDataset(size=1000, num_digits=4, seed=42).to(device)
logits, cache = model.run_with_cache(dataset.toks)

dla, labels = get_dla(cache)

pos = -5
for ex in range(5):
    print()
    print("".join(dataset.str_toks[ex]))
    display(px.imshow(
        dla[ex, :, pos, :-2].detach().cpu().numpy(),
        title=f"Direct logit attribution for ex {ex}, pos {pos}",
        labels={"x": "Output logit", "y": "Component", "color": "Logit attribution"},
        width=500, height=500,
        color_continuous_scale="RdBu", color_continuous_midpoint=0,
        x=dataset.vocab[:-2], y=labels,
    ))


2968+1606=4574



2262+0672=2934



1501+3261=4762



4300+4863=9163



1746+2946=4692


### DLA in Fourier space

We can visualize each component's DLA in Fourier space.

In [29]:
def plot_dla(
    dla: Float[Tensor, "batch n seq vocab"],
    labels: List[str],
    ex: int,
    show_sum: bool,
    show_pred: bool,
    show_correct_dir: bool):

    W_U_dot = model.W_U[:, :-2].T @ model.W_U[:, :-2]
    W_U_fft = t.fft.rfft(W_U_dot, dim=0)
    W_U_magnitude = t.sqrt(W_U_fft.real**2 + W_U_fft.imag**2)
    W_U_phase = t.atan2(W_U_fft.imag, W_U_fft.real)

    fig = make_subplots(
        rows=1, cols=4,
        column_widths=[0.2, 0.2, 0.2, 0.2],
        specs=[[{'type': 'polar'}, {'type': 'polar'}, {'type': 'polar'}, {'type': 'polar'}]],
        subplot_titles=(f"Digit 0 ({dataset.toks[ex, -4]})", f"Digit 1 ({dataset.toks[ex, -3]})", f"Digit 2 ({dataset.toks[ex, -2]})", f"Digit 3 ({dataset.toks[ex, -1]})")
    )

    for pos in range(-5, -1):
        dla_ftt = t.fft.rfft(dla[ex, :, pos, :-2], dim=-1)
        magnitude = t.sqrt(dla_ftt.real**2 + dla_ftt.imag**2)
        phase = t.atan2(dla_ftt.imag, dla_ftt.real)

        fig.add_trace(
            go.Scatterpolar(
                r=magnitude[:, 1].detach().cpu().numpy(),
                theta=phase[:, 1].detach().cpu().numpy(),
                thetaunit='radians',
                marker=dict(size=5, color='blue'),
                mode='markers' if len(labels) > 5 else 'markers+text',
                name=f"DLA {pos+5}",
                textposition="bottom center",
                text=labels,
            ),
            row=1, col=pos+6
        )


        dla_sum_fft = t.fft.rfft(dla[ex, :, pos, :-2].sum(dim=0), dim=-1).unsqueeze(0)
        sum_magnitude = t.sqrt(dla_sum_fft.real**2 + dla_sum_fft.imag**2)
        sum_phase = t.atan2(dla_sum_fft.imag, dla_sum_fft.real)
        if show_sum:
            fig.add_trace(
                go.Scatterpolar(
                    r=np.concatenate([[0], sum_magnitude[:, 1].detach().cpu().numpy()]),
                    theta=np.concatenate([[0], sum_phase[:, 1].detach().cpu().numpy()]),
                    thetaunit='radians',
                    marker=dict(size=5, color='red'),
                    mode='lines',
                    name=f"sum {pos+5}",
                ),
                row=1, col=pos+6
            )

        pred_fft = t.fft.rfft(logits[ex, pos, :-2], dim=-1).unsqueeze(0)
        pred_magnitude = t.sqrt(pred_fft.real**2 + pred_fft.imag**2)
        pred_phase = t.atan2(pred_fft.imag, pred_fft.real)
        if show_pred:
            fig.add_trace(
                go.Scatterpolar(
                    r=np.concatenate([[0], pred_magnitude[:, 1].detach().cpu().numpy()]),
                    theta=np.concatenate([[0], pred_phase[:, 1].detach().cpu().numpy()]),
                    thetaunit='radians',
                    marker=dict(size=5, color='yellow'),
                    mode='lines',
                    name=f"prediction {pos+5}",
                ),
                row=1, col=pos+6
            )

        correct_phase = W_U_phase[1, dataset.toks[ex,pos+1]].detach().cpu().numpy()
        max_rad = max(
            magnitude[:, 1].max().detach().cpu().numpy(),
            pred_magnitude[:, 1].max().detach().cpu().numpy(),
            sum_magnitude[:, 1].max().detach().cpu().numpy(),
        )
        if show_correct_dir:
            fig.add_trace(
                go.Scatterpolar(
                    r=[0, max_rad],
                    theta=[correct_phase, correct_phase],
                    thetaunit='radians',
                    mode='lines',
                    name=f"correct dir {pos+5}",
                    line=dict(width=2, color='green'),
                ),
                row=1, col=pos+6
            )

    fig.update_layout(
        title_text=f'Ex {ex}: {"".join(dataset.str_toks[ex])}',
        height=400,
        width=1200,
    )

    annot =list(fig.layout.annotations)
    for i in range(4): annot[i].y = 1.1
    fig.layout.annotations = annot

    fig.update_polars(radialaxis=dict(showticklabels=False), angularaxis=dict(showticklabels=False))

    fig.show()

In [30]:
logits, cache = model.run_with_cache(dataset.toks)

dla_simple, labels_simple = get_dla_simple(cache)

for ex in range(5):
    plot_dla(dla_simple[:, -3:], labels_simple[-3:], ex, show_sum=False, show_pred=True, show_correct_dir=True)

We can also plot each componenet's DLAs across the entire batch to get a sense of each component's directional trend.

In [31]:
def plot_dla_by_all_path_components(
    dla: Float[Tensor, "batch n seq vocab"],
    labels: List[str],
):
    assert(len(labels)==16)

    W_U_dot = model.W_U[:, :-2].T @ model.W_U[:, :-2]
    W_U_fft = t.fft.rfft(W_U_dot, dim=0)
    W_U_magnitude = t.sqrt(W_U_fft.real**2 + W_U_fft.imag**2)
    W_U_phase = t.atan2(W_U_fft.imag, W_U_fft.real)

    fig = make_subplots(
        rows=5, cols=3,
        column_widths=[0.2, 0.2, 0.2],
        specs=[[{'type': 'polar'}]*3]*5,
        subplot_titles=labels[1:]
    )

    colors = ['blue', 'green', 'red', 'yellow']

    for component in range(1, len(labels)):
        for pos in range(-5, -1):
            dla_ftt = t.fft.rfft(dla[:, component, pos, :-2], dim=-1)
            magnitude = t.sqrt(dla_ftt.real**2 + dla_ftt.imag**2)
            phase = t.atan2(dla_ftt.imag, dla_ftt.real)

            fig.add_trace(
                go.Scatterpolar(
                    r=magnitude[:, 1].detach().cpu().numpy(),
                    theta=phase[:, 1].detach().cpu().numpy(),
                    thetaunit='radians',
                    marker=dict(size=3, color=colors[pos+5]),
                    mode='markers',
                    name=f"{labels[component]} - digit {pos+5}",
                    textposition="bottom center",
                    # showlegend=False
                ),
                col=((component-1)%3+1), row=((component-1)//3 +1)
            )

    fig.update_layout(
        title_text=f'DLAs across batch, by component',
        height=1800,
        width=1400,
    )

    annot =list(fig.layout.annotations)
    fig.layout.annotations = annot

    fig.update_polars(radialaxis=dict(showticklabels=False), angularaxis=dict(showticklabels=False))

    fig.show()

In [32]:
logits, cache = model.run_with_cache(dataset.toks)

dla, labels = get_dla(cache)
plot_dla_by_all_path_components(dla, labels)

Observations;
- Some of the components tend to contribute in a particular direction. For example, the path 0.1 ➔ 1.0 seems to contribute along the 45° direction
- Even for an individual component, DLA directions differ based on digit
- Overall, these charts are quite messy and hard to make sense of

### Summary of DLA analysis

- L1 heads are primarily responsible for direct logit attribution
- In some cases, a single head is solely responsible for the correct logit.
- In other cases, multiple heads issue strong predictions, and the correct answer lies in the middle.

## Layer 1 OV

### Embeddings

We'll take a look at the OV circuits on embedding inputs.

In [33]:
W_OV = model.W_V @ model.W_O

for head0 in range(model.cfg.n_heads):
    for head1 in range(model.cfg.n_heads):

        OV = model.W_E @ W_OV[0, head0] @ W_OV[1, head1] @ model.W_U
        OV = OV[:, :-2]

        fig = make_subplots(
            rows=1, cols=2,
            column_widths=[0.1, 0.2],
            specs=[[{'type': 'heatmap'}, {'type': 'polar'}]],
            subplot_titles=(f"OV circuit", "FFT polar")
        )

        fig.add_trace(
            go.Heatmap(z=OV[:-2, :].detach().cpu().numpy(),
                      colorscale="RdBu",
                      zmid=0, showscale=False),
            row=1, col=1
        )

        OV_fft = t.fft.rfft(OV, dim=1)

        magnitude = t.sqrt(OV_fft.real**2 + OV_fft.imag**2)
        phase = t.atan2(OV_fft.imag, OV_fft.real)

        fig.add_trace(
            go.Scatterpolar(
                r=magnitude[:-2, 1].detach().cpu().numpy(),
                theta=phase[:-2, 1].detach().cpu().numpy(),
                thetaunit='radians',
                marker=dict(size=5, color='blue'),
                name='digits',
                mode='markers+text',
                textposition="bottom center",
                text=dataset.vocab[:-2],
            ),
            row=1, col=2
        )

        fig.update_layout(
            title_text=f'0.{head0} ➔ 1.{head1}',
            height=400,
            width=800,
        )

        annot =list(fig.layout.annotations)
        annot[1].y = 1.1
        fig.layout.annotations = annot

        tick_vals = list(range(len(dataset.vocab)))
        tick_text = dataset.vocab
        fig.update_xaxes(title_text="Output logit", tickvals=tick_vals[:-2], ticktext=tick_text[:-2], row=1, col=1)
        fig.update_yaxes(title_text="Input emb", tickvals=tick_vals[:-2], ticktext=tick_text[:-2], row=1, col=1)
        fig.update_yaxes(autorange="reversed", row=1, col=1)

        fig.update_polars(radialaxis=dict(showticklabels=True))

        fig.show()

### Summed embeddings

Rather than evaluating the OV circuits on each digit embedding individually, we can see how the OV circuits map the sum of each digit embedding pair.

In [34]:
n_digits = model.W_E[:-2].shape[0]
d_model  = model.W_E.shape[1]

W_E_sums = t.zeros((n_digits, n_digits, d_model)).cuda()
for i in range(n_digits):
    for j in range(n_digits):
        W_E_sums[i, j, :] = model.W_E[i] + model.W_E[j]

In [35]:
W_OV = model.W_V @ model.W_O

for head0 in range(model.cfg.n_heads):
    for head1 in range(model.cfg.n_heads):
        W_E_sums_flat = einops.rearrange(W_E_sums, 'x y d -> (x y) d')
        OV = W_E_sums_flat @ W_OV[0, head0] @ W_OV[1, head1] @ model.W_U

        OV = OV[:, :-2]

        fig = make_subplots(
            rows=1, cols=2,
            column_widths=[0.1, 0.2],
            specs=[[{'type': 'heatmap'}, {'type': 'polar'}]],
            subplot_titles=(f"OV circuit", "FFT polar")
        )

        fig.add_trace(
            go.Heatmap(z=OV.detach().cpu().numpy(),
                      colorscale="RdBu",
                      zmid=0, showscale=False),
            row=1, col=1
        )

        OV_fft = t.fft.rfft(OV, dim=1)

        magnitude = t.sqrt(OV_fft.real**2 + OV_fft.imag**2)
        phase = t.atan2(OV_fft.imag, OV_fft.real)

        fig.add_trace(
            go.Scatterpolar(
                r=magnitude[:, 1].detach().cpu().numpy(),
                theta=phase[:, 1].detach().cpu().numpy(),
                thetaunit='radians',
                marker=dict(size=5, color='blue'),
                name='digits',
                mode='markers',
                textposition="bottom center",
                text=[f"{i}+{j}" for i in range(10) for j in range(10)],
            ),
            row=1, col=2
        )

        fig.update_layout(
            title_text=f'0.{head0} ➔ 1.{head1}',
            height=400,
            width=800,
        )

        annot =list(fig.layout.annotations)
        annot[1].y = 1.1
        fig.layout.annotations = annot

        tick_text = [f"{i}+{j}" for i in range(10) for j in range(10)]
        tick_vals = list(range(len(tick_text)))

        fig.update_xaxes(title_text="Output logit", tickvals=list(range(len(dataset.vocab[:-2]))), ticktext=dataset.vocab[:-2], row=1, col=1)
        fig.update_yaxes(title_text="Input emb (sum)", tickvals=tick_vals, ticktext=tick_text, row=1, col=1, showticklabels=False)
        fig.update_yaxes(autorange="reversed", row=1, col=1)

        fig.update_polars(radialaxis=dict(showticklabels=True))

        fig.show()

Observations:
- H0.0 → H1.* seems to map all pairs in similar directions.
- H0.1 → H1.* appears to treat 3 custers differently: both digits <5, one digit <5, no digit <5
- H0.2 → H1.* is the most expressive, with more complex patterns


## Separating carries

The mechanisms for predicting a digit with a carry vs without a carry are presumably different.

We can separate the examples that require carry logic from the examples that do not, and analyze both side-by-side to check how the model deals with carries.

In [36]:
def requires_carry(toks: Int[Tensor, "seq_len"], num_digits: int) -> int:
    for i in range(num_digits):
        if toks[i] + toks[i + num_digits + 1] % 10 != toks[i + 2*num_digits + 2]:
            return True
    return False

def requires_carry_at_pos(toks: Int[Tensor, "seq_len"], num_digits: int, pos: int) -> int:
    if toks[pos] + toks[pos + num_digits + 1] % 10 != toks[pos + 2*num_digits + 2]:
        return True
    return False

In [37]:
dataset = SumDataset(size=1000, num_digits=4, seed=42).to(device)

carries = t.zeros((dataset.size)).bool()
carries_at_pos0 = t.zeros((dataset.size)).bool()
carries_at_pos1 = t.zeros((dataset.size)).bool()
carries_at_pos2 = t.zeros((dataset.size)).bool()

for i in range(dataset.size):
    carries[i] = requires_carry(dataset.toks[i], num_digits=4)
    carries_at_pos0[i] = requires_carry_at_pos(dataset.toks[i], num_digits=4, pos=0)
    carries_at_pos1[i] = requires_carry_at_pos(dataset.toks[i], num_digits=4, pos=1)
    carries_at_pos2[i] = requires_carry_at_pos(dataset.toks[i], num_digits=4, pos=2)

print(f"# of carries: {carries.sum()}")
print(f"# of non-carries: {(~carries).sum()}")
print(f"# of carries at pos 0: {carries_at_pos0.sum()}")
print(f"# of carries at pos 1: {carries_at_pos1.sum()}")
print(f"# of carries at pos 2: {carries_at_pos2.sum()}")

# of carries: 826
# of non-carries: 174
# of carries at pos 0: 515
# of carries at pos 1: 715
# of carries at pos 2: 689


### Ablation

In [38]:
dataset = SumDataset(size=1000, num_digits=4, seed=42).to(device)

activation_types = ['q', 'k', 'v']

for ablation_type in ['mean']: #['zero', 'mean']:
    for custom_dataset_label, dataset_mask in zip(["carries", "non-carries"], [carries, ~carries]):
        model.reset_hooks()
        custom_dataset = dataset[dataset_mask]
        logits_no_ablation = model(custom_dataset, return_type="logits")
        loss_no_ablation = get_loss(logits_no_ablation[:, -5:-1], custom_dataset[:, -4:])

        ablation_scores = t.zeros((len(activation_types), model.cfg.n_layers, model.cfg.n_heads))
        for i, activation_type in enumerate(activation_types):
            for layer in range(model.cfg.n_layers):
                for head in range(model.cfg.n_heads):
                    temp_hook_fn = functools.partial(
                        qkv_ablation_hook,
                        head_index_to_ablate=head,
                        ablation_type=ablation_type)
                    ablated_logits = model.run_with_hooks(
                        custom_dataset,
                        return_type="logits",
                        fwd_hooks=[(utils.get_act_name(activation_type, layer), temp_hook_fn)]
                    )

                    ablated_loss = get_loss(ablated_logits[:, -5:-1], custom_dataset[:, -4:])
                    ablated_loss_diff = ablated_loss - loss_no_ablation

                    ablation_scores[i, layer, head] = ablated_loss_diff

        fig = px.imshow(
            ablation_scores.view(len(activation_types), -1).cpu().numpy(),
            title=f"Ablation loss diff, ablation_type: {ablation_type}, dataset: {custom_dataset_label}",
            labels={"x": "Head", "y": "Ablated activation type"},
            text_auto=".2f",
            color_continuous_scale="RdBu",
            color_continuous_midpoint=0,
            x=[f"H{layer}.{head}" for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)],
            y = [f"{activation_type.upper()}" for activation_type in activation_types],
            width=600, height=500)
        fig.update_layout(coloraxis_showscale=False)
        display(fig)

Observations:
- When evaluated on non-carries, performing mean ablation on L1 keys (and H1.0 values) doesn't impact loss
  - This suggests that the L1 keys are important to the carry functionality, but not much else

### L1 attention

We can check to see if there is a difference in L1 attention patterns between carries and non-carries:

In [39]:
def plot_l1_attention_stats(dataset: SumDataset, mask: Bool[Tensor, "batch"], source_digit: int, label: str):
        logits, cache = model.run_with_cache(dataset[mask])
        l1_attn = cache["pattern", 1]

        l1_attn_from_digit = l1_attn[:, :, source_digit, :] # b h target
        l1_attn_from_digit_mean = l1_attn_from_digit.mean(dim=0)
        l1_attn_from_digit_std = l1_attn_from_digit.std(dim=0)

        data = []
        for head1 in range(model.cfg.n_heads):
            data.extend([{
                'X': i,
                'Y': l1_attn_from_digit_mean[head1, i].detach().cpu().item(),
                'Y_std': l1_attn_from_digit_std[head1, i].detach().cpu().item(),
                'Label': f"H1.{head1}"
            } for i in range(source_digit + 1)])

        df = pd.DataFrame(data)
        fig = px.bar(df, x='X', y='Y', color='Label',
                    title=f'Average L1 attention from position {source_digit}, {label}',
                    labels={'X': 'Target position', 'Y': 'Attention magnitude'},
                    error_y='Y_std',
                    width=600, height=400)
        fig.update_layout(barmode='group')
        display(fig)

In [40]:
dataset = SumDataset(size=1000, num_digits=4, seed=42).to(device)

for carries_mask, digit in zip([carries_at_pos0, carries_at_pos1, carries_at_pos2], [9, 10, 11]):
    for label, mask in zip(["carries", "non-carries"],[carries_mask, ~carries_mask]):
        plot_l1_attention_stats(dataset, mask, digit, label)
plot_l1_attention_stats(dataset, t.ones_like(carries).bool(), 12, "non-carries")

We can see very clear differences in attention patterns between carries and non-carries. This sugggests that the mechanism for separating carry logic from non-carry logic comes from differences in L1 attention patterns.

## Further work

Most of the model is still mysterious. The following are some directions of further investigation.

- Understanding OV circuits for each component
  - We were able to visualize them in the Fourier basis, but they were hard to make sense of.
  - How do the L1 OVs work together to output the correct answer?
- Understanding L0 and L1 attention mechanisms
  - Examining L1 attention on the carry vs non-carry dataset revealed differences in L1 attention
  - What is the mechanism by which this difference in attention is achieved?
  - How do the L0 heads encode carry signal?
  - How does this difference in L1 attention lead to `+1`-ing the output value?

## Appendix

### W_U Fourier analysis by hand

In [41]:
# based off of https://arena-ch1-transformers.streamlit.app/[1.5]_Grokking_and_Modular_Arithmetic
def make_fourier_basis(p: int) -> Tuple[t.Tensor, List[str]]:
    '''
    Returns a pair `fourier_basis, fourier_basis_names`, where `fourier_basis` is
    a `(p, p)` tensor whose rows are Fourier components and `fourier_basis_names`
    is a list of length `p` containing the names of the Fourier components (e.g.
    `["const", "cos 1", "sin 1", ...]`).
    '''
    # Initialize the tensor to store Fourier components
    fourier_basis = t.ones(p, p)

    # Initialize the list to store component names
    fourier_basis_names = ['Const']

    if p % 2 == 0:  # If p is even
        freq_count = p // 2  # Number of distinct frequency components
        for i in range(1, freq_count):
            fourier_basis[2*i-1] = t.cos(2*t.pi*t.arange(p)*i/p)
            fourier_basis[2*i] = t.sin(2*t.pi*t.arange(p)*i/p)
            fourier_basis_names.extend([f'cos {i}', f'sin {i}'])

        fourier_basis[-1] = t.cos(t.pi * t.arange(p))
        fourier_basis_names.append(f'cos {p // 2}')

    else:  # If p is odd
        for i in range(1, (p // 2) + 1):
            fourier_basis[2*i-1] = t.cos(2*t.pi*t.arange(p)*i/p)
            fourier_basis[2*i] = t.sin(2*t.pi*t.arange(p)*i/p)
            fourier_basis_names.extend([f'cos {i}', f'sin {i}'])

    # Normalize vectors, and return them
    fourier_basis /= fourier_basis.norm(dim=1, keepdim=True)
    return fourier_basis.to(device), fourier_basis_names

In [42]:
# transform the rows of W_U to fourier domain
fourier_basis, fourier_basis_names = make_fourier_basis(10)
W_U_fourier = model.W_U[:, :-2] @ fourier_basis.T

fig = px.imshow(
    W_U_fourier.detach().cpu().numpy(),
    title=f"W_U after fourier transform",
    labels={"x": "d_model", "y": "Fourier component"},
    x=fourier_basis_names,
    width=400,
    height=400,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
)
display(fig)

This confirms that there is primarily only one frequency present. This frequency corresponds to $$\omega_{1} = 1 \cdot \frac{2\pi}{10}$$

Then each non-zero row $r$ of `W_U` can each be effectively written as $$r_i = A_r \cos(i \cdot \omega_1) + B_r \sin(i \cdot \omega_1)$$
for $i = 0, 1, \ldots, 9$, where $A_r$ and $B_r$ represent the amplitude of cos and sin components, respectively.

In [43]:
W_U_fourier = model.W_U[:, :-2] @ fourier_basis.T
W_U_fourier_first_freq = t.zeros_like(W_U_fourier)
W_U_fourier_first_freq[:, 1] = W_U_fourier[:, 1]
W_U_fourier_first_freq[:, 2] = W_U_fourier[:, 2]

W_U_reconstructed = W_U_fourier_first_freq @ fourier_basis

fig = px.imshow(
    W_U_reconstructed.detach().cpu().numpy(),
    title=f"W_U (reconstructed from first freq)",
    labels={"x": "d_model", "y": "Fourier component"},
    x=fourier_basis_names,
    width=400,
    height=400,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
)
display(fig)

fig = px.imshow(
    model.W_U[:, :-2].detach().cpu().numpy(),
    title=f"W_U (original)",
    labels={"x": "d_model", "y": "Fourier component"},
    x=fourier_basis_names,
    width=400,
    height=400,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
)
display(fig)