## Setup

In [2]:
from transformer_lens.cautils.notebook import *

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [3]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)

clear_output()

In [4]:
def _logits_to_ave_logit_diff(logits: Float[Tensor, "batch seq d_vocab"], ioi_dataset: IOIDataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()



def _ioi_metric_noising(
        logits: Float[Tensor, "batch seq d_vocab"],
        clean_logit_diff: float,
        corrupted_logit_diff: float,
        ioi_dataset: IOIDataset,
    ) -> float:
        '''
        We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),
        and -1 when performance has been destroyed (i.e. is same as ABC dataset).
        '''
        patched_logit_diff = _logits_to_ave_logit_diff(logits, ioi_dataset)
        return ((patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff)).item()



def generate_data_and_caches(N: int, verbose: bool = False, seed: int = 42):

    ioi_dataset = IOIDataset(
        prompt_type="mixed",
        N=N,
        tokenizer=model.tokenizer,
        prepend_bos=False,
        seed=seed,
        device=str(device)
    )

    abc_dataset = ioi_dataset.gen_flipped_prompts("ABB->XYZ, BAB->XYZ")

    model.reset_hooks(including_permanent=True)

    ioi_logits_original, ioi_cache = model.run_with_cache(ioi_dataset.toks)
    abc_logits_original, abc_cache = model.run_with_cache(abc_dataset.toks)

    ioi_average_logit_diff = _logits_to_ave_logit_diff(ioi_logits_original, ioi_dataset).item()
    abc_average_logit_diff = _logits_to_ave_logit_diff(abc_logits_original, ioi_dataset).item()

    if verbose:
        print(f"Average logit diff (IOI dataset): {ioi_average_logit_diff:.4f}")
        print(f"Average logit diff (ABC dataset): {abc_average_logit_diff:.4f}")

    ioi_metric_noising = partial(
        _ioi_metric_noising,
        clean_logit_diff=ioi_average_logit_diff,
        corrupted_logit_diff=abc_average_logit_diff,
        ioi_dataset=ioi_dataset,
    )

    return ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric_noising


N = 100
ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric_noising = generate_data_and_caches(N, verbose=True)
seq_len = ioi_dataset.toks.size(1)

Average logit diff (IOI dataset): 3.0733
Average logit diff (ABC dataset): 0.3129


In [17]:
NAME_TOKENS = model.to_tokens(NAMES, prepend_bos=False).squeeze().tolist()
NNMH = [(10, 7), (11, 0)]

def attn_scores_as_linear_func_of_keys(
    batch_idx: Union[int, List[int], Int[Tensor, "batch"]] = None,
    head: Tuple[int, int] = NNMH[0],
    model: HookedTransformer = model,
    ioi_cache: ActivationCache = ioi_cache
) -> Float[Tensor, "d_model"]:
    '''
    If you hold keys fixed, then attention scores are a linear function of the keys.

    I want to fix the keys of head 10.7, and get a linear function mapping queries -> attention scores.

    I can then see if (for example) the unembedding vector for the IO token has a really big image in this linear fn.
    '''
    layer, head_idx = head
    if isinstance(batch_idx, int):
        batch_idx = [batch_idx]
    if batch_idx is None:
        batch_idx = range(len(ioi_cache["q", 0]))

    keys = ioi_cache["k", layer][:, :, head_idx] # shape (all_batch, seq_K, d_head)
    keys_at_IO = keys[batch_idx, ioi_dataset.word_idx["IO"][batch_idx]] # shape (batch, d_head)
    
    W_Q = model.W_Q[layer, head_idx].clone() # shape (d_model, d_head)

    linear_map = einops.einsum(W_Q, keys_at_IO, "d_model d_head, batch d_head -> batch d_model")
    if isinstance(batch_idx, int):
        linear_map = linear_map[0]
    return linear_map


attn_scores_IO = t.empty((0,)).to(device)
attn_scores_S = t.empty((0,)).to(device)
attn_scores_random = t.empty((0,)).to(device)
attn_scores_random_name = t.empty((0,)).to(device)
attn_scores_99_out = t.empty((0,)).to(device)

probs_list = [t.empty((0,)).to(device), t.empty((0,)).to(device), t.empty((0,)).to(device), t.empty((0,)).to(device)]

for seed in tqdm(range(10)):

    ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric_noising = generate_data_and_caches(N, seed=seed)

    linear_map = attn_scores_as_linear_func_of_keys(ioi_cache=ioi_cache)
    assert linear_map.shape == (N, model.cfg.d_model)

    # Has to be manual, because apparently `apply_ln_to_stack` doesn't allow it to be applied at different sequence positions
    # Note - I don't actually have to do this if I'm computing cosine similarity!
    io_unembeddings = model.W_U.T[t.tensor(ioi_dataset.io_tokenIDs)]
    s_unembeddings = model.W_U.T[t.tensor(ioi_dataset.s_tokenIDs)]
    random_unembeddings = model.W_U.T[t.randint(size=(N,), low=0, high=model.cfg.d_vocab)]
    random_name_unembeddings = model.W_U.T[np.random.choice(NAME_TOKENS, size=(N,))]
    ln_scales = ioi_cache["scale", 10, "ln2"][range(N), ioi_dataset.word_idx["end"]]
    out_99 = einops.einsum(ioi_cache["z", 9][range(N), ioi_dataset.word_idx["end"], 9], model.W_O[9, 9], "batch d_head, d_head d_model -> batch d_model")

    io_unembeddings_normalized = io_unembeddings / io_unembeddings.norm(dim=-1, keepdim=True)
    s_unembeddings_normalized = s_unembeddings / s_unembeddings.norm(dim=-1, keepdim=True)
    random_unembeddings_normalized = random_unembeddings / random_unembeddings.norm(dim=-1, keepdim=True)
    random_name_unembeddings_normalized = random_name_unembeddings / random_name_unembeddings.norm(dim=-1, keepdim=True)
    out_99_normalized = out_99 / out_99.norm(dim=-1, keepdim=True)
    # io_unembeddings_normalized = io_unembeddings / ln_scales
    # s_unembeddings_normalized = s_unembeddings / ln_scales
    # random_unembeddings_normalized = random_unembeddings / ln_scales
    # random_name_unembeddings_normalized = random_name_unembeddings / ln_scales
    # out_99_normalized = out_99 / ln_scales

    new_attn_scores_IO = einops.einsum(linear_map, io_unembeddings_normalized, "batch d_model, batch d_model -> batch")
    attn_scores_IO = t.concat([attn_scores_IO, new_attn_scores_IO])
    
    new_attn_scores_S = einops.einsum(linear_map, s_unembeddings_normalized, "batch d_model, batch d_model -> batch")
    attn_scores_S = t.concat([attn_scores_S, new_attn_scores_S])

    new_attn_scores_random = einops.einsum(linear_map, random_unembeddings_normalized, "batch d_model, batch d_model -> batch")
    attn_scores_random = t.concat([attn_scores_random, new_attn_scores_random])

    new_attn_scores_random_name = einops.einsum(linear_map, random_name_unembeddings_normalized, "batch d_model, batch d_model -> batch")
    attn_scores_random_name = t.concat([attn_scores_random_name, new_attn_scores_random_name])

    new_attn_scores_99_out = einops.einsum(linear_map, out_99_normalized, "batch d_model, batch d_model -> batch")
    attn_scores_99_out = t.concat([attn_scores_99_out, new_attn_scores_99_out])

    other_attn_scores_at_this_posn = ioi_cache["attn_scores", 10][range(N), 7, ioi_dataset.word_idx["end"]]

    # for i, new_attn_scores in enumerate([new_attn_scores_IO, new_attn_scores_S, new_attn_scores_random, new_attn_scores_random_name, new_attn_scores_99_out]):
    #     all_attn_scores = other_attn_scores_at_this_posn.clone()
    #     all_attn_scores[range(N), ioi_dataset.word_idx["IO"]] = new_attn_scores
    #     all_probs = all_attn_scores.softmax(dim=-1)[range(N), ioi_dataset.word_idx["IO"]]
    #     probs_list[i] = t.cat([probs_list[i], all_probs])

# probs_IO, probs_S, probs_random, probs_random_name, probs_99_out = probs_list

  0%|          | 0/10 [00:00<?, ?it/s]

In [18]:
hist(
    [attn_scores_IO, attn_scores_S, attn_scores_random_name, attn_scores_random, attn_scores_99_out],
    labels={"variable": "Query-side vector", "value": "Attention scores"},
    title="Attn scores (from END -> IO) in NNMH 10.7 (keys fixed, scores are linear func of queries)",
    names=["W_U[IO]", "W_U[S]", "W_U[random name]", "W_U[random]", "NMH 9.9 output"],
    width=1000,
    height=600,
    opacity=0.7,
    marginal="box",
    template="simple_white",
    nbins=300
)

In [22]:
hist(
    [probs_IO, probs_S, probs_random_name, probs_random],
    labels={"variable": "Query-side vector", "value": "Attention prob"},
    title="Attn probs (from END -> IO) in NNMH 10.7 (keys fixed, query-side vector patched in)",
    names=["W_U[IO]", "W_U[S]", "W_U[random name]", "W_U[random]"],
    width=1000,
    height=600,
    opacity=0.7,
    marginal="box",
    template="simple_white"
)