<a href="https://colab.research.google.com/github/andyrdt/mi/blob/main/ARENA/monthly_algorithmic_problems/11_2023/Cumulative_Sum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Monthly Algorithmic Challenge (November 2023): Cumulative Sum

This post is the fifth in the sequence of monthly mechanistic interpretability challenges. They are designed in the spirit of [Stephen Casper's challenges](https://www.lesswrong.com/posts/KSHqLzQscwJnv44T8/eis-vii-a-challenge-for-mechanists), but with the more specific aim of working well in the context of the rest of the ARENA material, and helping people put into practice all the things they've learned so far.


If you prefer, you can access the Streamlit page [here](https://arena-ch1-transformers.streamlit.app/Monthly_Algorithmic_Problems).

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/cumsum2.png" width="350">

## Setup

In [1]:
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys

if IN_COLAB:
    # Install packages
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/eindex.git
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

    # Code to download the necessary files (e.g. solutions, test funcs)
    import os, sys
    if not os.path.exists("chapter1_transformers"):
        !curl -o /content/main.zip https://codeload.github.com/callummcdougall/ARENA_2.0/zip/refs/heads/main
        !unzip /content/main.zip 'ARENA_2.0-main/chapter1_transformers/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter1_transformers/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter1_transformers", "chapter1_transformers")
        os.rmdir("ARENA_2.0-main")
        os.chdir("chapter1_transformers/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Collecting git+https://github.com/callummcdougall/eindex.git
  Cloning https://github.com/callummcdougall/eindex.git to /tmp/pip-req-build-qws936pa
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/eindex.git /tmp/pip-req-build-qws936pa
  Resolved https://github.com/callummcdougall/eindex.git to commit 0253192423088321281d563ae590a2ba60fb176a
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
  Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-4vyygf82
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-4vyygf82
  Resolved https://github.com/callummcdougall/CircuitsVis.git to commit 1e6129d08cae7af9242d9ab5d3ed322dd44b4dd3
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyp

In [2]:
import torch as t
from pathlib import Path
from eindex import eindex
from transformer_lens import HookedTransformer

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "monthly_algorithmic_problems" / "november23_cumsum"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from monthly_algorithmic_problems.november23_cumsum.dataset import CumsumDataset
from monthly_algorithmic_problems.november23_cumsum.model import create_model
from plotly_utils import hist, bar, imshow

device = t.device("cuda" if t.cuda.is_available() else "cpu")



In [3]:
import functools
import circuitsvis as cv
import einops
import plotly.express as px
import plotly.graph_objects as go

from torch import Tensor
from transformer_lens.hook_points import HookPoint
from jaxtyping import Float
from transformer_lens import utils

## Task & Dataset

The problem for this month is interpreting a model which has been trained to classify the cumulative sum of a sequence.

The model is fed sequences of integers, and is trained to classify the cumulative sum at a given sequence position. There are 3 possible classifications:

* 0 (if the cumsum is negative),
* 1 (if the cumsum is zero),
* 2 (if the cumsum is positive).

Here is an example (and also a demonstration of all the important attributes of the dataset class you'll be using):

In [4]:
dataset = CumsumDataset(size=1, seq_len=6, max_value=3, seed=40)

print(dataset[0]) # same as (dataset.toks[0], dataset.labels[0])

print(", ".join(dataset.str_toks[0])) # inputs to the model

print(", ".join(dataset.str_labels[0])) # whether the cumsum of inputs is strictly positive

(tensor([ 0,  1, -3, -3, -2,  3]), tensor([1, 2, 0, 0, 0, 0]))
+0, +1, -3, -3, -2, +3
zero, pos, neg, neg, neg, neg


The relevant files can be found at:

```
chapter1_transformers/
└── exercises/
    └── monthly_algorithmic_problems/
        └── november23_cumsum/
            ├── model.py               # code to create the model
            ├── dataset.py             # code to define the dataset
            ├── training.py            # code to training the model
            └── training_model.ipynb   # actual training script
```

## Model

The model is **not attention only**. It has one attention layer with a single head, and one MLP layer. It does *not* have layernorm at the end of the model. It was trained with weight decay, and an Adam optimizer with linearly decaying learning rate.

You can load the model in as follows. Note that this code is different to previous months, because we've removed the layernorm folding.


In [5]:
filename = section_dir / "cumsum_model.pt"

model = create_model(
    max_value=5,
    seq_len=20,
    seed=0,
    d_model=24,
    d_head=12,
    n_layers=1,
    n_heads=1,
    normalization_type=None,
    d_mlp=8,
)

state_dict = t.load(filename)

state_dict = model.center_writing_weights(t.load(filename))
state_dict = model.center_unembed(state_dict)
state_dict = model.fold_value_biases(state_dict)
model.load_state_dict(state_dict, strict=False);

> **Important announcement** - a mistake was found in the inital setup of this problem, wherein the dataset tokens were negative and causing negative indexing into the embedding matrix. You should use the functions `fix_dataset` and `fix_model` to fix this problem.

In [6]:
def fix_dataset(dataset: CumsumDataset):
    '''
    There was a mistake in the original setup of the problem: some tokens were negative, so they
    were causing negative indexing into the model's embedding matrix.

    This function adds to the tokens so they're all non-negative. In other words, the token indices
    (0, 1, 2, ..., max_value*2) now correspond to the values (-max_value, ..., +max_value) when we
    take the cumulative sum.
    '''
    dataset.toks += dataset.max_value


def fix_model(model: HookedTransformer):
    '''
    There was a mistake in the original setup of the problem: some tokens were negative, so they
    were causing negative indexing into the model's embedding matrix.

    This function rearranges the model's embedding matrix so that it works with the dataset returned
    from 'fix_dataset'. In other words, the rows of the model's embedding matrix now correspond to
    the values (-max_value, ..., +max_value) respectively.
    '''
    max_value = model.W_E.shape[0] // 2
    model.embed.W_E.data = t.concat([model.W_E[-max_value:], model.W_E[:-max_value]])


# Example of this being used (only has to be run once):
N = 1000
dataset = CumsumDataset(size=1000, max_value=5, seq_len=20, seed=42).to(device)
fix_dataset(dataset)
fix_model(model)

A demonstration of the model working:

In [7]:
logits, cache = model.run_with_cache(dataset.toks)

logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
probs = logprobs.softmax(-1)

batch_size, seq_len = dataset.toks.shape
logprobs_correct = eindex(logprobs, dataset.labels, "batch seq [batch seq]")
probs_correct = eindex(probs, dataset.labels, "batch seq [batch seq]")

print(f"Average cross entropy loss: {-logprobs_correct.mean().item():.3f}")
print(f"Mean probability on correct label: {probs_correct.mean():.3f}")
print(f"Median probability on correct label: {probs_correct.median():.3f}")
print(f"Min probability on correct label: {probs_correct.min():.3f}")

Average cross entropy loss: 0.077
Mean probability on correct label: 0.936
Median probability on correct label: 0.999
Min probability on correct label: 0.551


And a visualisation of its probability output for a single sequence:

In [8]:
def show(dataset: CumsumDataset, batch_idx: int):

    logits = model(dataset.toks[batch_idx].unsqueeze(0)).squeeze() # [seq_len vocab_out]
    probs = logits.softmax(dim=-1) # [seq_len vocab_out]

    imshow(
        probs.T,
        y=dataset.vocab_out,
        x=[f"{s}<br><sub>({j})</sub>" for j, s in enumerate(dataset.str_toks[batch_idx])],
        labels={"x": "Token", "y": "Vocab"},
        xaxis_tickangle=0,
        title=f"Sample model probabilities:<br>{', '.join(dataset.str_toks[batch_idx])}",
        text=[
            ["〇" if (s == target) else "" for target in dataset.str_labels[batch_idx]]
            for s in dataset.vocab_out
        ],
        width=750,
        height=350,
    )

show(dataset, 1)

Note, it was trained with a lot of weight decay, which is what makes its probabilities sometimes far from 100% (even if accuracy is basically 100%).


Best of luck! 🎈

# Work starts here

For this month's challenge, I pair programmed with **Ben Wu**. We figured out most of the important pieces of the model together.

## Attention patterns

Since the model is only one layer, each token position must attend directly to all previous token positions and itself in order to determine its cumulative sum.

By visualizing attention patterns, we can see that **attention is spread uniformly across token positions**.

In [9]:
pattern = cache["pattern", 0]

for ex in range(3):

    display(cv.attention.attention_patterns(
        attention=pattern[ex],
        tokens=dataset.str_toks[ex]
    ))



The uniform attention can be verified by manually intervening and setting attention patterns to uniform, and checking that loss is unimpacted.

In [10]:
def loss_from_logits(logits: Float[Tensor, 'batch seq vocab_out'], dataset: CumsumDataset):
    logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
    logprobs_correct = eindex(logprobs, dataset.labels, "batch seq [batch seq]")

    return -logprobs_correct.mean().item()

In [11]:
def uniform_pattern_hook(pattern: Float[Tensor, 'batch head seq seq'], hook: HookPoint):
    uniform_pattern = t.tril(t.ones_like(pattern))
    uniform_pattern = uniform_pattern / uniform_pattern.sum(-1, keepdim=True)
    pattern[:] = uniform_pattern[:]

    return pattern

In [12]:
logits = model(dataset.toks)

logits_uniform_attn = model.run_with_hooks(
    dataset.toks,
    fwd_hooks=[(utils.get_act_name("pattern", 0), uniform_pattern_hook)],
)

print("\nBASELINE:")
print(f"\tloss: {loss_from_logits(logits, dataset):.5f}")
print("\nUNIFORM ATTENTION PATTERN:")
print(f"\tloss: {loss_from_logits(logits_uniform_attn, dataset):.5f}")



BASELINE:
	loss: 0.07664

UNIFORM ATTENTION PATTERN:
	loss: 0.07663


#### Conclusion

- Attention is spread uniformly across all tokens.

## Direct logit attribution

In [13]:
per_layer_resid, component_labels = cache.decompose_resid(layer=-1, return_labels=True, incl_embeds=False)
dla_by_component = per_layer_resid @ model.W_U



for i, component_label in enumerate(component_labels):
    fig = go.Figure()
    for j, out_label in enumerate(dataset.vocab_out):
        dla = dla_by_component[i][dataset.labels[:, :] == j]

        dla_mean = dla.mean(dim=0).detach().cpu().numpy()
        dla_std  = dla.std(dim=0).detach().cpu().numpy()

        fig.add_trace(go.Bar(
            x=dataset.vocab_out,
            y=dla_mean,
            name=f"{out_label}",
            error_y=dict(
                type='data',
                array=dla_std,
                visible=True
            ),
        ))

    fig.update_layout(barmode='group', title=f'DLA {component_label}, split by correct label', xaxis_title="Vocab out", yaxis_title="DLA", width=800)
    fig.show()

#### Observations

- Both `attn_out` and `mlp_out` contribute directly to outputting the correct answer.
  - `mlp_out` has stronger contributions (higher magnitude)
- For negative sums, they upweight `neg` output logits and downweight `pos` output logits
- For positive sums, they downweight `neg` output logits and upweight `pos` output logits
- For zero sums, they slightly downweight both `neg` and `pos` output logits
  - `zero` output logits is slightly positive in all cases.



## Check structure of matrices

### $W_E$

We'd hypothesize that $W_E$ is projecting all tokens onto a 1D subspace that functions as a number line.

We can verify that the rows of $W_E$ lie in a 1D subspace by examining its SVD, and verifying that it has primarily one large singular value.

In [14]:
# SVD of W_E
U, S, Vh = t.linalg.svd(model.W_E, full_matrices=False)

In [15]:
bar(S, title="Singular values of W_E", width=600, height=400, xaxis_title="Index", yaxis_title="Value")

In [16]:
# project embeddings onto primary direction

primary_dir = Vh[0]
emb_projs = model.W_E @ primary_dir

bar(emb_projs, x=dataset.vocab, title="Embedding projection onto primary direction", width=600, xaxis_title="Input token", yaxis_title="Projection onto primary direction")

This analysis supports the claim that the embedding matrix projects tokens onto a 1D subspace that functions as a number line.

### $W_U$

In [17]:
# SVD of W_U
U_unemb, S_unemb, Vh_unemb = t.linalg.svd(model.W_U, full_matrices=False)

In [18]:
bar(S_unemb, title="Singular values of W_U", width=600, height=400, xaxis_title="Index", yaxis_title="Value")

In [19]:
proj_W_U = model.W_U.T @ U_unemb[:, :2]
proj_W_U = proj_W_U.detach().cpu().numpy()

fig = px.scatter(x=proj_W_U[:, 0], y=proj_W_U[:, 1], text=dataset.vocab_out)
fig.update_layout(
    width=600, height=500, xaxis_title="1st SVD dir", yaxis_title="2nd SVD dir", title="W_U directions, projected onto SVD directions")
fig.update_traces(textposition='top center', textfont_size=10)
fig.show()

## OV circuit

We'd expect that the OV circuit maps negative values to high logits on `neg`, and positive values to high logits on `pos`.

In [20]:
W_OV = model.W_V[0, 0] @ model.W_O[0, 0]

In [21]:
imshow(
    model.W_E @ W_OV @ model.W_U,
    y=dataset.vocab, x=dataset.vocab_out, title="W_E -> OV -> W_U", xaxis_title="Output token", yaxis_title="Input token", width=500
)

Very clean! And this also explains how DLA works from `attn_out`.

We can further analyze the structure of $W_E W_{OV}$:

In [22]:
U_OV, S_OV, Vh_OV = t.linalg.svd(model.W_E @ W_OV, full_matrices=False)
bar(S_OV, title="Singular values of W_E -> OV", width=500)

In [23]:
emb_projs = model.W_E @ W_OV @ Vh_OV[0]

bar(emb_projs, x=dataset.vocab, title="Embedding projection onto primary direction", width=500)

In [24]:
OV = model.W_E @ model.W_V[0, 0] @ model.W_O[0, 0]

proj_W_U = OV @ U_unemb[:, :2]
proj_W_U = proj_W_U.detach().cpu().numpy()

fig = px.scatter(x=proj_W_U[:, 0], y=proj_W_U[:, 1], text=dataset.vocab)
fig.update_layout(
    width=700, height=500, xaxis_title="1st principle dir", yaxis_title="2nd principle dir", title="W_E -> OV, projected onto principle directions of W_U")
fig.update_traces(textposition='top center', textfont_size=10)
fig.show()


The matrix $W_E W_{OV}$ is approximately rank 1. It essentially serves as a number line, mapping numerical tokens to a point on the line.

Note that this circuit doesn't really do anything helpful when the sum is 0: the resulting vector is simply (approximately) the 0 vector.

## MLPs

### When does each neuron get activated?

In [25]:
fig = go.Figure()

for i, label in enumerate(dataset.vocab_out):
    neuron_values = cache["post", 0][dataset.labels[:, :] == i]

    mean = neuron_values.mean(0).detach().cpu().numpy()
    std = neuron_values.std(0).detach().cpu().numpy()


    fig.add_trace(go.Bar(
        x=list(range(8)),
        y=mean,
        name=f"{label}",
        error_y=dict(
            type='data',
            array=std,
            visible=True
        ),
    ))

fig.update_layout(barmode='group', title='Neuron activations, split by label', xaxis_title="Neuron", yaxis_title="Mean activation", width=800)
fig.show()


We can categorize the neurons as follows:
- **Negative detectors**: 0, 1, 3, 4
- **Positive detectors**: 2, 7
- **Don't really do anything**: 5, 6

Note that when the sum is zero, neurons 0, 1, 2, 3, 4, 7 fire a little bit.



### What does each neuron contribute to logits?

In [26]:
signals = model.W_out[0] @ model.W_U
imshow(
    signals.T,
    title="W_out -> W_U", xaxis_title="Neuron", yaxis_title="Output token", y=dataset.vocab_out, width=800, height=300)

In [27]:
proj_W_U = model.W_out[0] @ U_unemb[:, :2]
proj_W_U = proj_W_U.detach().cpu().numpy()

fig = px.scatter(x=proj_W_U[:, 0], y=proj_W_U[:, 1], text=[f'Neuron {i}' for i in range(model.cfg.d_mlp)])
fig.update_layout(
    width=700, height=500, xaxis_title="1st principle dir", yaxis_title="2nd principle dir", title="W_out, projected onto principle directions of W_U")
fig.update_traces(textposition='top center', textfont_size=10)
fig.show()

These results are consistent with our categorization of the neurons:
- Neurons 0, 1, 3, 4 (the "negative detectors") contribute to `neg` output logit direction
- Neurons 2, 7 (the "positive detectors") contribute to `pos` output logit direction

## What's going on with zero?

So far we understand that the output of `attn_out` represents the sum, on a 1D subspace.

- When this signal is in the negative direction, `neg` logits are upweighted and `pos` logits are downweighted.
- When this signal is in the positive direction, `neg` logits are downwighted, and `pos` logits are upweighted.
- But what about when the signal is zero (there is no signal)?

### Biases

Basically, when there's no signal, bias terms come in to play. The model has learned bias terms such that in the absence of a strong `pos`/`neg` signal, biases the output logits away from `pos`/`neg` and towards `zero`.

There are three biases in the model:
- Attention: $b_O$
- MLP: $b_{in}$, $b_{out}$

In [28]:
biases = t.stack([
    model.b_O[0],
    model.b_in[0] @ model.W_out[0],
    model.b_out[0],
])

bias_labels = ["b_O", "b_in -> W_out", "b_out"]

In [29]:
bias_signal = biases @ model.W_U

imshow(
    bias_signal,
    title="Biases projected onto W_U directions", xaxis_title="Output token", yaxis_title="Bias vector", x=dataset.vocab_out, y=bias_labels, width=400, height=300)

We can see that all three of the bias terms induce a positive bias on the `zero` output logit.

Thus, when the output of the OV circuit is ~the zero vector (which happens when the cumulative is ~zero), the model will place the highest logit value on `zero`.