This notebook is where I generate the histograms & heatmaps of attention score contributions in the neg name mover heads.

The hypothesis is that the main contributor on the query-side is the unembedding of token X, and on the key-side is the embedding of token X. In the IOI task, X is the IO token.

## Setup

In [None]:
from transformer_lens.cautils.notebook import *
from transformer_lens.rs.callum.keys_fixed import (
    attn_scores_as_linear_func_of_keys,
    attn_scores_as_linear_func_of_queries,
    get_attn_scores_and_probs_as_linear_func_of_keys,
    get_attn_scores_and_probs_as_linear_func_of_queries,
    decompose_attn_scores,
    plot_contribution_to_attn_scores,
    project
)
from transformer_lens.rs.callum.generate_bag_of_words_quad_plot import get_effective_embedding

# effective_embeddings = get_effective_embedding(model) 

# W_U = effective_embeddings["W_U (or W_E, no MLPs)"]
# W_EE = effective_embeddings["W_E (including MLPs)"]
# W_EE_subE = effective_embeddings["W_E (only MLPs)"]

clear_output()

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

clear_output()

# What happens when you freeze keys / queries, and patch the other?

## Freeze keys, patch queries

In [None]:
NUM_BATCHES = 30
N = 40
NAME_TOKENS = model.to_tokens(NAMES, prepend_bos=False).squeeze().tolist()
NNMH_LIST = [(10, 7), (11, 10)]

attn_scores, attn_probs = get_attn_scores_and_probs_as_linear_func_of_queries(
    NNMH_LIST[0], 
    num_batches=NUM_BATCHES,
    batch_size=N,
    model=model,
    name_tokens=NAME_TOKENS,
)

In [None]:
labels_list, attn_scores_list = zip(*attn_scores.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_scores_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_scores_list = [attn_scores_list[i] for i in sorted_indices]

h = hist(
    attn_scores_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn scores (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=120,
    return_fig=True,
)
for i in [1, 2]:
    h.data[2*i].visible = "legendonly"
    h.data[2*i+1].visible = "legendonly"
h.show(config={"staticPlot": True})

Unembedding for IO has way more influence on the final attention scores than anything else. Most importantly, **it boosts attn scores more than the actual output of NMH 9.9**.

We've already observed that the name mover heads' output basically entirely creates the queries for the neg heads like 10.7 (i.e. if you path patch from name movers to 10.7 then attention from END to IO in 10.7 is no greater than attention from END to IO). So this suggests that **the entire reason the neg heads attend back to IO (on the query side) is because they pick up on the unembedding of IO which is stored there**.

I also want to check how large the component of `W_U[IO]` is in the output of head 9.9. Let's do that quickly.

In [None]:
ioi_dataset, ioi_cache = generate_data_and_caches(80, model=model, only_ioi=True, prepend_bos=True)

z = ioi_cache["z", 10][range(80), ioi_dataset.word_idx["end"], 7]
result = einops.einsum(z, model.W_O[10, 7], "batch d_head, d_head d_model -> batch d_model")
result_normalized = result / result.norm(dim=-1, keepdim=True)

io_dir = model.W_U.T[ioi_dataset.io_tokenIDs]

result_in_io_dir, result_in_perp_dir = project(result_normalized, io_dir)

io_component_norm = result_in_io_dir.norm(dim=-1).mean()
io_perp_component_norm = result_in_perp_dir.norm(dim=-1).mean()

print(f"IO component norm: {io_component_norm:.4f}")
print(f"IO perp component norm: {io_perp_component_norm:.4f}")

assert (io_component_norm ** 2 + io_perp_component_norm ** 2 - 1).abs() < 5e-3

Yeah okay, not that large. It's an important part of its output, but not the most important part. This suggests that 9.9 is partly communicating with other heads (maybe the neg heads!) rather than just writing in the unembedding direction of IO.

### Attn probs?

The plot below shows this in terms of attention probs (because it's good to know that the attention scores above are sufficient to affect the probs), even though attention scores are a more natural way of thinking about the function from queries -> attention, because of linearity.

Sure enough, attention probs are nearly 1 with the actual output of NMH (and with no patching), but they get absolutely hammered up to 1 when we use the unembedding of IO instead.

In [None]:
labels_list, attn_probs_list = zip(*attn_probs.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_probs_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_probs_list = [attn_probs_list[i] for i in sorted_indices]

h = hist(
    attn_probs_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn probs (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=100,
    return_fig=True,
)
for i in [1, 2]:
    h.data[2*i].visible = "legendonly"
    h.data[2*i+1].visible = "legendonly"
h.show(config={"staticPlot": True})

## Freeze queries, patch keys

Now that we've got satisfying results when the keys are fixed (and attn is linear func of queries), what about the other way around? We want to show the equiv result, i.e. that the main thing determining attn on the key-side is the IO unembedding. If not, then what the hell does the unembedding of IO match with?

In [43]:
attn_scores, attn_probs = get_attn_scores_and_probs_as_linear_func_of_keys(
    NNMH_LIST[0], 
    num_batches=NUM_BATCHES,
    batch_size=N,
    model=model,
)

  0%|          | 0/30 [00:00<?, ?it/s]

In [44]:
labels_list, attn_scores_list = zip(*attn_scores.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_scores_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_scores_list = [attn_scores_list[i] for i in sorted_indices]

hist(
    attn_scores_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn scores (from END -> IO) in NNMH 10.7 (queries fixed, scores are linear func of keys)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=120,
    static=True,
)

In [47]:
attn_scores, attn_probs = get_attn_scores_and_probs_as_linear_func_of_keys(
    NNMH_LIST[0], 
    num_batches=NUM_BATCHES,
    batch_size=N,
    model=model,
    subtract_S1_attn_scores=True,
)

labels_list, attn_scores_list = zip(*attn_scores.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_scores_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_scores_list = [attn_scores_list[i] for i in sorted_indices]

hist(
    attn_scores_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn scores (from END -> IO) in NNMH 10.7 (queries fixed, scores are linear func of keys)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=120,
    static=True,
)

  0%|          | 0/30 [00:00<?, ?it/s]

In [None]:
labels_list, attn_probs_list = zip(*attn_probs.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_probs_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_probs_list = [attn_probs_list[i] for i in sorted_indices]

hist(
    attn_probs_list,
    labels={"variable": "Query-side vector", "value": "Attention probs"},
    title="Attn probs (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=100,
    static=True,
)

## Analysis - wtf?

Okay, this seems really strange. I'd have expected replacing the keyside vector with the embedding to increase attention scores (since we have unembedding attending back to embedding of same vector). But in fact, it decreases them (or at least noises them).

If the embedding isn't the main contributor, then what is the main contributor? Is it super distributed, or is one head / MLP mainly responsible?

# Which components contribute on the key and query sides?

Here I'm going to break down the components of the key position (by every head and every MLP) to see which one contributes most to the attention scores.

It'll be a heatmap of all the attention heads (and the MLP). Each value will the the attention score contribution (with appropriate LN scale applied).

## Formalism

Here's an equation for a single scalar attention score $S$:

$$
S = (x_Q^T W_Q + b_Q^T)(x_K^T W_K^T + b_K^T)^T
$$

where $x_Q$ is the query-side vector (of shape `d_model`), and $x_K$ is the key-side vector.

### Decomposing on key-side

If we fix queries $q^T = x_Q^T W_Q + b_Q^T$, we can expand this into:

$$
S = x_K^T W_K q + b_K^T q
$$

and if we want to figure out which key-side components affect the attention most, then we can decompose $x_K$ as a sum of terms.

Rather than calculating the attn score contribution from `END -> IO`, we might want to subtract the contribution from `END -> S1`. This effectively removes noise from our analysis (since we're concerned with how the name movers attend to `IO` rather than `S1`). In other words, we have $x_K^{IO} - x_K^{S1}$ in place of $x_K$.

Below I'll generate some plots which don't subtract the baseline, but by default the first plot in each facet / animation frame will subtract the baseline.

### Decomposing on query-side

If we fix queries $k^T = x_K^T W_K + b_K^T$, we can expand this into:

$$
S = x_Q^T W_Q k + b_Q^T k
$$

If we want to subtract the baseline attention to `S1`, then rather than changing our components $x_Q$, we change our linear map:

$$
S = x_Q^T W_Q (k^{IO} - k^{S1}) + b_Q^T (k^{IO} - k^{S1})
$$

### Decomposing into parallel and perpendicular components

On top of this, we can further decompose the attention scores to just get the bits we care about. For instance, if we think that the most important query-side component in the residual stream is the part in the direction of the unembedding $W_U[IO]$, then when decomposing by keys we can do this:

$$
S = (x_K^T W_K + b_K^T)q_{∥} + (x_K^T W_K + b_K^T)q_{⟂}
$$

where $q_{∥}$ is the component of $q$ coming from the portion of the residual stream (pre-$W_Q$) in the direction of $W_U[IO]$, and $q_{⟂}$ is the component of $q$ perpendicular to $W_U[IO]$. In this way, we generate two separate attribution heatmaps. If the first of these stands out more (and MLP0 is highlighted), while the second just looks like random noise and is generally small in magnitude, this is evidence that the model is doing exactly the form of prediction-attention which we theorised.

We can also decompose the key-side components: $x_K = (x_K)_{∥} + (x_K)_{⟂}$, where $(x_K)_{∥}$ is the component of $x_K$ in the direction of the MLP output at the IO sequence position, and $(x_K)_{⟂}$ is the component of $x_K$ perpendicular to this. We could combine all of these, and generate four heatmaps (one for each bilinear form in the full expansion).

We can do the same thing if we're decomposing by queries too.

Note on the bias term - by default, I always include the bias component in the ⟂ rather than the ∥ terms, because the bias is a constant, and the reason for separating out ∥ is to try and identify the terms which actively contribute to the answer in a particular direction, not just things which are always the same!

In [48]:
# define a new decompose_attn_scores function which cuts down on all the code! (I don't use partial cause I don't like when the function shows up as blue)

BATCH_SIZE = 50

def _decompose_attn_scores(
    decompose_by: Literal["keys", "queries"],
    show_plot: bool = False,
    intervene_on_query: Literal["sub_W_U_IO", "project_to_W_U_IO", None] = None,
    intervene_on_key: Literal["sub_MLP0", "project_to_MLP0", None] = None,
    use_effective_embedding: bool = False,
    subtract_S1_attn_scores: bool = False,
    static: bool = False,
):
    return decompose_attn_scores(
        batch_size=BATCH_SIZE,
        seed=42,
        nnmh=(10, 7),
        model=model,
        decompose_by=decompose_by,
        show_plot=show_plot,
        intervene_on_query=intervene_on_query,
        intervene_on_key=intervene_on_key,
        use_effective_embedding=use_effective_embedding,
        subtract_S1_attn_scores=subtract_S1_attn_scores,
        static=static,
    )

## Decompose with baseline subtracted / not subtracted

First, before any parallel/perp decomposition, I want to compare what the baseline being subtracted / not subtracted does.

Prior is that baseline being subtracted leads to less noisy plots where we can observe our prediction-attention patterns more crisply.


#### Conclusion

Theory looks pretty good! It's very crisp on MLP0 when decomposed key-side, and very crisp on the name movers when decomposed on query-side.

In [69]:
contribution_to_attn_scores_decompose_k = _decompose_attn_scores(
    decompose_by = "keys",
)
contribution_to_attn_scores_decompose_k_sub_S1 = _decompose_attn_scores(
    decompose_by = "keys",
    subtract_S1_attn_scores = True,
)

plot_contribution_to_attn_scores(
    t.stack([contribution_to_attn_scores_decompose_k_sub_S1, contribution_to_attn_scores_decompose_k]),
    decompose_by = "keys",
    facet_labels = ["Attn(END->IO), with baseline Attn(END->S1) subtracted", "Attn(END->IO)"],
    title = "Decompose on key-side, subtract vs don't subtract baseline",
    static = True,
)

In [70]:
contribution_to_attn_scores_decompose_q = _decompose_attn_scores(
    decompose_by = "queries",
)
contribution_to_attn_scores_decompose_q_sub_S1 = _decompose_attn_scores(
    decompose_by = "queries",
    subtract_S1_attn_scores = True,
)

plot_contribution_to_attn_scores(
    t.stack([contribution_to_attn_scores_decompose_q_sub_S1, contribution_to_attn_scores_decompose_q]),
    decompose_by = "queries",
    facet_labels = ["Attn(END->IO), with baseline Attn(END->S1) subtracted", "Attn(END->IO)"],
    title = "Decompose on query-side, subtract vs don't subtract baseline",
    static = True,
)

### Try substituting the IO unembedding on query side

I no longer really like this plot or think it's saying much of interest, but nonetheless it's kinda neat.

Idea - since the most important part of the query-side vector seems to be the unembedding of IO, let's sub in the unembedding of IO and see if that works better!

I'll do four plots of key-side decomposition: with baseline subtracted / not subtracted, and query-side vector is IO vs what it actually is.


#### Conclusion

Yep, great clearing up!

Also, note that either subtracting the baseline or injecting the unembedding query-side is sufficient to clean the plot up, both are super strong!

This updates me towards:

> *Head `10.7` is doing some idealised version of copying-suppression. It's just that in real life this tends to be messier, e.g. if there are other things than just the unembedding of the token to be suppressed stored in the query position.*

In [71]:
contribution_to_attn_scores_decompose_k = _decompose_attn_scores(
    decompose_by = "keys",
)
contribution_to_attn_scores_decompose_k_sub_S1 = _decompose_attn_scores(
    decompose_by = "keys",
    subtract_S1_attn_scores = True,
)
contribution_to_attn_scores_decompose_k_replace_q = _decompose_attn_scores(
    decompose_by = "keys",
    intervene_on_query = "sub_W_U_IO",
)
contribution_to_attn_scores_decompose_k_replace_q_sub_S1 = _decompose_attn_scores(
    decompose_by = "keys",
    intervene_on_query = "sub_W_U_IO",
    subtract_S1_attn_scores = True,
)

plot_contribution_to_attn_scores(
    t.stack([
        contribution_to_attn_scores_decompose_k_replace_q_sub_S1, contribution_to_attn_scores_decompose_k_replace_q,
        contribution_to_attn_scores_decompose_k_sub_S1, contribution_to_attn_scores_decompose_k
    ]),
    decompose_by = "keys",
    facet_labels = [
        "S1 baseline subtracted, Q replaced with W<sub>U</sub>[IO]", "S1 baseline subtracted, Q not altered",
        "S1 baseline not subtracted, Q replaced with W<sub>U</sub>[IO]", "S1 baseline not subtracted, Q not altered",
    ],
    facet_col_wrap = 2,
    title = "Decompose on key-side, subtract vs don't subtract baseline, also sub W<sub>U</sub>[IO] for Q vs don't sub",
    static = True,
)

## Decomp with perpendicular / parallel components

To take a deeper dive / sanity check, rather than replacing the query-side vector with the unembedding vector but having the same norm, I'm going to split the attention contributions into "contributions via the IO-unembed projection part" and "contributions via the IO-unembed-perpendicular part". That way, it'll show me a more reasonable picture of "how much does MLP0 really help on the key-side, given what the query *actually* is?

I do the same thing for keys: splitting them into projections on the MLP0-direction, and not on the MLP0-direction.

### Decomp keys

In [75]:
contribution_to_attn_scores_decompose_k_split_qk = _decompose_attn_scores(
    decompose_by = "keys",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    subtract_S1_attn_scores = False,
).flatten(0, 1)
contribution_to_attn_scores_decompose_k_split_qk_sub_S1 = _decompose_attn_scores(
    decompose_by = "keys",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    subtract_S1_attn_scores = True,
).flatten(0, 1)

plot_contribution_to_attn_scores(
    t.stack([contribution_to_attn_scores_decompose_k_split_qk_sub_S1, contribution_to_attn_scores_decompose_k_split_qk]),
    decompose_by = "keys",
    facet_labels = [
        "q ∥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>",
        "q ∥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>"
    ],
    facet_col_wrap = 2,
    animation_labels = ["Attn(END->IO), with baseline Attn(END->S1) subtracted", "Attn(END->IO)"],
    title = "Decompose on key-side, split by q ∥/⟂ W<sub>U</sub>[IO] and k ∥/⟂ MLP<sub>0</sub>",
    static = True
)

### Decomp queries

In [81]:
contribution_to_attn_scores_decompose_q_split_qk = _decompose_attn_scores(
    decompose_by = "queries",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    subtract_S1_attn_scores = False,
).flatten(0, 1)
contribution_to_attn_scores_decompose_q_split_qk_sub_S1 = _decompose_attn_scores(
    decompose_by = "queries",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    subtract_S1_attn_scores = True,
).flatten(0, 1)

plot_contribution_to_attn_scores(
    t.stack([contribution_to_attn_scores_decompose_q_split_qk_sub_S1, contribution_to_attn_scores_decompose_q_split_qk]),
    decompose_by = "queries",
    facet_labels = [
        "q ∥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>",
        "q ∥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>"
    ],
    facet_col_wrap = 2,
    animation_labels = ["Attn(END->IO), with baseline Attn(END->S1) subtracted", "Attn(END->IO)"],
    title = "Decompose on key-side, split by q ∥/⟂ W<sub>U</sub>[IO] and k ∥/⟂ MLP<sub>0</sub>",
    static = False
)

# TODO - try the effective embeddings. Maybe the top-right plot could be improved if we captured the effective embedding in a better way?

In [None]:
pct_explained = (contribution_to_attn_scores_split_qk_decompose_k_sub_S1[0, 1, -1] / contribution_to_attn_scores_split_qk_decompose_k_sub_S1.sum()).item()

print(f"Pct of net attention diff explained by both parallel components: {pct_explained:.2%}")

This makes sense - we've been unable to scale the MLP0 component, because it also fucks up when dot producted with the `W_U_IO_perpdir` direction!

This updates me towards "the name movers are doing a clear, well-defined algorithm of copy-suppression, there's just some messy shit on the top". This is good!

In [None]:
contribution_to_attn_scores_split_qk_decompose_q = _decompose_attn_scores(
    show_plot = False,
    decompose_by = "queries",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
).flatten(0, 1)
contribution_to_attn_scores_split_qk_decompose_q_sub_S1 = _decompose_attn_scores(
    show_plot = False,
    decompose_by = "queries",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    subtract_S1_attn_scores = True,
).flatten(0, 1)

plot_contribution_to_attn_scores(
    t.stack([contribution_to_attn_scores_split_qk_decompose_q_sub_S1, contribution_to_attn_scores_split_qk_decompose_q]),
    decompose_by = "queries",
    facet_labels = [
        "q ∥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>",
        "q ∥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>"
    ],
    facet_col_wrap = 2,
    animation_labels = ["Baseline END -> S1 subtracted", "Attn scores from END -> IO"],
    title = "Contribution to attention scores, decomposed by query-side component, split by projections on key and query-side",
    static = True
)

There are 2 questions I now want answered:

1. What happens if I break down the query-side components, and see which one of the `W_U_IO_perpdir` query-side components results in an annoyingly negative dot product with MLP0?

2. What happens if, rather than projecting key-side bits onto the actual MLP0 output, I project them onto the effective embedding (which should be a good approximation for the MLP0 output)?

### 1. Break down query-side

In [None]:
contribution_to_attn_scores_query_decomp_full4 = decompose_attn_scores(
    batch_size = BATCH_SIZE,
    seed = 42,
    nnmh = NNMH_LIST[0],
    model = model,
    show_plot = False,
    decompose_by = "queries",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
)
plot_contribution_to_attn_scores(
    einops.rearrange(contribution_to_attn_scores_query_decomp_full4, "k_facet q_facet layer component -> (q_facet k_facet) layer component"),
    decompose_by = "queries",
    facet_labels = [
        "q ∥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>",
        "q ∥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>"
    ],
    title = "Contribution to attention scores, decomposed by query-side component, split by projections on key and query-side",
    static = True,
)

Let's break down these facets one by one.

* **Top-left**
    * This is good, and agrees with our theories.
    * The name movers are the main ones which write to the query side in the `W_U_IO_dir`, and positively interact with the `MLP0_dir` while doing so.
* **Top-right**
    * This is good, and agrees with our theories.
    * We know that the `W_U_IO_dir` is a big part of the name movers' output, and this tells us that they pretty much only affect the attention scores via the **prediction-attention mechanism** which we've theorised (i.e. they have a high dot product with the `MLP0_dir`, but very small with `MLP0_perpdir`).
* **Bottom-right**
    * Not super meaningful I think. It's nice that nothing here stands out as being particularly massive - I think this is just noise.
* **Bottom-left**
    * This is where things get spicy!
    * Firstly, let's not over-update on the query bias. This term might be pretty similar for all attention scores (because the MLP0 outputs will all be somewhat similar). Would be worth testing this - rather than just getting attention scores, we could get attention scores diff (i.e. relative the other attention scores from `END` to other tokens in the sequence).
    * Head 9.9 does a lot of work in this facet plot. This is pretty weird and surprising - why does so much of the name mover head's output work **against** prediction-attention? My guess - there are some offsetting forces here. There are some situations where prediction-attention is actively bad, and the model needs a way of overriding the default tendency to attend back.


**Overall conclusion**

This is a reasonable thing to happen because the residual stream on the key-side is so fucking noisy!

### 2. Effective embedding

When we do this, things just get a bit noisier and messier. This is to be expected, since the effective embedding was only an approximation. The fact that head 9.9 is drowned out in the top-left plot makes me think this.

In [None]:
contribution_to_attn_scores_query_decomp_full4 = decompose_attn_scores(
    batch_size = BATCH_SIZE,
    seed = 42,
    nnmh = NNMH_LIST[0],
    model = model,
    show_plot = False,
    decompose_by = "queries",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    use_effective_embedding = True,
)
plot_contribution_to_attn_scores(
    einops.rearrange(contribution_to_attn_scores_query_decomp_full4, "k_facet q_facet layer component -> (q_facet k_facet) layer component"),
    decompose_by = "queries",
    facet_labels = [
        "q ∥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>",
        "q ∥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>"
    ],
    title = "Contribution to attention scores, decomposed by query-side component, split by projections on key and query-side",
    static = True,
)

# TODO - get a better effective embedding?

## Normalize

Gonna see if subtracting other attention scores (either the means, or the `"END -> S1"` ones) cleans up some of the noise (particularly for the perp plots).

In [None]:
# attn_scores_as_linear_func_of_queries(
#     batch_idx: Optional[Union[int, List[int], Int[Tensor, "batch"]]],
#     head: Tuple[int, int],
#     model: HookedTransformer,
#     ioi_cache: ActivationCache,
#     ioi_dataset: IOIDataset,
# )

'''
subtract_S1_attn_scores:
    If "S1", we subtract the attention score from "END" to "S1"
    This seems like it might help clear up some annoying noise we see in the plots, and make the core pattern a bit cleaner.

    To be clear: 
        if decompose_by == "keys", then for each keyside component, we want to see if (END -> component_IO) is higher than (END -> component_S1)
        if decompose_by == "queries", then for each queryside component, we want to see if (component_END -> IO) is higher than (component_END -> S1)
'''