# Setup

In [1]:
from transformer_lens.cautils.notebook import *
from transformer_lens.rs.callum.keys_fixed import (
    attn_scores_as_linear_func_of_keys,
    attn_scores_as_linear_func_of_queries,
    get_attn_scores_as_linear_func_of_queries_for_histogram,
    get_attn_scores_as_linear_func_of_keys_for_histogram,
    decompose_attn_scores,
    plot_contribution_to_attn_scores,
    project,
    decompose_attn_scores_full,
    create_fucking_massive_plot_1,
    create_fucking_massive_plot_2,
    get_effective_embedding_2,
)

# effective_embeddings = get_effective_embedding(model) 

# W_U = effective_embeddings["W_U (or W_E, no MLPs)"]
# W_EE = effective_embeddings["W_E (including MLPs)"]
# W_EE_subE = effective_embeddings["W_E (only MLPs)"]

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)
model.set_use_split_qkv_normalized_input(True)

clear_output()

In [35]:
BATCH_SIZE = 30
NEG_NMH = (10, 7)
seed = 0

ioi_dataset, ioi_cache = generate_data_and_caches(BATCH_SIZE, model=model, seed=seed, only_ioi=True, prepend_bos=True, symmetric=True)

### Hook functions

In [36]:
def hook_fn_queries(
    q_input: Float[Tensor, "batch seq n_heads d_model"], 
    hook: HookPoint,
    head: Tuple[int, int] = NEG_NMH,
    ioi_dataset: IOIDataset = ioi_dataset,
    model: HookedTransformer = model,
    project_in_S_dir: bool = True,
    par: bool = False,
):
    unembed_IO = model.W_U.T[ioi_dataset.io_tokenIDs] # (batch, d_model)
    unembed_S = model.W_U.T[ioi_dataset.s_tokenIDs] # (batch, d_model)

    proj_dirs = [unembed_IO, unembed_S] if project_in_S_dir else [unembed_IO]
    
    q_slice = q_input[range(len(ioi_dataset)), ioi_dataset.word_idx["end"], head[1]]
    assert q_slice.shape == unembed_IO.shape
    q_input_par, q_input_perp = project(q_slice, proj_dirs)

    q_input[range(len(ioi_dataset)), ioi_dataset.word_idx["end"], head[1]] = (q_input_par if par else q_input_perp)

    return q_input


model.reset_hooks()
t.cuda.empty_cache()

model.add_hook(utils.get_act_name("q_input", NEG_NMH[0]), hook_fn_queries)

logits, cache = model.run_with_cache(ioi_dataset.toks, names_filter=lambda name: name.endswith("pattern"))

cv.attention.attention_patterns(
    attention = cache["pattern", NEG_NMH[0]][0, [NEG_NMH[1]]],
    tokens = model.to_str_tokens(ioi_dataset.toks[0]),
    # attention_head_names = ["10.7"]
)

In [37]:
def hook_fn_keys(
    k_input: Float[Tensor, "batch seq n_heads d_model"], 
    hook: HookPoint,
    head: Tuple[int, int] = NEG_NMH,
    ioi_dataset: IOIDataset = ioi_dataset,
    ioi_cache: ActivationCache = ioi_cache,
    model: HookedTransformer = model,
    project_in_S_dir: bool = True,
    par: bool = True,
):
    N = len(ioi_dataset)
    mlp0_dir_IO = ioi_cache["mlp_out", 0][range(N), ioi_dataset.word_idx["IO"]] # (batch, d_model)
    mlp0_dir_S = ioi_cache["mlp_out", 0][range(N), ioi_dataset.word_idx["S1"]] # (batch, d_model)

    k_input_IO = k_input[range(len(ioi_dataset)), ioi_dataset.word_idx["IO"], head[1]]
    k_input_S = k_input[range(len(ioi_dataset)), ioi_dataset.word_idx["S1"], head[1]]

    assert k_input_IO.shape == mlp0_dir_IO.shape
    k_input_IO_par, k_input_IO_perp = project(k_input_IO, mlp0_dir_IO)
    k_input_S_par, k_input_S_perp = project(k_input_S, mlp0_dir_S)

    k_input[range(len(ioi_dataset)), ioi_dataset.word_idx["IO"], head[1]] = (k_input_IO_par if par else k_input_IO_perp)
    k_input[range(len(ioi_dataset)), ioi_dataset.word_idx["S1"], head[1]] = (k_input_S_par if par else k_input_S_perp)

    return k_input




q_hook = (utils.get_act_name("q_normalized_input", NEG_NMH[0]), hook_fn_queries)
k_hook = (utils.get_act_name("k_normalized_input", NEG_NMH[0]), hook_fn_keys)

def test_model(model: HookedTransformer, show_too: bool = False):

    for use_q, use_k in itertools.product([True, False], [True, False]):
        model.reset_hooks()
        if use_q: model.add_hook(*q_hook)
        if use_k: model.add_hook(*k_hook)
        desc = f"{'q' if use_q else ' '}{'k' if use_k else ' '}"

        t.cuda.empty_cache()

        logits, cache = model.run_with_cache(ioi_dataset.toks, names_filter=lambda name: name.endswith("attn_scores"))
        attn_scores = cache["attn_scores", NEG_NMH[0]][:, NEG_NMH[1]]

        attn_scores_to_IO = attn_scores[range(len(ioi_dataset)), ioi_dataset.word_idx["end"], ioi_dataset.word_idx["IO"]]
        attn_scores_to_S = attn_scores[range(len(ioi_dataset)), ioi_dataset.word_idx["end"], ioi_dataset.word_idx["S1"]]

        print(f"Diff [{desc}] = {attn_scores_to_IO.mean() - attn_scores_to_S.mean():.3f}")

        if show_too:
            labels = [f"{x}_{i}" for (i, x) in enumerate(model.to_str_tokens(ioi_dataset.toks[0]))]
            imshow(
                cache["attn_scores", NEG_NMH[0]][0, NEG_NMH[1]],
                x = labels,
                y = labels,
                labels = {"x": "Key", "y": "Query"},
                height = 800,
            )

# cv.attention.attention_patterns(
#     attention = cache["pattern", NEG_NMH[0]][0, [NEG_NMH[1]]],
#     tokens = model.to_str_tokens(ioi_dataset.toks[0]),
#     attention_head_names = ["10.7"]
# )
# labels = [f"{x}_{i}" for (i, x) in enumerate(model.to_str_tokens(ioi_dataset.toks[0]))]


test_model(model)

Diff [qk] = 1.858
Diff [q ] = 2.098
Diff [ k] = 3.099
Diff [  ] = 3.254


In [38]:
def hook_fn_patch_wpos_MLP0(
    resid_pre: Float[Tensor, "batch seq d_model"],
    hook: HookPoint,
    add: bool,
    model: HookedTransformer = model,
    permute: bool = False,
    ioi_dataset: IOIDataset = ioi_dataset,
):
    seq_len = resid_pre.shape[1]
    assert model.W_pos.shape[-1] == model.cfg.d_model
    W_pos = model.W_pos[:seq_len]

    if permute:
        io_posses = W_pos[ioi_dataset.word_idx["IO"]]
        s_posses = W_pos[ioi_dataset.word_idx["S1"]]

        sign = 1.0 if add else -1.0

        shape1 = resid_pre[torch.arange(len(ioi_dataset)), ioi_dataset.word_idx["IO"]].shape
        shape2 = s_posses.shape
        assert shape1==shape2

        resid_pre[torch.arange(len(ioi_dataset)), ioi_dataset.word_idx["IO"]] += sign*(s_posses - io_posses)
        resid_pre[torch.arange(len(ioi_dataset)), ioi_dataset.word_idx["S1"]] += sign*(io_posses - s_posses)

        return resid_pre

    else:
        if add:
            return resid_pre + W_pos
        else:
            return resid_pre - W_pos


model.reset_hooks(including_permanent=True)
model.add_hook(utils.get_act_name("resid_pre", 0), partial(hook_fn_patch_wpos_MLP0, add=True, permute=True), is_permanent=True)

logits, mlp_positional_signals_flipped_cache = model.run_with_cache(ioi_dataset.toks)

# model.add_hook(utils.get_act_name("resid_pre", 1), partial(hook_fn_patch_wpos_MLP0, add=False, permute=True), is_permanent=True)

# test_model(model, show_too=False)
# model.reset_hooks(including_permanent=True)

In [7]:
# Diff [qk] = 1.570
# Diff [q ] = 1.711
# Diff [ k] = 2.791
# Diff [  ] = 2.911

## Path patching

### 1. Path patch from MLP0 -> keyside 10.7

In [39]:
flipped_dataset = ioi_dataset.gen_flipped_prompts("ABB->BAB, BAB->ABB")
# flipped_dataset = ioi_dataset.gen_flipped_prompts("ABB->BAA, BAB->ABA")
_, flipped_cache = model.run_with_cache(flipped_dataset.toks, names_filter=lambda name: name.endswith("mlp_out"))

In [41]:
def patching_metric(cache: ActivationCache) -> float:
    attn_scores = cache["attn_scores", 10][:, 7]
    attn_scores_to_IO = attn_scores[range(len(ioi_dataset)), ioi_dataset.word_idx["end"], ioi_dataset.word_idx["IO"]]
    attn_scores_to_S = attn_scores[range(len(ioi_dataset)), ioi_dataset.word_idx["end"], ioi_dataset.word_idx["S1"]]
    return (attn_scores_to_IO.mean() - attn_scores_to_S.mean()).item()


def test_model_PP(model: HookedTransformer):

    for use_q, use_PP in itertools.product([True, False], [True, False]):
        model.reset_hooks(including_permanent=True)
        if use_q:
            model.add_hook(*q_hook, is_permanent=True)
        desc = f"{'q' if use_q else ' '}{'P' if use_PP else ' '}"

        t.cuda.empty_cache()

        if use_PP:
            diff = path_patch(
                model = model,
                patching_metric = patching_metric,
                apply_metric_to_cache = True,
                orig_input = ioi_dataset.toks,
                # new_input = flipped_dataset.toks,
                orig_cache = ioi_cache,
                new_cache = mlp_positional_signals_flipped_cache,
                direct_includes_mlps = True,
                sender_nodes = Node("mlp_out", layer=0),
                receiver_nodes = [Node("v", layer=9, head=9), Node("v", layer=9, head=6)],
                # receiver_nodes = Node("k", layer=10, head=7),
            )
        else:
            # TODO - sanity check, do this with path_patch instead
            logits, cache = model.run_with_cache(ioi_dataset.toks, names_filter=lambda name: name.endswith("attn_scores"))
            diff = patching_metric(cache)

        print(f"Diff [{desc}] = {diff:.3f}")


model.reset_hooks(including_permanent=True)
test_model_PP(model)

Diff [qP] = -0.486
Diff [q ] = 2.098
Diff [ P] = -0.923
Diff [  ] = 3.254


In [42]:
flipped_dataset = ioi_dataset.gen_flipped_prompts("ABB->CBB, BAB->BCB")

In [44]:
ioi_dataset.sentences[0], flipped_dataset.sentences[0]

_, flipped_cache = model.run_with_cache(flipped_dataset.toks, names_filter=lambda name: name.endswith("k") and ".10." in name)

# Act patch

In [None]:
act_patch(
    model = model,
    orig_input = ioi_dataset.toks,
    patching_nodes = Node("k", layer=10, head=7),
    new_cache = flipped_cache,
    apply_metric_to_cache = True,
    patching_metric = patching_metric,
)