<a href="https://colab.research.google.com/github/andyrdt/mi/blob/main/ARENA/monthly_algorithmic_problems/07_2023/Palindromes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Monthly Algorithmic Challenge (July 2023): Palindromes

This marks the first of the (hopefully sequence of) monthly mechanistic interpretability challenges. I designed them in the spirit of [Stephen Casper's challenges](https://www.lesswrong.com/posts/KSHqLzQscwJnv44T8/eis-vii-a-challenge-for-mechanists), but with the more specific aim of working well in the context of the rest of the [ARENA material](https://arena-ch1-transformers.streamlit.app/), and helping people put into practice all the things they've learned so far.

If you prefer, you can access the Streamlit page [here](https://arena-ch1-transformers.streamlit.app/Monthly_Algorithmic_Problems#task-dataset).

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/zoom.png" width="400">

## Setup

In [1]:
%%capture
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys

if IN_COLAB:
    # Install packages
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

    # Code to download the necessary files (e.g. solutions, test funcs)
    import os, sys
    if not os.path.exists("chapter1_transformers"):
        !wget https://github.com/callummcdougall/ARENA_2.0/archive/refs/heads/main.zip
        !unzip /content/main.zip 'ARENA_2.0-main/chapter1_transformers/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter1_transformers/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter1_transformers", "chapter1_transformers")
        os.rmdir("ARENA_2.0-main")
        os.chdir("chapter1_transformers/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

In [2]:
import torch as t
from pathlib import Path

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "monthly_algorithmic_problems" / "july23_palindromes"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from monthly_algorithmic_problems.july23_palindromes.dataset import PalindromeDataset, display_seq
from monthly_algorithmic_problems.july23_palindromes.model import create_model
from plotly_utils import hist, bar, imshow

device = t.device("cuda" if t.cuda.is_available() else "cpu")



In [3]:
import os
import sys
import json
import functools
import webbrowser
from pathlib import Path
from functools import partial

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
from tqdm import tqdm
from IPython.display import display
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import einops
import gdown

from jaxtyping import Float, Int, Bool
from transformer_lens import (utils, ActivationCache, HookedTransformer,
                             HookedTransformerConfig, FactoredMatrix)
from transformer_lens.hook_points import HookPoint
from transformer_lens.components import LayerNorm
import circuitsvis as cv

from typing import List, Tuple, Union, Optional, Callable, Dict


## Task & Dataset

The directory containing all the relevant files is `chapter1_transformers/exercises/monthly_algorithmic_problems/july23_palindromes`. This contains files `model.py` (for defining the model), `training.py` (for training the model), and `dataset.py` (for the dataset of palindromes and non-palindromes).

Each sequence in the dataset looks like:

```
[start_token, a_1, a_2, ..., a_N, end_token]
```

where `start_token = 31`, `end_token = 32`, and each value `a_i` is a value in the range `[0, 30]` inclusive.

Each sequence has a corresponding label, which is `1` if the sequence is a palindrome (i.e. `(a_1, a_2, ..., a_N) == (a_N, ..., a_2, a_1)`), and `0` otherwise. The model has been trained to classify each sequence according to this label.

We've given you the class `PalindromeDataset` to store your data. You can slice this object to get batches of tokens and labels. You can also use the function `display_seq` to display a sequence in a more readable format (with any tokens that stop it from being a palindrome highlighted). There's an example later on this page.

Some other useful methods and attributes of this dataset (you can inspect `dataset.py` to see for yourself) are:

* `dataset.toks`, to get a batch of all the tokens in the dataset, of shape `(size, 2 * half_length + 2)`.
* `dataset.is_palindrome`, to get a tensor of all the labels in the dataset, of shape `(size,)`.
* `dataset.str_toks`, to get a list of lists, with string representations of each sequence, e.g. `["START", "1", "4", ..., "END"]`. This is useful for visualisation, e.g. circuitsvis.

## Model

Our model was trained by minimising cross-entropy loss between its predictions and the true labels. You can inspect the notebook `training_model.ipynb` to see how it was trained.

The model is is a 2-layer transformer with 2 attention heads, and causal attention. It includes layernorm, but no MLP layers. You can load it in as follows:

The model is is a 2-layer transformer with 2 attention heads, and causal attention. It was trained to predict the palindrome label at the `[END]` token for each sequence. You can load it in as follows:

In [4]:
filename = section_dir / "palindrome_classifier.pt"

model = create_model(
    half_length=10, # this is half the length of the palindrome sequences
    max_value=30, # values in palindrome sequence are between 0 and max_value inclusive
    seed=42,
    d_model=28,
    d_head=14,
    n_heads=2,
    normalization_type="LN",
    d_mlp=None, # this is an attn-only model
    device=device
)

state_dict = t.load(filename, map_location=t.device(device))

state_dict = model.center_writing_weights(t.load(filename))
state_dict = model.center_unembed(state_dict)
state_dict = model.fold_layer_norm(state_dict)
state_dict = model.fold_value_biases(state_dict)
model.load_state_dict(state_dict, strict=False);

The code to process the state dictionary is a bit messy, but it's necessary to make sure the model is easy to work with. For instance, if you inspect the model's parameters, you'll see that `model.ln_final.w` is a vector of 1s, and `model.ln_final.b` is a vector of 0s (because the weight and bias have been folded into the unembedding).

In [5]:
print("ln_final weight: ", model.ln_final.w)
print("\nln_final, bias: ", model.ln_final.b)

print("ln block 0 weight: ", model.blocks[0].ln1.w)
print("ln block 0 bias: ", model.blocks[0].ln1.b)

print("ln block 1 weight: ", model.blocks[1].ln1.w)
print("ln block 1 bias: ", model.blocks[1].ln1.b)

ln_final weight:  Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0',
       requires_grad=True)

ln_final, bias:  Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.], device='cuda:0', requires_grad=True)
ln block 0 weight:  Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0',
       requires_grad=True)
ln block 0 bias:  Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.], device='cuda:0', requires_grad=True)
ln block 1 weight:  Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1.

<details>
<summary>Aside - the other weight processing parameters</summary>

Here's some more code to verify that our weights processing worked, in other words:

* The unembedding matrix has mean zero over both its input dimension (`d_model`) and output dimension (`d_vocab`)
* All writing weights (i.e. `b_O`, `W_O`, and both embeddings) have mean zero over their output dimension (`d_model`)
* The value biases `b_V` are zero (because these can just be folded into the output biases `b_O`)

```python
W_U_mean_over_input = einops.reduce(model.W_U, "d_model d_vocab -> d_model", "mean")
t.testing.assert_close(W_U_mean_over_input, t.zeros_like(W_U_mean_over_input))

W_U_mean_over_output = einops.reduce(model.W_U, "d_model d_vocab -> d_vocab", "mean")
t.testing.assert_close(W_U_mean_over_output, t.zeros_like(W_U_mean_over_output))

W_O_mean_over_output = einops.reduce(model.W_O, "layer head d_head d_model -> layer head d_head", "mean")
t.testing.assert_close(W_O_mean_over_output, t.zeros_like(W_O_mean_over_output))

b_O_mean_over_output = einops.reduce(model.b_O, "layer d_model -> layer", "mean")
t.testing.assert_close(b_O_mean_over_output, t.zeros_like(b_O_mean_over_output))

W_E_mean_over_output = einops.reduce(model.W_E, "token d_model -> token", "mean")
t.testing.assert_close(W_E_mean_over_output, t.zeros_like(W_E_mean_over_output))

W_pos_mean_over_output = einops.reduce(model.W_pos, "position d_model -> position", "mean")
t.testing.assert_close(W_pos_mean_over_output, t.zeros_like(W_pos_mean_over_output))

b_V = model.b_V
t.testing.assert_close(b_V, t.zeros_like(b_V))
```

</details>

The model was trained to output the correct classification at the `END` token, in other words the value of the residual stream at `END` (post-layernorm) is mapped through `model.W_U` which has shape `(d_model, 2)`, and this gives us our classification logits for `(not palindrome, palindrome)`.

A demonstration of the model working (and of the `display_seq` function):

In [6]:
dataset = PalindromeDataset(size=100, max_value=30, half_length=10)

toks, is_palindrome = dataset[:5]

logits = model(toks)[:, -1]
probs = logits.softmax(-1)
probs_palindrome = probs[:, 1]

for tok, prob in zip(toks, probs_palindrome):
    display_seq(tok, prob)

<details>
<summary>Click on this dropdown for a hint on how to start (and some example code).</summary>

The following code will display the attention patterns for each head, on a particular example.

```python
display_seq(dataset.toks[batch_idx], probs_palindrome[batch_idx])

import circuitsvis as cv

cv.attention.attention_patterns(
    attention = t.concat([cache["pattern", layer][batch_idx] for layer in range(model.cfg.n_layers)]),
    tokens = dataset.str_toks[batch_idx],
    attention_head_names = [f"{layer}.{head}" for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)],
)
```

Find (1) a palindromic example, and (2) a non-palindromic example which is close to being palindromic (i.e. only 1 or 2 tokens are different). Then, compare the attention patterns for these two examples. Questions you might want to answer:

* How do the attention patterns for numbers which are palindromic (i.e. they are the same as their mirror image) differ from the numbers which aren't?
* How does information eventually get to the `[END]` token?

</details>

Note - although this model was trained for long enough to get loss close to zero (you can test this for yourself), it's not perfect. There are some weaknesses that the model has which make it vulnerable to adversarial examples, which I've decided to leave in as a fun extra challenge! Note that the model is still very good at its intended task, and the main focus of this challenge is on figuring out how it solves the task, not dissecting the situations where it fails. However, you might find that the adversarial examples help you understand the model better.

Best of luck! 🎈

# Andy's work begins here

## Summary of findings

### Mechanism overview

I'll give some intuitive explanations for how I think the mechanism works.

**Layer 0 Head 0 (H0.0)**
- QK circuit
  - The QK circuit serves two functions - one to do with embeddings and another to do with positions.
    - Embedding: Sequence position $i$ attends to a previous position $j$ if their token embeddings are the same.
    - Position: For $i \geq 11$, position $i$ attends to its "palindromic corresponding" position $j$ (i.e. 11 attends to 10, 12 attends to 9, and so on).
- OV circuit
  - A value's output encodes whether the originating position corresponds to the first half, or second half of the sequence.
- QK and OV in combination
  - Consider some token in the second half of the sequence.
    - If the input is a palindrome, the token will attend strongly to its corresponding token in the first half, and thus the resulting output value will strongly reflect the "first half direction."
    - If the input is not a palindrome, then the token should not attend strongly to any token in the first half. It will always attend strongly to itself, and thus its resulting value will strongly reflect the "second half direction."
  - Thus, a position with output in the "second half direction" indicates that the palindrome property has been violated at that position, and is thus a "problematic token."
    - This is the signal that H1.0 will query for in the next layer, and the signal that it will propagate.

**Layer 1 Head 0 (H1.0)**
- QK circuit
  - The last sequence position attends to keys that indicate a "problematic token," or a token that H0.0 has discovered breaks the palindrome property (it has no corresponding token in the first half).
  - H0.0 and H1.0 interact through K-composition:
    - The value from H0.0 is used as input for the key. This value from H0.0 indicates whether a token is "problematic."
    - H1.0 queries for problematic tokens, and pays attention to them.
- OV circuit
  - H1.0 projects attended-to-values to a "palindrome-or-not" direction.
  - H0.0 and H1.0 interact through V-composition:
    - The value from H0.0 is used as input for the value. This value from H0.0 indicates whether a token is "problematic."
    - H1.0 maps this value to a direction which shifts the logit difference in the "non-palindrome" direction.

**Other heads**
- Ablation testing showed that H1.1 is not useful at all.
- Ablation testing suggested H0.1 is doing *something*, but I leave it to further work to find exactly *what* it is doing.


## H0.0 QK Circuit

### Observations

From just observing the attention pattern of H0.0 on a few examples, we can immediately notice an interesting pattern.

It seems as though there are high attention scores between repeated words. I.e. the attention score $A_{i,j}^{0.0}$ is high if the tokens in the $i^{th}$ and $j^{th}$ position are the same.

Notice this pattern in the subsequent examples:

In [7]:
dataset = PalindromeDataset(size=100, max_value=30, half_length=10)

toks, is_palindrome = dataset[:100]
logits, cache = model.run_with_cache(toks[:5])
probs = logits[:, -1].softmax(-1)[:, 1]

layer = 0
head  = 0

for batch_idx in [1, 2, 3]:
  display_seq(dataset.toks[batch_idx], probs_palindrome[batch_idx])

  display(cv.attention.attention_patterns(
      attention = cache["pattern", layer][batch_idx][head].unsqueeze(0),
      tokens = dataset.str_toks[batch_idx],
      attention_head_names = [f"{layer}.{head}"],
  ))

### The embed circuit

Let's examine the circuit $W_{emb} W_{QK}^{0.0} W_{emb}^T$.

If we have two tokens $a$ and $b$ (represented as one-hot encoding column vectors), this matrix determines the attention in H0.0 between these tokens:
$$a^T W_{emb} W_{QK}^{0.0} W_{emb}^T b$$

We expect this circuit to be approximate the identity matrix, as the attention score seems to be high exactly when $a=b$.

Let's compute the circuit explicitly, and check that it resembles the identity matrix:

In [8]:
def mask_scores(attn_scores):
    '''
    Helper function to mask the attention scores so that
    tokens don't attend to previous tokens.
    '''
    mask = t.tril(t.ones_like(attn_scores)).bool()
    neg_inf = t.tensor(-1.0e6).to(attn_scores.device)
    masked_attn_scores = t.where(mask, attn_scores, neg_inf)
    return masked_attn_scores

In [9]:
W_Q_0_0 = model.W_Q[0, 0] # (d_model, d_head)
W_K_0_0 = model.W_K[0, 0] # (d_model, d_head)

W_E = model.W_E # (vocab_size, d_model)

# approximate the L0 LayerNorm as a constant scaling factor
W_E_scaled = W_E / cache["scale", 0, "ln1"].mean()

W_QK = W_Q_0_0 @ W_K_0_0.T

token_by_token = W_E_scaled @ W_QK @ W_E_scaled.T # (vocab_size, vocab_size)
token_by_token = token_by_token / model.cfg.d_head**0.5

circuit_QK_0_0 = t.softmax(token_by_token, dim=-1)

display(
    imshow(
      circuit_QK_0_0,
      title=f"Head 0.0 QK circuit (embed)",
      labels={"x": "Vocab", "y": "Vocab", "color": "QK weight"},
      width=700,
      height=700
    )
)

None

In [10]:
def top_1_acc(matrix) -> float:
    '''
    Computes the top-1 accuracy of a matrix.
    '''
    num_rows = matrix.shape[0]
    top_diag_count = 0.0
    for i in range(num_rows):
        (max, max_index) = matrix[i].max(-1)
        if max_index.item() == i:
            top_diag_count += 1.0

    return top_diag_count / num_rows

In [11]:
# we don't care about start or end tokens
effective_circuit_QK_0_0 = circuit_QK_0_0[:-2]
print(f"proportion of rows with diag element as heaviest: {top_1_acc(effective_circuit_QK_0_0)}")
print("(not including <start> or <end> tokens)")

proportion of rows with diag element as heaviest: 1.0
(not including <start> or <end> tokens)


We computed the $W_{emb} W_{QK}^{0.0} W_{emb}^T$ matrix and performed some analysis on it to support the claim that it is approxmiately diagonal.

It's plot shows that the diagonal entries are the heaviest. The `top_1_acc` measurement shows that all rows (except for those corresponding to the `start` and `end` tokens) have the diagonal element as the heaviest element.

Let's next explore how H0.0's QK circuit attends to positional information.

### The position circuit

Let's examine the circuit $W_{pos} W_{QK}^{0.0} W_{pos}^T$.

If we have two positions $i$ and $j$ (represented as one-hot encoding column vectors), this matrix determines the attention in H0.0 between these positions:
$$i^T W_{emb} W_{QK}^{0.0} W_{emb}^T j$$

Starting with $i = 11$, we expect the $i^{th}$ token to attend to its palindromic corresponding position (11 to 10, 12 to 9, and so on). This will give us a matrix that looks sort of like a lambda $\lambda$.

Let's compute the circuit explicitly, and check that it resembles lambda:

In [12]:
W_Q_0_0 = model.W_Q[0, 0] # (d_model, d_head)
W_K_0_0 = model.W_K[0, 0] # (d_model, d_head)

W_pos = model.W_pos # (seq_len, d_model)

# approximate the L0 LayerNorm as a constant scaling factor
W_pos_scaled = W_pos / cache["scale", 0, "ln1"].mean()

W_QK = W_Q_0_0 @ W_K_0_0.T

pos_by_pos = W_pos_scaled @ W_QK @ W_pos_scaled.T # (seq_len, seq_len)
pos_by_pos = pos_by_pos / model.cfg.d_head**0.5
masked = mask_scores(pos_by_pos)

circuit_pos_QK_0_0 = t.softmax(masked, dim=-1)

display(
    imshow(
      circuit_pos_QK_0_0,
      title=f"Head 0.0 QK circuit (pos)",
      labels={"x": "Position", "y": "Position", "color": "QK weight"},
      width=700,
      height=700
    )
)

None

We can roughly see the lambda $\lambda$ shape.

It's also notable that for *most* of the palindromic positional pairs `[(11, 10), (12, 9), ...]`, more attention is placed on the earlier position (although I see notable exceptions for `(19, 2)` and `(17, 4)`).

Let's zoom in on a single example's H0.0 attention pattern to really see what effect the positional QK circuit has when combined with the embedding QK circuit.

Consider the sequence `[1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5]`.

If the embedding circuit were the only contributor, then we would see attention evenly spread throughout all the repeated characters (first visualization). But in reality, some are weighted higher than others - namely the ones along the "lambda," due to the contribution of the positional circuit (second visualization). Switching between these two visualizations clearly demonstrates the power of the positional circuit.

In [13]:
# consider this example, with a bunch of spurious duplicated tokens
custom_toks = [31, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 32]

# first compute attn using only embedding information (no positional embedding)
embed = W_E[custom_toks]
resid = embed

# manually compute LayerNorm
resid = (resid - resid.mean(dim=-1, keepdim=True))
ln_out = resid / (resid.pow(2).mean(-1, keepdim=True) + 1e-5)**0.5

# QK circuit
pos_by_pos = ln_out @ W_QK @ ln_out.T
masked = mask_scores(pos_by_pos / model.cfg.d_head**0.5)
attn_no_pos = t.softmax(masked, dim=-1)

# now compute attn with embedding and positional embedding
embed = W_E[custom_toks]
pos_embed = W_pos[range(22)]
resid = embed + pos_embed

# manually compute LayerNorm
resid = (resid - resid.mean(dim=-1, keepdim=True))
ln_out = resid / (resid.pow(2).mean(-1, keepdim=True) + 1e-5)**0.5

# QK circuit
pos_by_pos = ln_out @ W_QK @ ln_out.T
masked = mask_scores(pos_by_pos / model.cfg.d_head**0.5)
attn_with_pos = t.softmax(masked, dim=-1)

display(cv.attention.attention_patterns(
      attention = t.stack((attn_no_pos, attn_with_pos)),
      tokens = [str(i) for i in custom_toks],
      attention_head_names = [f"Without position", f"With position"],
  ))

### Summary of H0.0 QK Circuit

So to summarize, there are two things that H0.0's QK circuit is effectively doing.

With respect to embeddings: a position will attend highly to another position if they have the same token embedding. This is evidenced by the diagonal nature of $W_E W_{QK}^{0.0} W_E^T$.

With respect to position: a position will attend highly to another position if it is its "palindromic correspondent." This is evidenced by the "lambda" nature of $W_{pos} W_{QK}^{0.0} W_{pos}^T$.

## Direct logit attribution

Next, we'll look at direct logit attribution. We will work backwards from the logit outputs, similar to the analysis done in [ARENA Chapter 1.4](https://arena-ch1-transformers.streamlit.app/[1.4]_Balanced_Bracket_Classifier).

### Translating through softmax

The probability `P(not_palindrome)` is a function of the difference between the two logits:

$$\text{softmax}\left( \begin{bmatrix} \text{logit}_0 \\ \text{logit}_1 \end{bmatrix} \right) = \frac{e^{\text{logit}_0}}{e^{\text{logit}_0}+ e^{\text{logit}_1}} = \frac{1}{1 + e^{\text{logit}_1 - \text{logit}_0}} = \text{sigmoid}(\text{logit}_0 - \text{logit}_1)$$

Thus, we can just reason about the difference between the logits from now on.

### Translating through linear

`logits = final_LN_output @ W_U`
- `W_U` has shape `(d_model, 2)`
- `final_LN_output` has shape `(seq_len, d_model)`

```python
logit_diff = (final_LN_output @ W_U)[-1, 0] - (final_LN_output @ W_U)[-1, 1]
           = final_LN_output[-1, :] @ (W_U[:, 0] - W_U[:, 1])
```

This means that the `logit_diff` is directly proportional to the dot product between the last row of `final_LN_output` and the vector given by `W_U[:, 0] - W_U[:, 1]`.

We'll call this vector `post_final_ln_dir := W_U[:, 0] - W_U[:, 1]`. We can compute it directly:

In [14]:
def get_post_final_ln_dir(model: HookedTransformer) -> Float[Tensor, "d_model"]:
    '''
    Returns the direction in which final_ln_output[0, :] should point to maximize P(not palindrome)
    '''
    return model.W_U[:, 0] - model.W_U[:,1]

### Ignoring the final LayerNorm

For this directional analysis, we're going to ignore the final LayerNorm for simplicity.

### Direct logit attribution for residual stream components

With this assumption, we have made it into the residual stream. We can check a residual stream's contribution to the "non-palindrome" direction by dotting it with `post_final_ln_dir`.

Let's decompose the pre-final-layernorm residual stream into its constituents, and compute the direct logit attribution for each of them.

$$x_{\text{residual}} = emb + pos + h_{0.0} + h_{0.1} + h_{1.0} + h_{1.1}$$

In [15]:
def get_out_by_components(model: HookedTransformer, data: PalindromeDataset) -> Float[Tensor, "component batch seq d_model"]:
    '''
    Computes a tensor of shape [10, dataset_size, seq_pos, emb] representing the output of the model's components when run on the data.
    The first dimension is  [emb, pos_emb, head 0.0, head 0.1, head 1.0, head 1.1]
    '''
    logits, cache = model.run_with_cache(data.toks)

    emb = cache[utils.get_act_name('embed')].unsqueeze(0)
    pos_emb = cache[utils.get_act_name('pos_embed')].unsqueeze(0)

    head_0 = einops.rearrange(cache[utils.get_act_name('result', 0)], 'b s h d -> h b s d')
    head_1 = einops.rearrange(cache[utils.get_act_name('result', 1)], 'b s h d -> h b s d')

    out = t.cat((emb, pos_emb, head_0, head_1), dim=0)
    return out

In [16]:
# double check that the sum of all components is equal to the cached final ln input
biases = model.b_O.sum(0)
out_by_components = get_out_by_components(model, dataset)
summed_terms = out_by_components.sum(dim=0) + biases

logits, cache = model.run_with_cache(dataset.toks)
final_ln_input = cache[utils.get_act_name('resid_post', 1)]

t.testing.assert_close(summed_terms, final_ln_input)
print("Output decomposition is correct!")

Output decomposition is correct!


In [17]:
out_by_components = get_out_by_components(model, dataset) # (n, batch, seq, d_model)
# we only care about the last sequence position:
out_by_components = out_by_components[:, :, -1, :] # (n, batch, d_model)

post_final_ln_dir = get_post_final_ln_dir(model)

out_by_components_in_nonpal_dir = einops.einsum(out_by_components, post_final_ln_dir, 'n batch d_model, d_model -> n batch')

labels = ['embed', 'pos_embed', 'H0.0', 'H0.1', 'H1.0', 'H1.1']

for i in range(out_by_components_in_nonpal_dir.shape[0]):
    nonpal = out_by_components_in_nonpal_dir[i][dataset.is_palindrome == 0].cpu().detach().numpy()
    pal = out_by_components_in_nonpal_dir[i][dataset.is_palindrome == 1].cpu().detach().numpy()

    overall_min = min(nonpal.min(), pal.min())
    overall_max = max(nonpal.max(), pal.max())
    bin_edges = np.linspace(overall_min, overall_max, num=51)  # num = number of bins + 1

    fig = go.Figure()
    fig.add_trace(go.Histogram(x=nonpal, xbins=dict(start=bin_edges[0], end=bin_edges[-1], size=(overall_max-overall_min)/50), name='non-pal', marker_color='red'))
    fig.add_trace(go.Histogram(x=pal, xbins=dict(start=bin_edges[0], end=bin_edges[-1], size=(overall_max-overall_min)/50), name='pal', marker_color='blue'))

    fig.update_layout(
        barmode='overlay',
        title=f'{labels[i]}',  # Sets the title for the histogram
        xaxis_title="Dot prod with non-palindrome direction",  # Sets the label for the x-axis
        yaxis_title="Frequency",  # Sets the label for the y-axis
        autosize=False,
        width=800,
        height=300,
    )
    fig.update_traces(opacity=0.60)
    fig.show()  # Displays the histogram

From these graphs, it is clear that the output of H1.0 directly contributes greatly to the correct answer. There is a clear separation between the contributions for palindromes vs non-palindromes.

H1.1 also seems to contribute a bit to boosting the non-palindromic signal for non-palindromes. However, there is not as clean of a separation between the palindrome inputs and non-palindrome inputs.

The rest of the components (embed, pos_embed, H0.0, H0.1) do not seem to have significant direct impact on the results.

## Composition scores


In the previous section, we saw that most of the direct logit attribution was coming from H1.0. Now we want to investigate how H1.0 works, and how it interacts with other heads.

A good place to start is composition score analysis. We are interested in seeing which heads interact with H1.0, and in what rough capacities.

We will check the composition scores between {H0.0 V, H0.1 V} and {H1.0 Q, H1.0 K, H1.0 V}.

This will suggest to us if there is Q-composition, K-composition, or V-composition going on.

In [18]:
def get_comp_score(
    W_A: Float[Tensor, "in_A out_A"],
    W_B: Float[Tensor, "out_A out_B"]
) -> float:
    '''
    Return the composition score between W_A and W_B.
    '''
    # W_A_W_B = FactoredMatrix(W_A, W_B)
    return ((W_A @ W_B).norm() / (W_A.norm() * W_B.norm())).item()

In [19]:
# Get all QK and OV matrices
W_QK = model.W_Q @ model.W_K.transpose(-1, -2)
W_OV = model.W_V @ model.W_O

# Define tensors to hold the composition scores
composition_scores = {
    "Q": t.zeros(model.cfg.n_heads, model.cfg.n_heads).to(device),
    "K": t.zeros(model.cfg.n_heads, model.cfg.n_heads).to(device),
    "V": t.zeros(model.cfg.n_heads, model.cfg.n_heads).to(device),
}

for i in tqdm(range(model.cfg.n_heads)):
    for j in range(model.cfg.n_heads):
        composition_scores["Q"][i, j] = get_comp_score(W_OV[0, i], W_QK[1, j])
        composition_scores["K"][i, j] = get_comp_score(W_OV[0, i], W_QK[1, j].T)
        composition_scores["V"][i, j] = get_comp_score(W_OV[0, i], W_OV[1, j])

100%|██████████| 2/2 [00:00<00:00, 108.66it/s]


In [20]:
# generate a baseline score
def generate_single_random_comp_score() -> float:
    '''
    A function which generates a single composition score for random matrices
    '''
    W_A_left = t.empty(model.cfg.d_model, model.cfg.d_head)
    W_B_left = t.empty(model.cfg.d_model, model.cfg.d_head)
    W_A_right = t.empty(model.cfg.d_model, model.cfg.d_head)
    W_B_right = t.empty(model.cfg.d_model, model.cfg.d_head)

    for W in [W_A_left, W_B_left, W_A_right, W_B_right]:
        nn.init.kaiming_uniform_(W, a=np.sqrt(5))

    W_A = W_A_left @ W_A_right.T
    W_B = W_B_left @ W_B_right.T

    return get_comp_score(W_A, W_B)

n_samples = 300
comp_scores_baseline = np.zeros(n_samples)
for i in tqdm(range(n_samples)):
    comp_scores_baseline[i] = generate_single_random_comp_score()
print("\nMean:", comp_scores_baseline.mean())
print("Std:", comp_scores_baseline.std())

baseline = comp_scores_baseline.mean()
print(baseline)

100%|██████████| 300/300 [00:00<00:00, 1315.35it/s]


Mean: 0.18756299674510957
Std: 0.01410002791311248
0.18756299674510957





In [21]:
for comp_type, comp_scores in composition_scores.items():
    display(px.imshow(
          comp_scores.cpu().detach().numpy(),
          y=[f"L0H{h}" for h in range(model.cfg.n_heads)],
          x=[f"L1H{h}" for h in range(model.cfg.n_heads)],
          labels={"x": "Layer 1", "y": "Layer 0"},
          title=f"{comp_type} Composition Scores",
          color_continuous_scale="RdBu" if baseline is not None else "Blues",
          color_continuous_midpoint=baseline if baseline is not None else None,
          zmin=None if baseline is not None else 0.0,
          width=600,
          height=400
    ))


It is pretty clear that H0.0's OV output space overlaps significantly with H1.0's Q input space and V input space. None of the other input/output spaces have significant overlap.

This suggests that H0.0 and H1.0 are related by K-composition (H0.0's output is used in H1.0's keys) and V-composition (H0.0's output is used in H1.0's values).

## H0.0 & H1.0 K-composition

Let's take a look at H1.0's QK circuit. We only care about the last position's attention pattern, as none of the others cannot have an impact on the output.

Let's look at the last position of H1.0's attention pattern for 3 examples of a palindrome and 3 examples of a non-palindrome.

In [22]:
def head_1_0_last_attn_vis_hook(
    pattern: Float[Tensor, "batch head_index dest_pos source_pos"],
    hook: HookPoint,
):
    if (hook.layer() == 1):
        display_seq(current_tok.squeeze(), -1)

        # Create a bar chart
        last_token_attn = pattern[0,0,-1].cpu().detach().numpy().tolist()
        fig = go.Figure(data=[
            go.Bar(x=[f"({i}) {s}" for i, s in enumerate(current_tok)], y=last_token_attn)
        ])

        fig.update_layout(
            title='Attention scores from last position',
            xaxis_title='Target position',
            yaxis_title='Attention score',
            autosize=False,
            width=500,
            height=300,
        )
        fig.show()

        display(
            cv.attention.attention_patterns(
                tokens=[f"{s}" for i, s in enumerate(current_tok)],
                attention=pattern.mean(0)
            )
        )

toks, is_palindrome = dataset[:100]
toks_pos = toks[is_palindrome == 1]
toks_neg = toks[is_palindrome == 0]

for label, example_set in zip(["Palindrome", "Non-palindrome"], [toks_pos, toks_neg]):
    for i in range(3):
        print(label)
        model.reset_hooks()
        current_tok = example_set[i]
        pattern_hook_names_filter = lambda name: name.endswith("pattern")
        logits = model.run_with_hooks(
            current_tok,
            return_type="logits",
            fwd_hooks=[
                (pattern_hook_names_filter, head_1_0_last_attn_vis_hook)
            ]
        )




Palindrome


Palindrome


Palindrome


Non-palindrome


Non-palindrome


Non-palindrome


Observations from the above graphs:
- For palindromes, the attention pattern is approximately diffuse across many tokens.
- For non-palindromes, the attention pattern is sharply concentrated on the tokens which cause the string not to be a palindrome.

This suggests there is something going on with the QK circuit in H1.0 that causes it to attend to token positions in the second half that violate the palindrome condition.

Since we suspect that H0.0 interacts with H1.0 through K-composition, the following is a hypothesized mechanism:
- H0.0's value indicates whether a position breaks the palindromic property.
- H0.0's value is used as a key in H1.0.
- H1.0 queries for positions where the palindromic property is broken.

### H1.0's queries

Let's examine H1.0's queries.

First we will perform mean and zero ablations on H1.0's query activations to get a sense of what's going on.

In [23]:
def get_loss(logits, is_palindrome):
    last_logits = logits[:, -1, :]
    loss = t.nn.functional.cross_entropy(last_logits, is_palindrome.to(device))
    return loss.item()

In [24]:


layer_to_ablate = 1
head_to_ablate  = 0
activations_to_ablate = 'q'

def qkv_zero_ablation_hook(
    qkv: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
) -> Float[Tensor, "batch seq n_heads d_model"]:
    '''Zero-ablates a q, k, or v activation'''
    if (hook.layer() == layer_to_ablate):
        qkv[:,:,head_to_ablate,:] = 0
    return qkv

def qkv_mean_ablation_hook(
    qkorv: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
) -> Float[Tensor, "batch seq n_heads d_model"]:
    '''Mean-ablates a q, k, or v activation'''
    if (hook.layer() == layer_to_ablate):
        qkorv[:,:,head_to_ablate] = qkorv[:,:,head_to_ablate,:].mean(dim=0, keepdim=True)
    return qkorv

toks, is_palindrome = dataset[0:100]

model.reset_hooks()
logits_no_ablation = model(toks, return_type="logits")
loss_no_ablation = get_loss(logits_no_ablation, is_palindrome)
for type, hook in zip(['mean', 'zero'], [qkv_mean_ablation_hook, qkv_zero_ablation_hook]):
    ablated_logits = model.run_with_hooks(
        toks,
        return_type="logits",
        fwd_hooks=[(utils.get_act_name(activations_to_ablate, layer_to_ablate), hook)]
    )
    ablated_loss = get_loss(ablated_logits, is_palindrome) - loss_no_ablation
    print(f"loss diff {type} ablating {activations_to_ablate} in h{layer}.{head_to_ablate}: {ablated_loss}")

loss diff mean ablating q in h0.0: 0.0004616358783096075
loss diff zero ablating q in h0.0: 2.4908793556969613


We can deduce that

1. H1.0's q values are important, since ablating them elevates loss.
2. H1.0's q values are roughly the same for every input, since mean ablating them did not impact the loss.

Therefore, we can actually simply look at the mean q and extract it.

In [25]:
toks, is_palindrome = dataset[0:100]

model.reset_hooks()
logits, cache = model.run_with_cache(toks, return_type="logits")

q_h_1_0 = cache["q", 1][:, :, 0, :]
q_avg_h_1_0 = q_h_1_0.mean(0)

# we only care about the query in the last position
q_h_1_0_direction = q_avg_h_1_0[-1]

### K-composition circuit

H0.0's OV circuit interacts with H1.0's QK circuit in K-composition.

The circuit can be written as follows:

$$\begin{align}
\text{attn between pos -1 and pos i} &= -1^T W_{pos} W_{QK}^{1.0} (W_{OV}^{0.0})^T W_{pos}^T i\\
&= -1^T W_{pos} W_Q^{1.0} (W_K^{1.0})^T (W_O^{0.0})^T (W_{V}^{0.0})^T W_{pos}^T i\\
&= (-1^T W_{pos} W_Q^{1.0} )(i^T W_{pos} W_{OV}^{0.0}W_K^{1.0})^T
\end{align}$$

where $-1$ and $i$ are one-hot column vectors encoding positions $-1$ and $i$ respectively. Note that the left expression in parentheses is a constant vector - that's the mean query direction we found above `q_h_1_0_direction`!

The expression $W_{pos} W_{QK}^{1.0} (W_{OV}^{0.0})^T W_{pos}^T$ is a `seq_len x seq_len` matrix that maps positions (on the left) to the positions they attend to (on the right). When the input on the left is $-1$, we get a vector where the $i^{th}$ entry is the attention paid from $-1$ to the $i^{th}$ position.

In [26]:
W_pos = model.W_pos
W_E = model.W_E
W_O_0_0 = model.W_O[0,0]
W_V_0_0 = model.W_V[0,0]

W_Q_1_0 = model.W_Q[1,0]
W_K_1_0 = model.W_K[1,0]

full_K_comp_circuit = W_pos[-1].unsqueeze(0) @ W_Q_1_0 @ W_K_1_0.T @ (W_V_0_0 @ W_O_0_0).T @ W_pos.T
full_K_comp_circuit = full_K_comp_circuit.view(-1, 1)
fig = px.imshow(
    full_K_comp_circuit.T.cpu().detach().numpy(),
    color_continuous_scale='RdBu',
    color_continuous_midpoint=0.0,
    title='K-composition circuit between H0.0 OV & H1.0 QK',
    height=500, width=800)
fig.update_yaxes(showticklabels=False)
tickvals = list(range(full_K_comp_circuit.shape[0]))
ticktext = [str(i) for i in tickvals]
fig.update_xaxes(tickvals=tickvals, ticktext=ticktext)
fig.show()

fig = go.Figure(data=go.Scatter(x=list(range(0, 22)), y=full_K_comp_circuit.view(-1).cpu().detach().numpy(), mode='markers', marker=dict(color=['orange'] + ['red']*10 + ['blue']*10 + ['orange'])))
fig.update_layout(title='K-composition circuit between H0.0 OV & H1.0 QK', xaxis_title='Target position', yaxis_title='H1.0 attention attribution', height=400, width=800)
fig.add_shape(type='line', x0=0, y0=0, x1=22, y1=0, line=dict(color='Black'))
fig.show()

## H0.0 & H1.0 V-composition, and the H0.0 OV circuit

From the composition score analysis, it looks like H0.0 and H1.0 are involved in V-composition in addition to K-composition.

Through the logit attribution analysis, we know that H1.0's outputs give the most significant logit attribution.

If H0.0's output is being used in H1.0's values, then we should be able to measure the direct logit attribution of H0.0's output *through* H1.0's OV circuit.

We can basically repeat the previous direct logit attribution analysis with a couple changes.
1. We will multiply H0.0's output by H1.0's OV circuit
2. Previously, only the last sequence position mattered. However, in this case, all the sequence positions matter, as they can be mixed in with the last position's H1.0 output.
  - I will just average across the sequence position for now (this is not great, but still gives a sort of reasonable result).

In [27]:
out_by_components = get_out_by_components(model, dataset) # (n, batch, seq, d_model)

# we only want to look at H0.0
out_by_components = out_by_components[2, :, :, :].unsqueeze(0) # (1, batch, seq, d_model)

# multiply by H1.0's OV circuit
h10_W_V = model.W_V[1,0]
h10_W_O = model.W_O[1,0]
h10_W_OV = h10_W_V @ h10_W_O

out_by_components_h10_OV = einops.einsum(out_by_components, h10_W_OV, 'n b seq d1, d1 d2-> n b seq d2')

# take the mean over sequence dimension
out_by_components_h10_OV = out_by_components_h10_OV.mean(dim=2)

post_final_ln_dir = get_post_final_ln_dir(model)

out_by_components_in_nonpal_dir = einops.einsum(out_by_components_h10_OV, post_final_ln_dir, 'n batch d_model, d_model -> n batch')

labels = ['H0.0']

for i in range(out_by_components_in_nonpal_dir.shape[0]):
    nonpal = out_by_components_in_nonpal_dir[i][dataset.is_palindrome == 0].cpu().detach().numpy()
    pal = out_by_components_in_nonpal_dir[i][dataset.is_palindrome == 1].cpu().detach().numpy()

    overall_min = min(nonpal.min(), pal.min())
    overall_max = max(nonpal.max(), pal.max())
    bin_edges = np.linspace(overall_min, overall_max, num=51)  # num = number of bins + 1

    fig = go.Figure()
    fig.add_trace(go.Histogram(x=nonpal, xbins=dict(start=bin_edges[0], end=bin_edges[-1], size=(overall_max-overall_min)/50), name='non-pal', marker_color='red'))
    fig.add_trace(go.Histogram(x=pal, xbins=dict(start=bin_edges[0], end=bin_edges[-1], size=(overall_max-overall_min)/50), name='pal', marker_color='blue'))

    fig.update_layout(
        barmode='overlay',
        title=f'{labels[i]}',  # Sets the title for the histogram
        xaxis_title="Dot prod with non-palindrome direction, after H1.0 OV applied",  # Sets the label for the x-axis
        yaxis_title="Frequency",  # Sets the label for the y-axis
        autosize=False,
        width=800,
        height=300,
    )
    fig.update_traces(opacity=0.60)
    fig.show()  # Displays the histogram

This looks pretty reasonable, even though we were a bit careless to average over the sequence position.

Let's get more granular, to the level of a single example.

I want to show this:
- H0.0 generates a specific OV direction for sequence positions that violate the palindrome property.
- This direction composes with H1.0's OV circuit, and ends up pointing in the non-palindrome direction.

I will first show this on a single example.

In [28]:
toks, is_palindrome = dataset[:5]
logits, cache = model.run_with_cache(toks)

# we'll focus just on a single example at index 2: START |30|22|22|00|28|01|00|09|23|27|27|23|30|00|01|28|00|22|22|23| END
index = 2
display_seq(dataset.toks[index])

display(cv.attention.attention_patterns(
    attention = t.concat([cache["pattern", layer][index] for layer in range(model.cfg.n_layers)]),
    tokens = dataset.str_toks[index],
    attention_head_names = [f"{layer}.{head}" for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)],
))

We can see that the last postition of H1.0 places a lot of attention on the token at index 13. This corresponds to the token '30', which is (partially) responsible for the input not being a palindrome.

We also see that H0.0 places high attention from index 13 to index 13. Thus, H0.0's output at index 13 is essentially the value corresponding to index 13 (passed through it's W_O).

We can show that this value, when composed with H1.0's OV circuit, results in a large direct logit attribution in the 'non-palindrome direction.'

In [29]:
nonpal_token_index = 13

# get the value from h0.0 corresponding to the nonpal token index
h00_v_nonpal = cache[utils.get_act_name('v', 0)][index, nonpal_token_index, 0, :].unsqueeze(0)

# convert to h0.0 output
h00_W_O = model.W_O[0,0]
h00_o_nonpal = h00_v_nonpal @ h00_W_O

# apply h1.0's OV circuit
h10_W_V = model.W_V[1,0]
h10_W_O = model.W_O[1,0]
h10_o_nonpal = h00_o_nonpal @ h10_W_V @ h10_W_O

# now we have a direction in residual space
# now we can dot it with the non-pal direction
# we expect a high dot product
post_final_ln_dir = get_post_final_ln_dir(model).unsqueeze(0)
dot = h10_o_nonpal @ post_final_ln_dir.T
print(f"Dot product corresponding to non-pal sequence position: {dot.item():.4f}")

Dot product corresponding to non-pal sequence position: 20.1647


We can check this value for every sequence position.

In [30]:
# get the value from h0.0 corresponding to each token index
h00_v = cache[utils.get_act_name('v', 0)][index, 1:21, 0, :]

# # convert to h0.0 output
h00_W_O = model.W_O[0,0]
h00_o = h00_v @ h00_W_O #einops.einsum(h00_v, h00_W_O, 'seq d_head, d_head d_model -> seq d_model')


# apply h1.0's OV circuit
h10_W_V = model.W_V[1,0]
h10_W_O = model.W_O[1,0]
h10_W_OV = h10_W_V @ h10_W_O
h10_o = einops.einsum(h00_o, h10_W_OV, 'seq d_model1, d_model1 d_model2 -> seq d_model2')
# # now we have a direction in residual space
# # now we can dot it with the non-pal direction
# # we expect a high dot product
post_final_ln_dir = get_post_final_ln_dir(model).unsqueeze(0)
dot = einops.einsum(h10_o, post_final_ln_dir, 'seq d_model, n d_model -> seq n')

fig = go.Figure(data=go.Scatter(x=list(range(1, 21)), y=dot.view(-1).cpu().detach().numpy(), mode='markers', marker=dict(color=['red']*10 + ['blue']*10)))
fig.update_layout(title='Logit attribution of H0.0 values through H1.0 OV, on example #2', xaxis_title='Sequence position', yaxis_title='Logit attribution', height=400, width=800)
fig.add_shape(type='line', x0=0, y0=0, x1=22, y1=0, line=dict(color='Black'))
fig.show()

Let's now perform this analysis across 100 examples!

In [31]:
toks, is_palindrome = dataset[:100]
logits, cache = model.run_with_cache(toks)

# get the value from h0.0 corresponding to each token index
h00_v = cache[utils.get_act_name('v', 0)][:, 1:21, 0, :] # batch seq d_head

# # convert to h0.0 output
h00_W_O = model.W_O[0,0]
h00_o = einops.einsum(h00_v, h00_W_O, 'batch seq d_head, d_head d_model -> batch seq d_model')

# apply h1.0's OV circuit
h10_W_V = model.W_V[1,0]
h10_W_O = model.W_O[1,0]
h10_W_OV = h10_W_V @ h10_W_O
h10_o = einops.einsum(h00_o, h10_W_OV, 'batch seq d_model1, d_model1 d_model2 -> batch seq d_model2')

# # now we have a direction in residual space
# # now we can dot it with the non-pal direction
# # we expect a high dot product
post_final_ln_dir = get_post_final_ln_dir(model).unsqueeze(0)
dot = einops.einsum(h10_o, post_final_ln_dir, 'batch seq d_model, n d_model -> batch seq n')

dot_mean = dot.mean(dim=0)
dot_std  = dot.std(dim=0)
print(f"std ~= {dot_std.mean(dim=0).item()}")

fig = go.Figure(data=go.Scatter(x=list(range(1, 21)), y=dot.view(-1).cpu().detach().numpy(), mode='markers', marker=dict(color=['red']*10 + ['blue']*10)))
fig.update_layout(title='Logit attribution of H0.0 values through H1.0 OV, across 100 examples', xaxis_title='Sequence position', yaxis_title='Logit attribution', height=400, width=800)
fig.add_shape(type='line', x0=0, y0=0, x1=22, y1=0, line=dict(color='Black'))
fig.show()

std ~= 1.6621723175048828


### The H0.0 OV & H1.1 OV circuits, and the full V-composition circuit

From the above plot, we can roughly see what the H0.0 and H1.0 OV circuits are doing.

$W_{OV}^{0.0}$ is detecting if the sequence position is in the first half or second half. If in the second half, it outputs a vector that, when run through  $W_{OV}^{1.0}$, points in the direction of 'non-palindrome'. If in the first half, the resulting vector is in the opposite direction.

$W_{OV}^{1.0}$ is responsible for taking the vectors emitted by H0.0 and transforming them according to the above relation, so that the output is either in the 'non-palindrome' direction, or in the opposite.

We can describe the full circuit like this:
$$W_{pos} W_{OV}^{0.0} W_{OV}^{1.0} \cdot \text{post_ln_dir}$$

The above `seq x 1` circuit maps a position to a 'non-palindrome' score.

When an input is a palindrome, H0.0 will attend strongly to tokens in the first half from all positions. When an input is a non-palindrome, H0.0 will strongly self-attend for some position in the second half. Thus, in the circuit above, we expect indices 1-10 to map to negative values (resulting in a vote for 'palindrome'), and indices 11-20 to map to positive values (resulting in a vote for 'non-palindrome').

We compute and display the full circuit:

In [32]:
W_pos = model.W_pos
W_O_0_0 = model.W_O[0,0]
W_V_0_0 = model.W_V[0,0]
W_O_1_0 = model.W_O[1,0]
W_V_1_0 = model.W_V[1,0]
post_final_ln_dir = get_post_final_ln_dir(model).unsqueeze(0)

full_composed_OV_circuit = W_pos @ W_V_0_0 @ W_O_0_0 @ W_V_1_0 @ W_O_1_0 @ post_final_ln_dir.T

fig = px.imshow(
    full_composed_OV_circuit.T.cpu().detach().numpy(),
    color_continuous_scale='RdBu',
    color_continuous_midpoint=0.0,
    title='V-composition circuit between H0.0 OV & H1.0 OV',
    height=500, width=800)
fig.update_yaxes(showticklabels=False)
tickvals = list(range(full_composed_OV_circuit.shape[0]))
ticktext = [str(i) for i in tickvals]
fig.update_xaxes(tickvals=tickvals, ticktext=ticktext)
fig.show()

fig = go.Figure(data=go.Scatter(x=list(range(0, 22)), y=full_composed_OV_circuit.view(-1).cpu().detach().numpy(), mode='markers', marker=dict(color=['orange'] + ['red']*10 + ['blue']*10 + ['orange'])))
fig.update_layout(title='V-composition circuit between H0.0 OV & H1.0 OV', xaxis_title='Sequence position', yaxis_title='Logit attribution', height=400, width=800)
fig.add_shape(type='line', x0=0, y0=0, x1=22, y1=0, line=dict(color='Black'))
fig.show()

### Some adversarial examples

Looking at the plot, our theory breaks down at the very beginning of the sequence and at the very end of the sequence (positions 1, 2, 4 and their corresponding 17, 19, 20).

But it turns out that the model's functionality breaks down at these positions as well!

 Here are some adversarial examples that I created to take advantage of this. All four of these examples are non-palindromes, and the non-palindromic violations come at these weak positions (1, 2, 4 and 17, 19, 20). The model classifies them as palindromes!

In [33]:
custom_toks = t.tensor(
    [
        [31, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 9, 8, 7, 6, 5, 4, 3, 15, 1, 32],
        [31, 1, 2, 3, 15, 5, 6, 7, 8, 9, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 32],
        [31, 1, 2, 3, 3, 3, 6, 7, 8, 9, 10, 10, 9, 8, 7, 6, 3, 3, 3, 2, 3, 32],
        [31, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 9, 8, 7, 6, 5, 15, 3, 15, 1, 32],
    ]
)

custom_logits, custom_cache = model.run_with_cache(custom_toks)

custom_logits = custom_logits[:, -1]
probs = custom_logits.softmax(-1)
probs_palindrome = probs[:, 1]

for i in range(custom_toks.shape[0]):
    display_seq(custom_toks[i], probs_palindrome[i])

## Further work

- What role does H0.1 play?
  - Head ablations (see [Other experiments -> Head ablation](#scrollTo=H1q6NbMy0XIs)) show that H0.1 plays some role in the mechanism:
    - When zero-ablated, loss increases.
    - When mean-ablated, there's no change in loss.

## Other experiments

### Head ablations

We can perform abltion analysis to identify which heads are most important.

In [34]:
def get_loss(logits, is_palindrome):
    last_logits = logits[:, -1, :]
    loss = t.nn.functional.cross_entropy(last_logits, is_palindrome.to(device))
    return loss.item()

In [35]:
def head_ablation_hook(
    attn_result: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
    head_index_to_ablate: int,
    ablation_type: str, # either 'zero', 'random', 'mean'
) -> Float[Tensor, "batch seq n_heads d_model"]:
    if ablation_type == 'zero':
        attn_result[:,:,head_index_to_ablate,:] = 0
    elif ablation_type == 'random':
        attn_result[:,:,head_index_to_ablate] = t.randn_like(attn_result[:,:,head_index_to_ablate])
    elif ablation_type == 'mean':
        attn_result[:,:,head_index_to_ablate] = attn_result[:,:,head_index_to_ablate].mean(0, keepdim=True)
    return attn_result

In [36]:
dataset = PalindromeDataset(size=500, max_value=30, half_length=10)
toks, is_palindrome = dataset[:500]

ablation_scores = t.zeros((model.cfg.n_layers, model.cfg.n_heads))

model.reset_hooks()
logits_no_ablation = model(toks, return_type="logits")
loss_no_ablation = get_loss(logits_no_ablation, is_palindrome)
print(f"loss_no_ablation: {loss_no_ablation}")
for ablation_type in ['zero', 'random', 'mean']:
    ablation_scores = t.zeros((model.cfg.n_layers, model.cfg.n_heads))
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            temp_hook_fn = functools.partial(
                head_ablation_hook,
                head_index_to_ablate=head,
                ablation_type=ablation_type)
            ablated_logits = model.run_with_hooks(
                toks,
                return_type="logits",
                fwd_hooks=[(utils.get_act_name("result", layer), temp_hook_fn)]
            )
            ablated_loss = get_loss(ablated_logits, is_palindrome) - loss_no_ablation
            ablation_scores[layer, head] = ablated_loss
            # print(f"loss diff ablating {layer}.{head}: {ablated_loss}")

    imshow(
        ablation_scores,
        labels={"x": "Head", "y": "Layer", "color": "Logit diff"},
        title=f"Loss Difference After Ablating Heads, ablation_type: {ablation_type}",
        text_auto=".2f",
        width=900, height=400
    )

loss_no_ablation: 0.013713347725570202


Some conclusions from this ablation analysis:
- **Head 1.1 seems to be doing absolutely nothing.**
- Heads 0.0, 0.1, and 1.0 are quite important to the mechanism.
- Mean-ablating head 0.1 seems to have no impact on performance. This suggests that its output is roughly the same for all inputs (?)


### QKV ablations

We can ablate the indivudal Q, K, and V activations for each head.

In [37]:
def qkv_zero_ablation_hook(
    qkorv: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
) -> Float[Tensor, "batch seq n_heads d_model"]:
    if (hook.layer() == layer):
        qkorv[:,:,head,:] = 0
    return qkorv

def qkv_mean_ablation_hook(
    qkorv: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
) -> Float[Tensor, "batch seq n_heads d_model"]:
    if (hook.layer() == layer):
        qkorv[:,:,head] = qkorv[:,:,head,:].mean(dim=0, keepdim=True)
    return qkorv

toks, is_palindrome = dataset[0:200]

for type in ['q', 'k', 'v']:
    print(f"Mean ablating {type}")
    for layer in [0,1]:
        for head in [0,1]:
            model.reset_hooks()
            logits_no_ablation = model(toks, return_type="logits")
            loss_no_ablation = get_loss(logits_no_ablation, is_palindrome)

            ablated_logits = model.run_with_hooks(
                toks,
                return_type="logits",
                fwd_hooks=[(utils.get_act_name(type, layer), qkv_mean_ablation_hook)]
            )
            ablated_loss = get_loss(ablated_logits, is_palindrome) - loss_no_ablation
            print(f"loss diff mean ablating {type} in h{layer}.{head}: {ablated_loss}")

print('\n\n')

for type in ['q', 'k', 'v']:
    print(f"Zero ablating {type}")
    for layer in [0,1]:
        for head in [0,1]:
            model.reset_hooks()
            logits_no_ablation = model(toks, return_type="logits")
            loss_no_ablation = get_loss(logits_no_ablation, is_palindrome)

            ablated_logits = model.run_with_hooks(
                toks,
                return_type="logits",
                fwd_hooks=[(utils.get_act_name(type, layer), qkv_zero_ablation_hook)]
            )
            ablated_loss = get_loss(ablated_logits, is_palindrome) - loss_no_ablation
            print(f"loss diff zero ablating {type} in h{layer}.{head}: {ablated_loss}")

Mean ablating q
loss diff mean ablating q in h0.0: 1.289652420207858
loss diff mean ablating q in h0.1: 0.0038201939314603806
loss diff mean ablating q in h1.0: 0.0001571420580148697
loss diff mean ablating q in h1.1: 0.0001202831044793129
Mean ablating k
loss diff mean ablating k in h0.0: 0.9433671962469816
loss diff mean ablating k in h0.1: 0.027933679521083832
loss diff mean ablating k in h1.0: 0.2121445070952177
loss diff mean ablating k in h1.1: 0.0009737040381878614
Mean ablating v
loss diff mean ablating v in h0.0: 0.9077602755278349
loss diff mean ablating v in h0.1: 0.0028489665128290653
loss diff mean ablating v in h1.0: 0.6797835361212492
loss diff mean ablating v in h1.1: -4.720850847661495e-05



Zero ablating q
loss diff zero ablating q in h0.0: 2.9085682164877653
loss diff zero ablating q in h0.1: -0.0005261398619040847
loss diff zero ablating q in h1.0: 1.8412707578390837
loss diff zero ablating q in h1.1: -0.0014881233219057322
Zero ablating k
loss diff zero ablating k