# Setup

In [1]:
from transformer_lens.cautils.notebook import *
from transformer_lens.rs.callum.keys_fixed import (
    attn_scores_as_linear_func_of_keys,
    attn_scores_as_linear_func_of_queries,
    get_attn_scores_as_linear_func_of_queries_for_histogram,
    get_attn_scores_as_linear_func_of_keys_for_histogram,
    decompose_attn_scores,
    plot_contribution_to_attn_scores,
    project,
    decompose_attn_scores_full,
    create_fucking_massive_plot_1,
    create_fucking_massive_plot_2,
    get_effective_embedding_2,
)

# effective_embeddings = get_effective_embedding(model) 

# W_U = effective_embeddings["W_U (or W_E, no MLPs)"]
# W_EE = effective_embeddings["W_E (including MLPs)"]
# W_EE_subE = effective_embeddings["W_E (only MLPs)"]

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)
model.set_use_split_qkv_normalized_input(True)

clear_output()

In [3]:
BATCH_SIZE = 30
NEG_NMH = (10, 7)
seed = 42

ioi_dataset, ioi_cache = generate_data_and_caches(BATCH_SIZE, model=model, seed=seed, only_ioi=True, prepend_bos=True)

### Hook functions

In [7]:
def hook_fn_queries(
    q_input: Float[Tensor, "batch seq n_heads d_model"], 
    hook: HookPoint,
    head: Tuple[int, int] = NEG_NMH,
    ioi_dataset: IOIDataset = ioi_dataset,
    model: HookedTransformer = model,
    project_in_S_dir: bool = True,
    par: bool = True,
):
    unembed_IO = model.W_U.T[ioi_dataset.io_tokenIDs] # (batch, d_model)
    unembed_S = model.W_U.T[ioi_dataset.s_tokenIDs] # (batch, d_model)

    proj_dirs = [unembed_IO, unembed_S] if project_in_S_dir else [unembed_IO]
    
    q_slice = q_input[range(len(ioi_dataset)), ioi_dataset.word_idx["end"], head[1]]
    assert q_slice.shape == unembed_IO.shape
    q_input_par, q_input_perp = project(q_slice, proj_dirs)

    q_input[range(len(ioi_dataset)), ioi_dataset.word_idx["end"], head[1]] = (q_input_par if par else q_input_perp)

    return q_input


model.reset_hooks()
t.cuda.empty_cache()

model.add_hook(utils.get_act_name("q_input", NEG_NMH[0]), hook_fn_queries)

logits, cache = model.run_with_cache(ioi_dataset.toks)

cv.attention.attention_patterns(
    attention = cache["pattern", NEG_NMH[0]][0, [NEG_NMH[1]]],
    tokens = model.to_str_tokens(ioi_dataset.toks[0]),
    attention_head_names = ["10.7"]
)

In [10]:
def hook_fn_keys(
    k_input: Float[Tensor, "batch seq n_heads d_model"], 
    hook: HookPoint,
    head: Tuple[int, int] = NEG_NMH,
    ioi_dataset: IOIDataset = ioi_dataset,
    ioi_cache: ActivationCache = ioi_cache,
    model: HookedTransformer = model,
    project_in_S_dir: bool = True,
    par: bool = True,
):
    N = len(ioi_dataset)
    mlp0_dir_IO = ioi_cache["mlp_out", 0][range(N), ioi_dataset.word_idx["IO"]] # (batch, d_model)
    mlp0_dir_S = ioi_cache["mlp_out", 0][range(N), ioi_dataset.word_idx["S1"]] # (batch, d_model)

    k_input_IO = k_input[range(len(ioi_dataset)), ioi_dataset.word_idx["IO"], head[1]]
    k_input_S = k_input[range(len(ioi_dataset)), ioi_dataset.word_idx["S1"], head[1]]

    assert k_input_IO.shape == mlp0_dir_IO.shape
    k_input_IO_par, k_input_IO_perp = project(k_input_IO, mlp0_dir_IO)
    k_input_S_par, k_input_S_perp = project(k_input_S, mlp0_dir_S)

    k_input[range(len(ioi_dataset)), ioi_dataset.word_idx["IO"], head[1]] = (k_input_IO_par if par else k_input_IO_perp)
    k_input[range(len(ioi_dataset)), ioi_dataset.word_idx["S1"], head[1]] = (k_input_S_par if par else k_input_S_perp)

    return k_input


model.reset_hooks(including_permanent=True)
t.cuda.empty_cache()

model.add_hook(utils.get_act_name("q_normalized_input", NEG_NMH[0]), hook_fn_queries)
model.add_hook(utils.get_act_name("k_normalized_input", NEG_NMH[0]), hook_fn_keys)

logits, cache = model.run_with_cache(ioi_dataset.toks)

cv.attention.attention_patterns(
    attention = cache["pattern", NEG_NMH[0]][0, [NEG_NMH[1]]],
    tokens = model.to_str_tokens(ioi_dataset.toks[0]),
    attention_head_names = ["10.7"]
)

In [11]:
labels = [f"{x}_{i}" for (i, x) in enumerate(model.to_str_tokens(ioi_dataset.toks[0]))]

imshow(
    cache["attn_scores", NEG_NMH[0]][0, NEG_NMH[1]],
    x = labels,
    y = labels,
    labels = {"x": "Key", "y": "Query"},
    height = 800,
)