## Setup

In [10]:
from transformer_lens.cautils.notebook import *
from transformer_lens.rs.callum.keys_fixed import (
    attn_scores_as_linear_func_of_keys,
    attn_scores_as_linear_func_of_queries,
    get_attn_scores_and_probs_keys_fixed,
    get_attn_scores_and_probs_queries_fixed
)
from transformer_lens.rs.callum.generate_bag_of_words_quad_plot import get_effective_embedding

# effective_embeddings = get_effective_embedding(model) 

# W_U = effective_embeddings["W_U (or W_E, no MLPs)"]
# W_EE = effective_embeddings["W_E (including MLPs)"]
# W_EE_subE = effective_embeddings["W_E (only MLPs)"]

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)

clear_output()

In [3]:
NUM_BATCHES = 40
N = 60
NAME_TOKENS = model.to_tokens(NAMES, prepend_bos=False).squeeze().tolist()
NNMHs = [(10, 7), (11, 10)]

attn_scores, attn_probs = get_attn_scores_and_probs_keys_fixed(
    NNMHs[0], 
    num_batches=NUM_BATCHES,
    batch_size=N,
    model=model,
    name_tokens=NAME_TOKENS,
)

  0%|          | 0/40 [00:00<?, ?it/s]

In [5]:
labels_list, attn_scores_list = zip(*attn_scores.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_scores_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_scores_list = [attn_scores_list[i] for i in sorted_indices]

hist(
    attn_scores_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn scores (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=100
)

In [5]:
labels_list, attn_probs_list = zip(*attn_probs.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_probs_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_probs_list = [attn_probs_list[i] for i in sorted_indices]

hist(
    attn_probs_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn probs (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=100
)

In [11]:
attn_scores, attn_probs = get_attn_scores_and_probs_queries_fixed(
    NNMHs[0], 
    num_batches=NUM_BATCHES,
    batch_size=N,
    model=model,
    name_tokens=NAME_TOKENS,
)

  0%|          | 0/40 [00:00<?, ?it/s]

In [16]:
labels_list, attn_scores_list = zip(*attn_scores.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_scores_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_scores_list = [attn_scores_list[i] for i in sorted_indices]

hist(
    attn_scores_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn scores (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=100
)

In [15]:
labels_list, attn_probs_list = zip(*attn_probs.items())

sorted_indices = t.argsort(t.tensor([score.mean() for score in attn_probs_list]))
labels_list = [labels_list[i] for i in sorted_indices]
attn_probs_list = [attn_probs_list[i] for i in sorted_indices]

hist(
    attn_probs_list,
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn probs (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=labels_list,
    width=1000,
    height=600,
    opacity=0.6,
    marginal="box",
    template="simple_white",
    nbins=100
)

# Which components contribute?

Here I'm going to break down the components of the key position (by every head and every MLP) to see which one contributes most to the attention scores.

It'll be a heatmap of all the attention heads (and the MLP). Each value will the the attention score contribution (with appropriate LN scale applied).

In [41]:
batch_size = N
seed = 0
NNMH = NNMHs[0]

ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric_noising = generate_data_and_caches(batch_size, model=model, seed=seed)

linear_map = attn_scores_as_linear_func_of_keys(batch_idx=None, head=NNMH, model=model, ioi_cache=ioi_cache, ioi_dataset=ioi_dataset)
assert linear_map.shape == (batch_size, model.cfg.d_model)

contribution_to_attn_scores = t.zeros(1 + NNMH[0], model.cfg.n_heads + 1)

ln_scale = ioi_cache["scale", NNMH[0], "ln1"][range(batch_size), ioi_dataset.word_idx["IO"], NNMH[1]]
# ! todo - update my diagram, so people use "ln1"

# bit hacky - having a zeroth layer for the embedding, and just putting it at zeroth column
embed = ioi_cache["embed"][range(batch_size), ioi_dataset.word_idx["IO"]]
pos_embed = ioi_cache["pos_embed"][range(batch_size), ioi_dataset.word_idx["IO"]]
embed_scaled = embed / ln_scale
pos_embed_scaled = pos_embed / ln_scale
contribution_to_attn_scores[0, 0] = einops.einsum(embed_scaled, linear_map, "batch d_model, batch d_model -> batch").mean()
contribution_to_attn_scores[0, 1] = einops.einsum(pos_embed_scaled, linear_map, "batch d_model, batch d_model -> batch").mean()

for layer in range(NNMH[0]):

    z = ioi_cache["z", layer][range(batch_size), ioi_dataset.word_idx["IO"]]
    assert z.shape == (batch_size, model.cfg.n_heads, model.cfg.d_head)

    # ! todo - see if I can fit this on the cache, i.e. all heads at once
    for head in range(model.cfg.n_heads):
        attn_out = einops.einsum(
            z[:, head], model.W_O[layer, head],
            "batch d_head, d_head d_model -> batch d_model"
        )
        results_scaled = attn_out / ln_scale
        contribution_to_attn_scores[1 + layer, head] = einops.einsum(results_scaled, linear_map, "batch d_model, batch d_model -> batch").mean()

    mlp_out = ioi_cache["mlp_out", layer][range(batch_size), ioi_dataset.word_idx["IO"]]
    assert mlp_out.shape == (batch_size, model.cfg.d_model)
    mlp_out_scaled = mlp_out / ln_scale
    contribution_to_attn_scores[1 + layer, -1] = einops.einsum(mlp_out_scaled, linear_map, "batch d_model, batch d_model -> batch").mean()

In [62]:
fig = imshow(
    contribution_to_attn_scores,
    title="Contribution to attention scores (key-side)",
    labels={"x": "Component (attn heads & MLP)", "y": "Layer"},
    y=["pre-layers"] + [str(i) for i in range(12)],
    border=True,
    width=1000,
    height=800
)

In [60]:
text = [["W_E", "W_pos"] + ["" for _ in range(11)]]
for layer in range(0, 10):
    text.append([f"{layer}.{head}" for head in range(12)] + [f"MLP{layer}"])

In [64]:
fig.data[0].update(
    text=text, 
    texttemplate="%{text}", 
    textfont={"size": 12}
)

fig.update_layout(height=600, width=850)
fig.show()

In [34]:
contribution_to_attn_scores.sum()

tensor(3.0991)