This notebook is where I generate the histograms & heatmaps of attention score contributions in the neg name mover heads.

The hypothesis is that the main contributor on the query-side is the unembedding of token X, and on the key-side is the embedding of token X. In the IOI task, X is the IO token.

## Setup

In [1]:
from transformer_lens.cautils.notebook import *
from transformer_lens.rs.callum.keys_fixed import (
    attn_scores_as_linear_func_of_keys,
    attn_scores_as_linear_func_of_queries,
    get_attn_scores_as_linear_func_of_queries_for_histogram,
    get_attn_scores_as_linear_func_of_keys_for_histogram,
    decompose_attn_scores,
    plot_contribution_to_attn_scores,
    project,
    decompose_attn_scores_full,
    create_fucking_massive_plot_1,
    create_fucking_massive_plot_2
)
from transformer_lens.rs.callum.generate_bag_of_words_quad_plot import get_effective_embedding

# effective_embeddings = get_effective_embedding(model) 

# W_U = effective_embeddings["W_U (or W_E, no MLPs)"]
# W_EE = effective_embeddings["W_E (including MLPs)"]
# W_EE_subE = effective_embeddings["W_E (only MLPs)"]

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

clear_output()

# What happens when you freeze keys / queries, and patch the other?

## Freeze keys, patch queries

In [9]:
NUM_BATCHES = 30
N = 40
NAME_TOKENS = model.to_tokens(NAMES, prepend_bos=False).squeeze().tolist()
NNMH_LIST = [(10, 7), (11, 10)]

attn_scores, attn_probs = get_attn_scores_as_linear_func_of_queries_for_histogram(
    NNMH_LIST[0], 
    num_batches=NUM_BATCHES,
    batch_size=N,
    model=model,
    name_tokens=NAME_TOKENS,
    subtract_S1_attn_scores=True
)

  0%|          | 0/30 [00:00<?, ?it/s]

In [13]:
labels_list, attn_scores_list = zip(*attn_scores.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_scores_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_scores_list = [attn_scores_list[i] for i in sorted_indices]

h = hist(
    attn_scores_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn<sub>END➔IO</sub> - Attn<sub>END➔S1</sub> in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.65,
    marginal="box",
    template="simple_white",
    nbins=120,
    return_fig=True,
)
for i in [0, 2, 3]:
    h.data[2*i].visible = "legendonly"
    h.data[2*i+1].visible = "legendonly"
h.show(config={"staticPlot": True})

Unembedding for IO has way more influence on the final attention scores than anything else. Most importantly, **it boosts attn scores more than the actual output of NMH 9.9**.

We've already observed that the name mover heads' output basically entirely creates the queries for the neg heads like 10.7 (i.e. if you path patch from name movers to 10.7 then attention from END to IO in 10.7 is no greater than attention from END to IO). So this suggests that **the entire reason the neg heads attend back to IO (on the query side) is because they pick up on the unembedding of IO which is stored there**.

I also want to check how large the component of `W_U[IO]` is in the output of head 9.9. Let's do that quickly.

In [14]:
batch = 100

ioi_dataset, ioi_cache = generate_data_and_caches(batch, model=model, only_ioi=True, prepend_bos=True)

result = ioi_cache["result", 9][range(batch), ioi_dataset.word_idx["end"], 9]
result_normalized = result / result.norm(dim=-1, keepdim=True)

io_dir = model.W_U.T[ioi_dataset.io_tokenIDs]
s_dir = model.W_U.T[ioi_dataset.s_tokenIDs]

result_in_io_dir, result_in_io_perp_dir = project(result_normalized, io_dir, test=False)

io_component_norm = result_in_io_dir.norm(dim=-1).mean()
io_perp_component_norm = result_in_io_perp_dir.norm(dim=-1).mean()

print(f"W_U[IO] component fraction explained: {io_component_norm**2:.4f}")
print(f"W_U[IO]-perp component fraction explained: {io_perp_component_norm**2:.4f}")

result_in_io_s_dir, result_in_io_s_perp_dir = project(result_normalized, [io_dir, s_dir], test=False)

io_s_component_norm = result_in_io_s_dir.norm(dim=-1).mean()
io_s_perp_component_norm = result_in_io_s_perp_dir.norm(dim=-1).mean()

print(f"\nW_U[IO] & W_U[S] component fraction explained: {io_s_component_norm**2:.4f}")
print(f"W_U[IO] & W_U[S]-perp component fraction explained: {io_s_perp_component_norm**2:.4f}")

W_U[IO] component fraction explained: 0.1492
W_U[IO]-perp component fraction explained: 0.8488

W_U[IO] & W_U[S] component fraction explained: 0.1590
W_U[IO] & W_U[S]-perp component fraction explained: 0.8388


Yeah okay, not that large. It's an important part of its output, but not the most important part. This suggests that 9.9 is partly communicating with other heads (maybe the neg heads!) rather than just writing in the unembedding direction of IO.

Lastly, let's look at the cosine similarities. I want to see if any part of `9.9`'s output which is not in the IO unembedding direction is also useful.

In [15]:
linear_map, bias_term = attn_scores_as_linear_func_of_queries(None, (10, 7), model, ioi_cache, ioi_dataset, subtract_S1_attn_scores=False)

io_cos_sim = t.cosine_similarity(result_in_io_dir, linear_map, dim=-1).mean()
io_perp_cos_sim = t.cosine_similarity(result_in_io_perp_dir, linear_map, dim=-1).mean()

print(f"Cos sim with IO component: {io_cos_sim:.4f}")
print(f"Cos sim with IO-perp component: {io_perp_cos_sim:.4f}\n")

linear_map, bias_term = attn_scores_as_linear_func_of_queries(None, (10, 7), model, ioi_cache, ioi_dataset, subtract_S1_attn_scores=True)

io_cos_sim = t.cosine_similarity(result_in_io_dir, linear_map, dim=-1).mean()
io_perp_cos_sim = t.cosine_similarity(result_in_io_perp_dir, linear_map, dim=-1).mean()

print(f"Cos sim with IO component (subtract S1 baseline): {io_cos_sim:.4f}")
print(f"Cos sim with IO-perp component (subtract S1 baseline): {io_perp_cos_sim:.4f}\n")

io_s_cos_sim = t.cosine_similarity(result_in_io_s_dir, linear_map, dim=-1).mean()
io_s_perp_cos_sim = t.cosine_similarity(result_in_io_s_perp_dir, linear_map, dim=-1).mean()

print(f"Cos sim with IO & S component (subtract S1 baseline): {io_s_cos_sim:.4f}")
print(f"Cos sim with (IO & S)-perp component (subtract S1 baseline): {io_s_perp_cos_sim:.4f}")

Cos sim with IO component: 0.2794
Cos sim with IO-perp component: 0.0834

Cos sim with IO component (subtract S1 baseline): 0.2001
Cos sim with IO-perp component (subtract S1 baseline): 0.1520

Cos sim with IO & S component (subtract S1 baseline): 0.2302
Cos sim with (IO & S)-perp component (subtract S1 baseline): 0.1353


The output of head 9.9 and 9.6 which are perpendicular to $W_U[IO]$ - are they close to each other? **Unfortunately, no.** This isn't super telling though; obviously there will be a big component of their output which isn't used by head 10.7 - also they might be complimentary (i.e. fit together rather than add on top of each other).

In [16]:
result_99 = ioi_cache["result", 9][range(batch), ioi_dataset.word_idx["end"], 9]
result_96 = ioi_cache["result", 9][range(batch), ioi_dataset.word_idx["end"], 6]

io_dir = model.W_U.T[ioi_dataset.io_tokenIDs]
s_dir = model.W_U.T[ioi_dataset.s_tokenIDs]

result_99_in_io_dir, result_99_in_io_perp_dir = project(result_99, [io_dir, s_dir], test=False)
result_96_in_io_dir, result_96_in_io_perp_dir = project(result_96, [io_dir, s_dir], test=False)

cos_sim_99_to_96 = t.cosine_similarity(result_99_in_io_dir, result_96_in_io_dir, dim=-1).mean()
cos_sim_99_to_96 = t.cosine_similarity(result_99_in_io_perp_dir, result_96_in_io_perp_dir, dim=-1).mean()

print(f"Cos sim between 9.9 & 9.6 (perpendicular to query unembedding direction): {cos_sim_99_to_96:.4f}")

Cos sim between 9.9 & 9.6 (perpendicular to query unembedding direction): -0.0341


The other component has 3x lower cosine similarity, but it's also got larger norm. Also, when `subtract_S1_attn_scores=True`, this becomes much larger.

### Attn probs?

The plot below shows this in terms of attention probs (because it's good to know that the attention scores above are sufficient to affect the probs), even though attention scores are a more natural way of thinking about the function from queries -> attention, because of linearity.

Sure enough, attention probs are nearly 1 with the actual output of NMH (and with no patching), but they get absolutely hammered up to 1 when we use the unembedding of IO instead.

In [None]:
labels_list, attn_probs_list = zip(*attn_probs.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_probs_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_probs_list = [attn_probs_list[i] for i in sorted_indices]

h = hist(
    attn_probs_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn probs (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=100,
    return_fig=True,
)
for i in [1, 2]:
    h.data[2*i].visible = "legendonly"
    h.data[2*i+1].visible = "legendonly"
h.show(config={"staticPlot": True})

## Freeze queries, patch keys

Now that we've got satisfying results when the keys are fixed (and attn is linear func of queries), what about the other way around? We want to show the equiv result, i.e. that the main thing determining attn on the key-side is the IO unembedding. If not, then what the hell does the unembedding of IO match with?

In [None]:
attn_scores, attn_probs = get_attn_scores_as_linear_func_of_keys_for_histogram(
    NNMH_LIST[0], 
    num_batches=NUM_BATCHES,
    batch_size=N,
    model=model,
)

labels_list, attn_scores_list = zip(*attn_scores.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_scores_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_scores_list = [attn_scores_list[i] for i in sorted_indices]

hist(
    attn_scores_list,
    labels={"variable": "Key-side vector", "value": "Attention scores"},
    title="Attn scores (from END -> IO) in NNMH 10.7 (queries fixed, scores are linear func of keys)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=120,
    static=True,
)

  0%|          | 0/30 [00:00<?, ?it/s]

In [None]:
# labels_list, attn_probs_list = zip(*attn_probs.items())

# sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_probs_list]))
# labels_list = [labels_list[i] for i in sorted_indices]
# attn_probs_list = [attn_probs_list[i] for i in sorted_indices]

# hist(
#     attn_probs_list,
#     labels={"variable": "Query-side vector", "value": "Attention probs"},
#     title="Attn probs (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
#     names=labels_list,
#     width=1000,
#     height=600,
#     opacity=0.6,
#     marginal="box",
#     template="simple_white",
#     nbins=100,
#     static=True,
# )

In [None]:
attn_scores, _ = get_attn_scores_as_linear_func_of_keys_for_histogram(
    NNMH_LIST[0], 
    num_batches=NUM_BATCHES,
    batch_size=N,
    model=model,
    subtract_S1_attn_scores=True,
)

labels_list, attn_scores_list = zip(*attn_scores.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_scores_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_scores_list = [attn_scores_list[i] for i in sorted_indices]

hist(
    attn_scores_list,
    labels={"variable": "Key-side vector", "value": "Attention scores diff"},
    title="Attn scores (END->IO) minus (END->S1) in NNMH 10.7 (queries fixed, scores are linear func of keys)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=120,
    static=False,
    legend_traceorder="reversed",
)

  0%|          | 0/30 [00:00<?, ?it/s]

## Analysis

Okay, so replacing the residual stream with the effective embedding (or just the MLP out) is an improvement. Not much of an improvement (as expected from the cosine sim of just 0.1 below), but still an improvement.

Note that we don't get an improvement unless we subtract the S1 baseline, which I claim is a principled thing to do (see discussion below).

In [None]:
batch = 100

ioi_dataset, ioi_cache = generate_data_and_caches(batch, model=model, only_ioi=True, prepend_bos=True)

resid_pre = ioi_cache["resid_pre", 10][range(batch), ioi_dataset.word_idx["IO"]]
resid_pre_normalized = resid_pre / resid_pre.norm(dim=-1, keepdim=True)

mlp0_dir = ioi_cache["mlp_out", 0][range(batch), ioi_dataset.word_idx["IO"]]

resid_pre_in_mlp0_dir, resid_pre_in_mlp0_perp_dir = project(resid_pre_normalized, mlp0_dir)

mlp0_component_norm = resid_pre_in_mlp0_dir.norm(dim=-1).mean()
mlp0_perp_component_norm = resid_pre_in_mlp0_perp_dir.norm(dim=-1).mean()

print(f"MLP0 component norm: {mlp0_component_norm:.4f}")
print(f"MLP0-perp component norm: {mlp0_perp_component_norm:.4f}")

assert (mlp0_component_norm ** 2 + mlp0_perp_component_norm ** 2 - 1).abs() < 5e-3

MLP0 component norm: 0.6732
MLP0-perp component norm: 0.7362


In [None]:
linear_map, bias_term = attn_scores_as_linear_func_of_keys(None, (10, 7), model, ioi_cache, ioi_dataset)

MLP0_cos_sim = t.cosine_similarity(resid_pre_in_mlp0_dir, linear_map, dim=-1).mean()
MLP0_perp_cos_sim = t.cosine_similarity(resid_pre_in_mlp0_perp_dir, linear_map, dim=-1).mean()

print(f"Cos sim with MLP0 component: {MLP0_cos_sim:.4f}")
print(f"Cos sim with MLP0-perp component: {MLP0_perp_cos_sim:.4f}\n")

resid_pre_S1 = ioi_cache["resid_pre", 10][range(batch), ioi_dataset.word_idx["S1"]]
resid_pre_normalized_S1 = resid_pre_S1 / resid_pre_S1.norm(dim=-1, keepdim=True)
mlp0_dir_S1 = ioi_cache["mlp_out", 0][range(batch), ioi_dataset.word_idx["S1"]]
resid_pre_in_mlp0_dir_S1, resid_pre_in_mlp0_perp_dir_S1 = project(resid_pre_normalized_S1, mlp0_dir_S1)

MLP0_cos_sim = t.cosine_similarity(resid_pre_in_mlp0_dir - resid_pre_in_mlp0_dir_S1, linear_map, dim=-1).mean()
MLP0_perp_cos_sim = t.cosine_similarity(resid_pre_in_mlp0_perp_dir - resid_pre_in_mlp0_perp_dir_S1, linear_map, dim=-1).mean()

print(f"Cos sim with MLP0 component (subtract baseline): {MLP0_cos_sim:.4f}")
print(f"Cos sim with MLP0-perp component (subtract baseline): {MLP0_perp_cos_sim:.4f}")

Cos sim with MLP0 component: 0.0278
Cos sim with MLP0-perp component: 0.1050

Cos sim with MLP0 component (subtract baseline): 0.1082
Cos sim with MLP0-perp component (subtract baseline): -0.0016


# Which components contribute on the key and query sides?

Here I'm going to break down the components of the key position (by every head and every MLP) to see which one contributes most to the attention scores.

It'll be a heatmap of all the attention heads (and the MLP). Each value will the the attention score contribution (with appropriate LN scale applied).

## Formalism

Here's an equation for a single scalar attention score $S$, from the `END` token to the `IO` token:

$$
S = (W_Q^T x_{E} + b_Q)^T(W_K^T x_{IO} + b_K)
$$

where $x_E$ is the query-side vector (of shape `d_model`) at the `END` token, and $x_{IO}$ is the key-side vector at the `IO` token.

### Decomposing on key-side

If we fix queries $q = W_Q^T x_E + b_Q$, we can expand this into:

$$
S = q^T W_K^T x_{IO} + q^T b_K
$$

and if we want to figure out which key-side components affect the attention most, then we can decompose $x_{IO}$ as a sum of terms: one for each attn head and MLP before head 10.7, plus the direct terms and the bias term:

$$
S = \underbrace{\left(\sum_{L=0}^9 \sum_{H=0}^{11} q^T W_K^T x_{IO}^{L.H}\right)}_{\text{attn heads}} + \underbrace{\left(\sum_{L=0}^9 q^T W_K^T x_{IO}^{L}\right)}_{\text{MLPs}} + \underbrace{q^T W_K^T x_{IO}^{e} + q^T W_K^T x_{IO}^{pe}}_{\text{direct path}} + \underbrace{q^T b_K}_{\text{bias term}}
$$

where $x_{IO}^{L.H}$ is the key-side vector (i.e. at the `IO` position) output by head $H$ in layer $L$, $x_{IO}^{L}$ is the key-side vector output by the layer-$L$ MLP, and $x_{IO}^e$ and $x_{IO}^{pe}$ are the embedding and positional embedding for `IO` respectively.

**What does this tell us?** If one of the terms is much larger than the others, this tells us that this component is the most important one (on the key-side) for determining attention scores. For instance, our theory is that the main contributor on the key-side is the embedding of `IO` (specifically the part of the extended embedding which isn't tied to the unembedding, in other words `MLP0` and possibly some layer-0 attention heads).

> The function `attn_scores_as_linear_func_of_queries` returns `(linear_map, bias_term)`, which are the terms $q^T W_K^T$ and $q^T b_K$ respectively.

### Subtracting the baseline

Rather than calculating the attn score contribution from `END -> IO`, we might want to subtract the contribution from `END -> S1`. This effectively removes noise from our analysis (since we're concerned with how the name movers attend to `IO` rather than `S1`). In other words, we have $x_{IO} - x_{S1}$ in place of $x_{IO}$. Note that this deletes our bias term (which is the same at `IO` and `S1` positions).

$$
S = \underbrace{\left(\sum_{L=0}^9 \sum_{H=0}^{11} q^T W_K^T (x_{IO}^{L.H} - x_{S1}^{L.H})\right)}_{\text{attn heads}} + \underbrace{\left(\sum_{L=0}^9 q^T W_K^T (x_{IO}^{L} - x_{S1}^{L})\right)}_{\text{MLPs}} + \underbrace{q^T W_K^T (x_{IO}^{e} - x_{S1}^{e}) + q^T W_K^T (x_{IO}^{pe} - x_{S1}^{pe})}_{\text{direct path}}
$$

> **A bit more on why I think subtracting the baseline is a principled thing to do**
> 
> `IO` and `S1` are basically symmetric in terms of the IOI task: they're both names which occur at the start of the sentence (and their position actually gets shuffled, sometimes `IO` is first and sometimes `S1` is first). If we don't subtract the baseline, we're basically examining the sum of two circuits: name attention (attend to one of the first names in the sequence) and copy-suppression (attend to `IO` over `S1`). If we subtract the baseline, we're examining the copy-suppression circuit in isolation.
> 
> Will we always be able to subtract the baseline? No, but we'll probably be able to for more examples than just IOI. It would be good to have some more examples where we can subtract the baseline, and get evidence for our prediction-attention hypothesis.

Most of the plots below will subtract the baseline (denoted by **Baseline subtracted** vs. **Baseline not subtracted**).

### Decomposing on query-side

If we fix queries $k = W_K^T x_{IO} + b_K$, we can expand this into:

$$
S = k^T W_Q^T x_E + k^T b_Q
$$

If we want to decompose by query-side component, we can write this as:

$$
S = \underbrace{\left(\sum_{L=0}^9 \sum_{H=0}^{11} k^T W_Q^T x_{E}^{L.H}\right)}_{\text{attn heads}} + \underbrace{\left(\sum_{L=0}^9 k^T W_Q^T x_{E}^{L}\right)}_{\text{MLPs}} + \underbrace{k^T W_Q^T x_{E}^{e} + k^T W_Q^T x_{E}^{pe}}_{\text{direct path}} + \underbrace{k^T b_Q}_{\text{bias term}}
$$

If we want to subtract the baseline attention to `S1`, then rather than changing our components $x_E$, we change our linear map:

$$
S = (k_{IO} - k_{S1})^T W_Q^T x_E + (k_{IO} - k_{S1})^T b_Q
$$

where $k_{IO} - k_{S1} = (W_K^T x_{IO} + b_K) - (W_K^T x_{IO} + b_K) = W_K^T (x_{IO} - x_{S1})$. With this new linear map, we cn do the same decomposition.

**What does this tell us?** Same thing as last time, but for the query-side. For instance, we expect the name mover heads to show up strongest on this plot, because they seem to be what causes the `END` token to attend to `IO` in the neg name mover 10.7.

### Decomposing into parallel and perpendicular components

So far, this has all been pretty standard, and it doesn't go beyond e.g. basic heatmap attention attributions which are commonly done when investigating induction heads. But here, we go further, by decomposing the key-side and query-side components into parallel and perpendicular components (wrt some specific vector). So rather than just identifying the components responsible for this effect, we're trying to identify **the specific directions these components write to which give us our attention scores.**

To be more specific - we think that the most important interaction is between:

* the query-side component in the direction of the unembedding $W_U[IO]$ (output by the name mover heads), and
* the key-side component in the direction of the (extended) embedding for the IO token (output primarily by MLP0).

So for any given term $x_Q^T W_Q W_K^T x_K$, we can decompose it into up to 4 terms:

$$
(x_Q^∥)^T W_Q W_K^T x_K^{∥} + (x_Q^∥)^T W_Q W_K^T x_K^{⟂} + (x_Q^⟂)^T W_Q W_K^T x_K^{∥} + (x_Q^⟂)^T W_Q W_K^T x_K^{⟂}
$$

where $x_Q^∥$, $x_Q^⟂$ are the components of the query-side vector in the residual stream $x_Q$ which are parallel / perpendicular to the unembedding vector $W_U[IO]$ respectively (so $x_Q^∥ + x_Q^⟂ = x_Q$), and similar for $x_K^∥$, $x_K^⟂$.

> **What does this tell us?** If our theory is correct, then we expect the first of these terms to be the largest. Note that this is a pretty high burden of proof, because if we just picked random vectors to project onto and calculated this decomposition, we'd expect the first term to have a norm of $(1/d_{model}) \times (1/d_{model}) \approx 0.00017\%$ relative to the full term.

Note on the bias term - this by default always goes in the "⟂" component, because this seems more reasonable (this way the ∥ component is ***only*** capturing the interaction from the residual-stream vector in a particular direction). Also, this is a higher burden of proof, because then the observation "all the effect comes from the ∥ component" implies a stronger conclusion.

In [None]:
# define a new decompose_attn_scores function which cuts down on all the code! (I don't use partial cause I don't like when the function shows up as blue)

model = model.cuda();
BATCH_SIZE = 50

def _decompose_attn_scores(
    decompose_by: Literal["keys", "queries"],
    show_plot: bool = False,
    intervene_on_query: Literal["sub_W_U_IO", "project_to_W_U_IO", None] = None,
    intervene_on_key: Literal["sub_MLP0", "project_to_MLP0", None] = None,
    use_effective_embedding: bool = False,
    use_layer0_heads: bool = False,
    subtract_S1_attn_scores: bool = False,
    include_S1_in_unembed_projection: bool = False,
    static: bool = False,
):
    t.cuda.empty_cache()
    return decompose_attn_scores(
        batch_size=BATCH_SIZE,
        seed=42,
        nnmh=(10, 7),
        model=model,
        decompose_by=decompose_by,
        show_plot=show_plot,
        intervene_on_query=intervene_on_query,
        intervene_on_key=intervene_on_key,
        use_effective_embedding=use_effective_embedding,
        use_layer0_heads=use_layer0_heads,
        subtract_S1_attn_scores=subtract_S1_attn_scores,
        include_S1_in_unembed_projection=include_S1_in_unembed_projection,
        static=static,
    )

Moving model to device:  cuda


## Decompose with baseline subtracted / not subtracted

First, before any parallel/perp decomposition, I want to compare what the baseline being subtracted / not subtracted does. In other words, the two plots below demonstrate the following decompositions respectively:

Prior is that baseline being subtracted leads to less noisy plots where we can observe our prediction-attention patterns more crisply.

### First plot

On the left is the key-side decomposition, i.e. each of the cells equals the value of one of the terms in the following sum:

$$
S = \underbrace{\left(\sum_{L=0}^9 \sum_{H=0}^{11} q^T W_K^T x_{IO}^{L.H}\right)}_{\text{attn heads}} + \underbrace{\left(\sum_{L=0}^9 q^T W_K^T x_{IO}^{L}\right)}_{\text{MLPs}} + \underbrace{q^T W_K^T x_{IO}^{e} + q^T W_K^T x_{IO}^{pe}}_{\text{direct path}} + \underbrace{q^T b_K}_{\text{bias term}}
$$

And on the right is the key-side decomposition, with baseline subtracted:

$$
S = \underbrace{\left(\sum_{L=0}^9 \sum_{H=0}^{11} q^T W_K^T (x_{IO}^{L.H} - x_{S1}^{L.H})\right)}_{\text{attn heads}} + \underbrace{\left(\sum_{L=0}^9 q^T W_K^T (x_{IO}^{L} - x_{S1}^{L})\right)}_{\text{MLPs}} + \underbrace{q^T W_K^T (x_{IO}^{e} - x_{S1}^{e}) + q^T W_K^T (x_{IO}^{pe} - x_{S1}^{pe})}_{\text{direct path}}
$$

> As expected, very crisp on MLP0 when (plus some very small contributions from some of the layer-0 attention heads). Interesting that some of the other MLPs slightly matter - we can guess these might be extended embeddings in a much weaker sense than the first one.

In [None]:
contribution_to_attn_scores_decompose_k = _decompose_attn_scores(
    decompose_by = "keys",
)
contribution_to_attn_scores_decompose_k_sub_S1 = _decompose_attn_scores(
    decompose_by = "keys",
    subtract_S1_attn_scores = True,
)

plot_contribution_to_attn_scores(
    t.stack([contribution_to_attn_scores_decompose_k_sub_S1, contribution_to_attn_scores_decompose_k]),
    decompose_by = "keys",
    facet_labels = ["Baseline subtracted", "Baseline not subtracted"],
    title = "Decompose on key-side, subtract vs don't subtract baseline",
    static = True,
)

### Decomp queries

On the left is the query-side decomposition with baselines subtracted, i.e. each of the cells equals the value of one of the terms in the following sum:

$$
S = \underbrace{\left(\sum_{L=0}^9 \sum_{H=0}^{11} (k_{IO} - k_{S1})^T W_Q^T x_{E}^{L.H}\right)}_{\text{attn heads}} + \underbrace{\left(\sum_{L=0}^9 (k_{IO} - k_{S1})^T W_Q^T x_{E}^{L}\right)}_{\text{MLPs}} + \underbrace{(k_{IO} - k_{S1})^T W_Q^T x_{E}^{e} + (k_{IO} - k_{S1})^T W_Q^T x_{E}^{pe}}_{\text{direct path}}
$$

And on the right is the query-side decomposition, with baseline not subtracted:

$$
S = \underbrace{\left(\sum_{L=0}^9 \sum_{H=0}^{11} k^T W_Q^T x_{E}^{L.H}\right)}_{\text{attn heads}} + \underbrace{\left(\sum_{L=0}^9 k^T W_Q^T x_{E}^{L}\right)}_{\text{MLPs}} + \underbrace{k^T W_Q^T x_{E}^{e} + k^T W_Q^T x_{E}^{pe}}_{\text{direct path}} + \underbrace{k^T b_K}_{\text{bias term}}
$$

> As expected, very crisp on the two of the three name mover heads before layer 10 (which are 9.9 and 9.6), and basically zero everywhere else. 9.9 is especially strong (agreeing with the result from the IOI paper that it's the main name mover).

In [None]:
contribution_to_attn_scores_decompose_q = _decompose_attn_scores(
    decompose_by = "queries",
)
contribution_to_attn_scores_decompose_q_sub_S1 = _decompose_attn_scores(
    decompose_by = "queries",
    subtract_S1_attn_scores = True,
)

plot_contribution_to_attn_scores(
    t.stack([contribution_to_attn_scores_decompose_q_sub_S1, contribution_to_attn_scores_decompose_q]),
    decompose_by = "queries",
    facet_labels = ["Attn(END->IO), with baseline Attn(END->S1) subtracted", "Attn(END->IO)"],
    title = "Decompose on query-side, subtract vs don't subtract baseline",
    static = True,
)

### Try substituting the IO unembedding on query side

I no longer really endorse this plot or think it's saying much of interest, but nonetheless it's neat enough to leave here.

Idea - since the most important part of the query-side vector seems to be the unembedding of IO, let's sub in the unembedding of IO and see if that works better!

In other words, in the following decomposition:

$$
S = \underbrace{\left(\sum_{L=0}^9 \sum_{H=0}^{11} q^T W_K^T x_{IO}^{L.H}\right)}_{\text{attn heads}} + \underbrace{\left(\sum_{L=0}^9 q^T W_K^T x_{IO}^{L}\right)}_{\text{MLPs}} + \underbrace{q^T W_K^T x_{IO}^{e} + q^T W_K^T x_{IO}^{pe}}_{\text{direct path}} + \underbrace{q^T b_K}_{\text{bias term}}
$$

we replace $q$ (the query at the `END` token) with the vector $W_Q^T u_{IO} + b_Q$ (where $u_{IO}$ is the unembedding of the `IO` token).

I'll do four plots of key-side decomposition: with baseline subtracted / not subtracted, and with query $q$ replaced / not replaced.

> As expected, replacing the query $q$ with the unembedding vector makes the signals way cleaner. It's interesting that doing this cleans up the noise even if we don't subtract the baseline (although it makes signal weaker). This suggests that the noise from not subtracting baseline comes from all the non-$u_{IO}$ parts of the query. (This isn't surprising, but still useful to have confirmed.)

In [None]:
contribution_to_attn_scores_decompose_k = _decompose_attn_scores(
    decompose_by = "keys",
)
contribution_to_attn_scores_decompose_k_sub_S1 = _decompose_attn_scores(
    decompose_by = "keys",
    subtract_S1_attn_scores = True,
)
contribution_to_attn_scores_decompose_k_replace_q = _decompose_attn_scores(
    decompose_by = "keys",
    intervene_on_query = "sub_W_U_IO",
)
contribution_to_attn_scores_decompose_k_replace_q_sub_S1 = _decompose_attn_scores(
    decompose_by = "keys",
    intervene_on_query = "sub_W_U_IO",
    subtract_S1_attn_scores = True,
)

plot_contribution_to_attn_scores(
    t.stack([
        contribution_to_attn_scores_decompose_k_replace_q_sub_S1,
        contribution_to_attn_scores_decompose_k_replace_q,
        contribution_to_attn_scores_decompose_k_sub_S1,
        contribution_to_attn_scores_decompose_k
    ]),
    decompose_by = "keys",
    facet_labels = [
        "S1 baseline subtracted, Q replaced with W<sub>U</sub>[IO]", "S1 baseline subtracted, Q not altered",
        "S1 baseline not subtracted, Q replaced with W<sub>U</sub>[IO]", "S1 baseline not subtracted, Q not altered",
    ],
    facet_col_wrap = 2,
    title = "Decompose on key-side, subtract vs don't subtract baseline, also sub W<sub>U</sub>[IO] for Q vs don't sub",
    static = False,
)

## Decomp with perpendicular / parallel components

To take a deeper dive / sanity check, rather than replacing the query-side vector with the unembedding vector but having the same norm, I'm going to split the attention contributions into "contributions via the IO-unembed projection part" and "contributions via the IO-unembed-perpendicular part". Same for keys.

### Decomp keys

There are 8 plots below: 4 on the first animation frame, 4 on the second (you can move the slider to go between them).

The first 4 show the 4 different terms in the expression:

$$
S = \underbrace{\left(\sum_{L=0}^9 \sum_{H=0}^{11} q^T W_K^T (x_{IO}^{L.H} - x_{S1}^{L.H})\right)}_{\text{attn heads}} + \underbrace{\left(\sum_{L=0}^9 q^T W_K^T (x_{IO}^{L} - x_{S1}^{L})\right)}_{\text{MLPs}} + \underbrace{q^T W_K^T (x_{IO}^{e} - x_{S1}^{e}) + q^T W_K^T (x_{IO}^{pe} - x_{S1}^{pe})}_{\text{direct path}}
$$

where we turn this into 4 terms by:

* writing each component $x_{IO} = x_{IO}^{∥} + x_{IO}^{⟂}$, i.e. parallel / perpendicular to the output of MLP0 at `IO` (and same for $x_{S1}$),
* writing the query $q$ as $q = (W_Q^T x_E^∥) + (W_Q^T x_E^⟂ + b_Q)$, i.e. the contribution from the residual stream components parallel / perpendicular to the IO-unembedding vector (with the bias term lumped into the perpendicular part),
* and then expanding out the product into $(1 + 1) \times (1 + 1) = 4$ different terms.

The second 4 plots show the same terms, except without subtracting the baseline (i.e. we just have the components $x_{IO}$ rather than $x_{IO} - x_{S1}$).

> Just looking at the first set of 4 plots because I think it's more representative of what we care about.
> 
> The MLP0 shows up really strong in the first plot - much stronger than any other component anywhere else. This is good for our theory.
> 
> It's slightly surprising that there's still some query-side component which increases attention to the `IO` token relative the `S1` token. From this knowledge alone, my first guess would have been that this is a positional signal, but it interacts strongly with `MLP0` rather than `W_pos` so apparently not! It's also interesting that this shows up in the $k ∥ MLP_0$ plot rather than $k ⟂ MLP_0$. **Possible intuitive explanation: part of the vector written to residual stream signals "I want the later components to look back at this word, but I don't necessarily want to predict it."**

In [145]:
contribution_to_attn_scores_decompose_k_split_qk = _decompose_attn_scores(
    decompose_by = "keys",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    subtract_S1_attn_scores = False,
    include_S1_in_unembed_projection = False,
)

contribution_to_attn_scores_decompose_k_split_qk_sub_S1 = _decompose_attn_scores(
    decompose_by = "keys",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    subtract_S1_attn_scores = True,
    include_S1_in_unembed_projection = False,
)

contribution_to_attn_scores_decompose_k_split_qk_sub_S1_qboth = _decompose_attn_scores(
    decompose_by = "keys",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    subtract_S1_attn_scores = True,
    include_S1_in_unembed_projection = True,
)

plot_contribution_to_attn_scores(
    t.stack([
        t.stack([
            contribution_to_attn_scores_decompose_k_split_qk_sub_S1[("IO_dir", "MLP0_dir")],
            contribution_to_attn_scores_decompose_k_split_qk_sub_S1[("IO_dir", "MLP0_perp")],
            contribution_to_attn_scores_decompose_k_split_qk_sub_S1[("IO_perp", "MLP0_dir")],
            contribution_to_attn_scores_decompose_k_split_qk_sub_S1[("IO_perp", "MLP0_perp")],
        ]),
        t.stack([
            contribution_to_attn_scores_decompose_k_split_qk_sub_S1_qboth[("IO_dir", "MLP0_dir")],
            contribution_to_attn_scores_decompose_k_split_qk_sub_S1_qboth[("IO_dir", "MLP0_perp")],
            contribution_to_attn_scores_decompose_k_split_qk_sub_S1_qboth[("IO_perp", "MLP0_dir")],
            contribution_to_attn_scores_decompose_k_split_qk_sub_S1_qboth[("IO_perp", "MLP0_perp")],
        ]),
        t.stack([
            contribution_to_attn_scores_decompose_k_split_qk[("IO_dir", "MLP0_dir")],
            contribution_to_attn_scores_decompose_k_split_qk[("IO_dir", "MLP0_perp")],
            contribution_to_attn_scores_decompose_k_split_qk[("IO_perp", "MLP0_dir")],
            contribution_to_attn_scores_decompose_k_split_qk[("IO_perp", "MLP0_perp")],
        ]),
    ]),
    decompose_by = "keys",
    facet_labels = [
        "q ∥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>",
        "q ∥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>"
    ],
    facet_col_wrap = 2,
    animation_labels = ["Baseline subtracted", "Baseline subtracted, more Q-projection", "Baseline not subtracted"],
    title = "Decompose on key-side, split by q ∥/⟂ W<sub>U</sub>[IO] and k ∥/⟂ MLP<sub>0</sub>",
    static = False
)

In [141]:
pct_explained = (
    contribution_to_attn_scores_decompose_k_split_qk_sub_S1[('IO_dir', 'MLP0_dir')][1, -1]
    /
    t.stack(list(contribution_to_attn_scores_decompose_k_split_qk_sub_S1.values())).sum()
)
print(f"Pct of net attention diff explained by both parallel components (only MLP0 output): {pct_explained:.2%}")

pct_explained = (
    contribution_to_attn_scores_decompose_k_split_qk_sub_S1[('IO_dir', 'MLP0_dir')][:, -1].sum()
    /
    t.stack(list(contribution_to_attn_scores_decompose_k_split_qk_sub_S1.values())).sum()
)
print(f"Pct of net attention diff explained by both parallel components (all MLPs' output in MLP0 dir): {pct_explained:.2%}")

Pct of net attention diff explained by both parallel components (only MLP0 output): 23.48%
Pct of net attention diff explained by both parallel components (all MLPs' output in MLP0 dir): 41.95%


### Decomp keys

There are 8 plots below: 4 on the first animation frame, 4 on the second (you can move the slider to go between them).

The first 4 show the 4 different terms in the expression:

$$
S = \underbrace{\left(\sum_{L=0}^9 \sum_{H=0}^{11} (k_{IO} - k_{S1})^T W_Q^T x_{E}^{L.H}\right)}_{\text{attn heads}} + \underbrace{\left(\sum_{L=0}^9 (k_{IO} - k_{S1})^T W_Q^T x_{E}^{L}\right)}_{\text{MLPs}} + \underbrace{(k_{IO} - k_{S1})^T W_Q^T x_{E}^{e} + (k_{IO} - k_{S1})^T W_Q^T x_{E}^{pe}}_{\text{direct path}}
$$

where we turn this into 4 terms by:

* writing each component $x_{E} = x_{E}^{∥} + x_{E}^{⟂}$, i.e. parallel / perpendicular to the IO-unembedding vector,
* writing the keys as $k_{IO} - k_{S1} = W_K^T (x_{IO}^∥ - x_{S1}^∥) + W_K^T (x_{IO}^⟂ - x_{S1}^⟂)$, i.e. the contribution from the residual stream components parallel / perpendicular to the IO-unembedding vector (ignoring the bias term because it gets cancelled when we subtract S1 baseline),
* and then expanding out the product into $(1 + 1) \times (1 + 1) = 4$ different terms.

The second 4 plots show the same terms, except without subtracting the baseline (i.e. we just have the linear map from $k_{IO}$ rather than $k_{IO} - k_{S1}$).

> Just looking at the first set of 4 plots because I think it's more representative of what we care about.
> 
> The name movers shows up pretty strong in the first plot, and nothing else shows up very strong anywhere else. This is good for our theory.
> 
> However, the name movers in the top-left plot only capture about 1/3 of the attention scores. The other 2/3 comes from the $k ⟂ MLP_0$ component. This seems consistent with our previous results, which suggested that the IO-unembedding part of the name movers isn't the only thing they output. There's also some signal **"I want the later components to look back at this word, but I don't necessarily want to predict it."**

In [144]:
contribution_to_attn_scores_decompose_q_split_qk = _decompose_attn_scores(
    decompose_by = "queries",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    subtract_S1_attn_scores = False,
)
contribution_to_attn_scores_decompose_q_split_qk_sub_S1 = _decompose_attn_scores(
    decompose_by = "queries",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    subtract_S1_attn_scores = True,
)
contribution_to_attn_scores_decompose_q_split_qk_sub_S1_qboth = _decompose_attn_scores(
    decompose_by = "queries",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    subtract_S1_attn_scores = True,
    include_S1_in_unembed_projection = True,
)

plot_contribution_to_attn_scores(
    t.stack([
        t.stack([
            contribution_to_attn_scores_decompose_q_split_qk_sub_S1[("IO_dir", "MLP0_dir")],
            contribution_to_attn_scores_decompose_q_split_qk_sub_S1[("IO_dir", "MLP0_perp")],
            contribution_to_attn_scores_decompose_q_split_qk_sub_S1[("IO_perp", "MLP0_dir")],
            contribution_to_attn_scores_decompose_q_split_qk_sub_S1[("IO_perp", "MLP0_perp")],
        ]),
        t.stack([
            contribution_to_attn_scores_decompose_q_split_qk_sub_S1_qboth[("IO_dir", "MLP0_dir")],
            contribution_to_attn_scores_decompose_q_split_qk_sub_S1_qboth[("IO_dir", "MLP0_perp")],
            contribution_to_attn_scores_decompose_q_split_qk_sub_S1_qboth[("IO_perp", "MLP0_dir")],
            contribution_to_attn_scores_decompose_q_split_qk_sub_S1_qboth[("IO_perp", "MLP0_perp")],
        ]),
        t.stack([
            contribution_to_attn_scores_decompose_q_split_qk[("IO_dir", "MLP0_dir")],
            contribution_to_attn_scores_decompose_q_split_qk[("IO_dir", "MLP0_perp")],
            contribution_to_attn_scores_decompose_q_split_qk[("IO_perp", "MLP0_dir")],
            contribution_to_attn_scores_decompose_q_split_qk[("IO_perp", "MLP0_perp")],
        ]),
    ]),
    decompose_by = "queries",
    facet_labels = [
        "q ∥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>",
        "q ∥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>"
    ],
    facet_col_wrap = 2,
    animation_labels = ["Baseline subtracted", "Baseline subtracted, more Q-projection", "Baseline not subtracted"],
    title = "Decompose on query-side, split by q ∥/⟂ W<sub>U</sub>[IO] and k ∥/⟂ MLP<sub>0</sub>",
    static = True
)

In [143]:
mover_heads = [(9, 9), (9, 6)]

pct_explained = (
    sum([contribution_to_attn_scores_decompose_q_split_qk_sub_S1[('IO_dir', 'MLP0_dir')][layer+1, head] for (layer, head) in mover_heads])
    /
    t.stack(list(contribution_to_attn_scores_decompose_q_split_qk_sub_S1.values())).sum()
)

print(f"Pct of net attention diff explained by both parallel components: {pct_explained:.2%}")

Pct of net attention diff explained by both parallel components: 29.42%


# Fucking massive plot

In [124]:
model = model.to("cpu")

contribution_to_attn_scores = decompose_attn_scores_full(
    batch_size = BATCH_SIZE,
    seed = 0,
    nnmh = NNMH_LIST[0],
    model = model,
    use_effective_embedding = False,
    use_layer0_heads = False,
    subtract_S1_attn_scores = True,
    include_S1_in_unembed_projection = True,
)

Moving model to device:  cpu


In [125]:
create_fucking_massive_plot_1(contribution_to_attn_scores)

In [15]:
"Q = 9.9, K = MLP_0, q ⟂ W_U[IO], k ∥ MLP_0"

"Q = 9.9 ⟂ W_U[IO], K = MLP_0"

create_fucking_massive_plot_2(contribution_to_attn_scores)

In [126]:
create_fucking_massive_plot_2(contribution_to_attn_scores)