# Keys fixed

This notebook is where I generate the histograms & heatmaps of attention score contributions in the neg name mover heads.

The hypothesis is that the main contributor on the query-side is the unembedding of token X, and on the key-side is the embedding of token X. In the IOI task, X is the IO token.

## Setup

In [38]:
from transformer_lens.cautils.notebook import *
from transformer_lens.rs.callum.keys_fixed import (
    attn_scores_as_linear_func_of_keys,
    attn_scores_as_linear_func_of_queries,
    get_attn_scores_and_probs_as_linear_func_of_keys,
    get_attn_scores_and_probs_as_linear_func_of_queries,
)
from transformer_lens.rs.callum.generate_bag_of_words_quad_plot import get_effective_embedding

# effective_embeddings = get_effective_embedding(model) 

# W_U = effective_embeddings["W_U (or W_E, no MLPs)"]
# W_EE = effective_embeddings["W_E (including MLPs)"]
# W_EE_subE = effective_embeddings["W_E (only MLPs)"]

clear_output()

In [39]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)

clear_output()

# Keys fixed, attn is linear func of queries

In [48]:
NUM_BATCHES = 30
N = 40
NAME_TOKENS = model.to_tokens(NAMES, prepend_bos=False).squeeze().tolist()
NNMH_LIST = [(10, 7), (11, 10)]

attn_scores, attn_probs = get_attn_scores_and_probs_as_linear_func_of_queries(
    NNMH_LIST[0], 
    num_batches=NUM_BATCHES,
    batch_size=N,
    model=model,
    name_tokens=NAME_TOKENS,
)

  0%|          | 0/30 [00:00<?, ?it/s]

In [49]:
labels_list, attn_scores_list = zip(*attn_scores.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_scores_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_scores_list = [attn_scores_list[i] for i in sorted_indices]

h = hist(
    attn_scores_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn scores (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=120,
    return_fig=True,
)
for i in [1, 2]:
    h.data[2*i].visible = "legendonly"
    h.data[2*i+1].visible = "legendonly"
h.show()

Unembedding for IO has way more influence on the final attention scores than anything else. Most importantly, **it boosts attn scores more than the actual output of NMH 9.9**.

We've already observed that the name mover heads' output basically entirely creates the queries for the neg heads like 10.7 (i.e. if you path patch from name movers to 10.7 then attention from END to IO in 10.7 is no greater than attention from END to IO). So this suggests that **the entire reason the neg heads attend back to IO (on the query side) is because they pick up on the unembedding of IO which is stored there**.

## Attention probs?

The plot below shows this in terms of attention probs (because it's good to know that the attention scores above are sufficient to affect the probs), even though attention scores are a more natural way of thinking about the function from queries -> attention, because of linearity.

Sure enough, attention probs are nearly 1 with the actual output of NMH (and with no patching), but they get absolutely hammered up to 1 when we use the unembedding of IO instead.

In [42]:
labels_list, attn_probs_list = zip(*attn_probs.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_probs_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_probs_list = [attn_probs_list[i] for i in sorted_indices]

h = hist(
    attn_probs_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn probs (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=100,
    return_fig=True,
)
for i in [1, 2]:
    h.data[2*i].visible = "legendonly"
    h.data[2*i+1].visible = "legendonly"
h.show()

# Queries fixed

Now that we've got satisfying results when the keys are fixed (and attn is linear func of queries), what about the other way around? We want to show the equiv result, i.e. that the main thing determining attn on the key-side is the IO unembedding. If not, then what the hell does the unembedding of IO match with?

In [51]:
attn_scores, attn_probs = get_attn_scores_and_probs_as_linear_func_of_keys(
    NNMH_LIST[0], 
    num_batches=NUM_BATCHES,
    batch_size=N,
    model=model,
    name_tokens=NAME_TOKENS,
)

  0%|          | 0/30 [00:00<?, ?it/s]

In [52]:
labels_list, attn_scores_list = zip(*attn_scores.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_scores_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_scores_list = [attn_scores_list[i] for i in sorted_indices]

hist(
    attn_scores_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn scores (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=120
)

In [46]:
labels_list, attn_probs_list = zip(*attn_probs.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_probs_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_probs_list = [attn_probs_list[i] for i in sorted_indices]

hist(
    attn_probs_list,
    labels={"variable": "Query-side vector", "value": "Attention probs"},
    title="Attn probs (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=100
)

## Analysis - wtf?

Okay, this seems really strange. I'd have expected replacing the keyside vector with the embedding to increase attention scores (since we have unembedding attending back to embedding of same vector). But in fact, it decreases them (or at least noises them).

If the embedding isn't the main contributor, then what is the main contributor? Is it super distributed, or is one head / MLP mainly responsible?

# Which components contribute on the key-side?

Here I'm going to break down the components of the key position (by every head and every MLP) to see which one contributes most to the attention scores.

It'll be a heatmap of all the attention heads (and the MLP). Each value will the the attention score contribution (with appropriate LN scale applied).

In [55]:
batch_size = N
seed = 0
NNMH = NNMH_LIST[0]

ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric_noising = generate_data_and_caches(batch_size, model=model, seed=seed)

linear_map, bias_term = attn_scores_as_linear_func_of_keys(batch_idx=None, head=NNMH, model=model, ioi_cache=ioi_cache, ioi_dataset=ioi_dataset)
assert linear_map.shape == (batch_size, model.cfg.d_model)
assert bias_term.shape == (batch_size,)

contribution_to_attn_scores = t.zeros(1 + NNMH[0], model.cfg.n_heads + 1)

ln_scale = ioi_cache["scale", NNMH[0], "ln1"][range(batch_size), ioi_dataset.word_idx["IO"], NNMH[1]]

# bit hacky - having a zeroth layer for the embedding, and just putting it at zeroth column
embed = ioi_cache["embed"][range(batch_size), ioi_dataset.word_idx["IO"]]
pos_embed = ioi_cache["pos_embed"][range(batch_size), ioi_dataset.word_idx["IO"]]
embed_scaled = embed / ln_scale
pos_embed_scaled = pos_embed / ln_scale
contribution_to_attn_scores[0, 0] = einops.einsum(embed_scaled, linear_map, "batch d_model, batch d_model -> batch").mean()
contribution_to_attn_scores[0, 1] = einops.einsum(pos_embed_scaled, linear_map, "batch d_model, batch d_model -> batch").mean()
contribution_to_attn_scores[0, 2] = bias_term.mean()

for layer in range(NNMH[0]):

    z = ioi_cache["z", layer][range(batch_size), ioi_dataset.word_idx["IO"]]
    assert z.shape == (batch_size, model.cfg.n_heads, model.cfg.d_head)

    # ! todo - see if I can fit this on the cache, i.e. all heads at once (doesn't really matter though since it's super fast anyways)
    for head in range(model.cfg.n_heads):
        attn_out = einops.einsum(
            z[:, head], model.W_O[layer, head],
            "batch d_head, d_head d_model -> batch d_model"
        )
        results_scaled = attn_out / ln_scale
        contribution_to_attn_scores[1 + layer, head] = einops.einsum(results_scaled, linear_map, "batch d_model, batch d_model -> batch").mean()

    mlp_out = ioi_cache["mlp_out", layer][range(batch_size), ioi_dataset.word_idx["IO"]]
    assert mlp_out.shape == (batch_size, model.cfg.d_model)
    mlp_out_scaled = mlp_out / ln_scale
    contribution_to_attn_scores[1 + layer, -1] = einops.einsum(mlp_out_scaled, linear_map, "batch d_model, batch d_model -> batch").mean()

In [65]:
text = [["W_E", "W_pos", "b_K"] + ["" for _ in range(10)]]
for layer in range(0, 10):
    text.append([f"{layer}.{head}" for head in range(12)] + [f"MLP{layer}"])

fig = imshow(
    contribution_to_attn_scores,
    title="Contribution to attention scores (key-side)",
    labels={"x": "Component (attn heads & MLP)", "y": "Layer"},
    y=["misc"] + [str(i) for i in range(10)],
    x=[f"L.{i}" for i in range(12)] + ["MLP"],
    border=True,
    width=900,
    height=600,
    return_fig=True
)
fig.data[0].update(
    text=text, 
    texttemplate="%{text}", 
    textfont={"size": 12}
)
fig.show()

In [34]:
contribution_to_attn_scores.sum()

tensor(3.0991)

idea - sub out query-side with IO, then try again!