# Setup

In [1]:
from transformer_lens.cautils.notebook import *

from transformer_lens.rs.callum.keys_fixed import (
    project,
    get_effective_embedding_2,
)

from transformer_lens.rs.callum.orthogonal_query_investigation import (
    decompose_attn_scores_full,
    create_fucking_massive_plot_1,
    create_fucking_massive_plot_2,
    token_to_qperp_projection
)

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)
# model.set_use_split_qkv_normalized_input(False)
# model.cfg.use_hook_tokens = True
# model.use_hook_

clear_output()

# sudo pkill -9 python

In [3]:
effective_embeddings = get_effective_embedding_2(model)

W_EE = effective_embeddings['W_E (including MLPs)']
W_EE0 = effective_embeddings['W_E (only MLPs)']
W_E = model.W_E
W_EE0A = W_EE - W_E

# Define an easier-to-use dict!
effective_embeddings = {"W_EE": W_EE, "W_EE0": W_EE0, "W_E": W_E, "W_EE0A": W_EE - W_E}

# Part 1

I want to calculate, for a given head and a given input, the weighted average amount that this head suppresses the logits of tokens that it attends to (if you just take the component of its output in the unembedding direction of the token it attends to).

I do this by the following method:

* After `v` is calculated, I calculate `result_pre_attn`, which is actually the thing you get by applying `W_O` first, i.e. before taking a weighted average of `z` in accordance with the attention probabilities. In other words, `result_pre_attn` has shape `(batch, seqK, d_model)`, and it contains the vectors which will get moved from the source position `seqK` to the destination positions if it's paid 100% attention to.
    * I then calculate `result_pre_attn_projected`, by projecting `result_pre_attn` onto the unembedding directions for that particular source token. This still has shape `(batch, seqK, d_model)`.
* At the attention patterns, I cache them in the result hook.
* At the results, I get `result_post_attn_projected`, which has shape `(batch, seqQ, d_model)`, which are the actual (projections of) vectors that will be added to the residual stream at the destination positions `seqQ`.
* At the scale, I can scale these `result_post_attn_projected` vectors, and then convert them into logits, and calculate a weighted average of how much they suppress the logits of the tokens attended to (i.e. sumproduct of `amount_seqK_token_suppressed_at_seqQ * attn_from_seqQ_to_seqK`).

In [4]:
def hook_fn_cache_new_result_projections(
    v: Float[Tensor, "batch seq n_heads d_head"],
    hook: HookPoint,
    toks: Int[Tensor, "batch seq"],
    model: HookedTransformer,
    head_idx: int,
):
    '''
    Hook function to compute the projections of the vectors which will be moved from each source position.

    This doesn't change the values; it just stores some stuff in context.
    '''
    # Calculate the result (pre taking weighted average from attn) by multiplying the attn scores by the values
    result_pre_attn = einops.einsum(v[:, :, head_idx], model.W_O[hook.layer(), head_idx], "batch seq d_head, d_head d_model -> batch seq d_model")
    token_unembeddings = model.W_U.T[toks]

    # The result_projected tensor tells us, for each source token, what is the vector that will get moved from source 
    # to destination (after projecting it onto the source unembedding)
    result_pre_attn_projected, result_pre_attn_perpendicular = project(result_pre_attn, token_unembeddings)

    result_hook = model.hook_dict[utils.get_act_name("result", hook.layer())]
    result_hook.ctx["result_pre_attn_projected"] = result_pre_attn_projected



def hook_fn_cache_attn(
    pattern: Float[Tensor, "batch n_heads seqQ seqK"],
    hook: HookPoint,
    model: HookedTransformer,
    head_idx: int,
):
    result_hook = model.hook_dict[utils.get_act_name("result", hook.layer())]
    result_hook.ctx["pattern"] = pattern[:, head_idx]



def hook_fn_take_weighted_avg(
    result: Float[Tensor, "batch seq n_heads d_head"],
    hook: HookPoint,
    model: HookedTransformer,
):
    '''
    Now that `hook_fn_cache_new_result_projections` has cached the (pre-attention) results, we can now use the actual
    attention pattern to take linear combination of new results.
    '''
    # TODO - if I'm not using `result`, then just move this hook fn logic into the hook I add to `scale` below.
    
    result_hook = model.hook_dict[utils.get_act_name("result", hook.layer())]

    result_pre_attn_projected = result_hook.ctx.pop("result_pre_attn_projected")
    pattern = result_hook.ctx["pattern"]

    # The new result_projected tensor tells us, for each source token AND destination token, what is the vector that actually 
    # gets moved from source to destination (after projecting it onto the source unembedding)
    result_post_attn_projected = einops.einsum(
        result_pre_attn_projected, pattern,
        "batch seqK d_model, batch seqQ seqK -> batch seqQ d_model",
    )
    result_hook.ctx["result_post_attn_projected"] = result_post_attn_projected



def hook_fn_calculate_logit_diff(
    scale: Float[Tensor, "batch seq 1"],
    hook: HookPoint,
    toks: Int[Tensor, "batch seq"],
    model: HookedTransformer,
    layer: int,
):
    '''
    Note on how we deal with BOS, because it's too long to fit in a comment.

    We don't care about the predictions made by BOS tokens, so we want to zero attention whenever destination = BOS.
    We also don't care about attending to BOS tokens, so we want to zero attention whenever source = BOS.
    
    The thing we divide each of the (batch, seqQ) token scores by is the sum of attention over keys.
    The thing we divide each of the (batch,) sequence scores by is the sum of attention over keys and queries.
    '''
    assert isinstance(toks, Int[Tensor, "batch seq"])
    batch_size, seq_len = toks.shape

    # Check BOS is at start (this matters!)
    assert t.all(toks[:, 0] == model.tokenizer.bos_token_id)

    # Get change in attn result from the context of the result hook
    result_hook = model.hook_dict[utils.get_act_name("result", layer)]

    result_post_attn_projected = result_hook.ctx.pop("result_post_attn_projected")
    pattern = result_hook.ctx.pop("pattern")
    
    # Scale the change
    result_post_attn_projected_scaled = (result_post_attn_projected - result_post_attn_projected.mean(-1, keepdims=True)) / scale

    # Calculate weighted average of how much each source token is suppressed
    token_unembeddings = model.W_U.T[toks]
    all_logit_suppression = einops.einsum(
        result_post_attn_projected_scaled, token_unembeddings,
        "batch seqQ d_model, batch seqK d_model -> batch seqQ seqK",
    )
    # The [i, q, k]-th elem is how much (in sequence i, destination position j) each of the source tokens' (k) logits are directly affected
    # I now take a weighted average of this, over the attention paid to source tokens
    is_not_bos_2d_mask = (toks != model.tokenizer.bos_token_id).float()
    is_not_bos_3d_mask = einops.einsum(is_not_bos_2d_mask, is_not_bos_2d_mask, "batch seqQ, batch seqK -> batch seqQ seqK")
    is_not_self_3d_mask = 1.0 - einops.repeat(t.eye(seq_len).to(device), "seqQ seqK -> batch seqQ seqK", batch=batch_size).float()
    full_mask = is_not_self_3d_mask * is_not_bos_3d_mask
    pattern_non_bos = pattern * full_mask # is_not_bos_3d_mask

    # We don't want to divide by anything per dest. If a dest token only attends to BOS, we don't want to reweight - we don't care about that dest!
    weighted_avg_logit_suppression_per_dest = einops.einsum(
        all_logit_suppression, pattern_non_bos,
        "batch seqQ seqK, batch seqQ seqK -> batch seqQ",
    )

    # We do want to renormalize per sequence, by dividing by the total attention we're summing over.
    non_bos_attn_per_seq = pattern_non_bos.sum((-1, -2))
    weighted_avg_logit_suppression_per_seq = einops.einsum(
        all_logit_suppression, pattern_non_bos,
        "batch seqQ seqK, batch seqQ seqK -> batch",
    ) / non_bos_attn_per_seq

    hook.ctx["weighted_avg_logit_suppression"] = {
        "per_position": weighted_avg_logit_suppression_per_dest,
        "per_sequence": weighted_avg_logit_suppression_per_seq,
        "pattern_non_bos": pattern_non_bos
    }



def compute_weighted_avg_logit_suppression(
    model: HookedTransformer,
    toks: Int[Tensor, "batch seq"],
    head: Tuple[int, int],
):
    layer, head_idx = head

    model.reset_hooks(including_permanent=True)
    model.run_with_hooks(
        toks,
        return_type = None,
        fwd_hooks = [
            (utils.get_act_name("v", layer), partial(hook_fn_cache_new_result_projections, toks=toks, model=model, head_idx=head_idx)),
            (utils.get_act_name("pattern", layer), partial(hook_fn_cache_attn, model=model, head_idx=head_idx)),
            (utils.get_act_name("result", layer), partial(hook_fn_take_weighted_avg, model=model)),
            (utils.get_act_name("scale"), partial(hook_fn_calculate_logit_diff, toks=toks, model=model, layer=layer)),
        ]
    )

    results = model.hook_dict[utils.get_act_name("scale")].ctx.pop("weighted_avg_logit_suppression")
    return results

### Test my function

I'm going to take a standard sentence from the IOI distribution, and run my function on it. I should find that:

* For head `10.7`, the amount of suppression when `seqQ` is the position of the IO token is about the same as the head's direct output projected in the IO direction (i.e. very negative).
* For head `9.9`, for similar reasons, it should be very positive.

In [25]:
batch_size = 10
seed = 0
ioi_dataset, ioi_cache = generate_data_and_caches(batch_size, model=model, seed=seed, only_ioi=True, prepend_bos=True)

def investigate_a_few_heads(
    toks: Int[Tensor, "batch seq"],
    head_list = [(9, 9), (10, 7), (11, 0)],
):
    labels = []
    results = []
    batch_size, seq_len = toks.shape

    for head in head_list:
        results_dict = compute_weighted_avg_logit_suppression(
            model=model,
            toks=toks,
            head=head
        )
        weighted_avg_logit_suppression_per_position = results_dict["per_position"]
        weighted_avg_logit_suppression_per_sequence = results_dict["per_sequence"]
        pattern_non_bos = results_dict["pattern_non_bos"].sum(-1)
        pattern_non_bos = t.concat([pattern_non_bos, pattern_non_bos.mean(1, keepdims=True)], dim=1)
        results_tensor = t.concat([weighted_avg_logit_suppression_per_position, weighted_avg_logit_suppression_per_sequence.unsqueeze(-1)], dim=-1)

        results.extend([results_tensor, pattern_non_bos])
        head_name = f"{head[0]}.{head[1]}"
        labels.extend([head_name, f"{head_name} probs"])

    text = [model.to_str_tokens(toks[i]) + ["*"] for i in range(batch_size)]
    text = [[(i if i != "<|endoftext|>" else "") for i in T] for T in text]

    fig = imshow(
        t.stack(results),
        animation_frame = 0,
        animation_labels = labels,
        height = 800,
        width = 1600,
        return_fig = True,
    )

    for i in range(len(fig.data)):
        fig.data[i].update(
            text=text, 
            texttemplate="%{text}", 
            textfont={"size": 12}
        )

    fig.show()


investigate_a_few_heads(ioi_dataset.toks)

In [216]:
batch_idx = 3

cv.attention.attention_patterns(
    tokens = model.to_str_tokens(ioi_dataset.toks[batch_idx]),
    attention = t.stack([ioi_cache["pattern", 10][batch_idx, 7], ioi_cache["pattern", 9][batch_idx, 9], ioi_cache["pattern", 11][batch_idx, 8]]),
    attention_head_names = ["10.7", "9.9", "11.8"]
)

In [8]:
def create_really_hacky_first_pass_metric(
    batch_size: int = 50,
    seed: int = 0,
):
    ioi_dataset, ioi_cache = generate_data_and_caches(batch_size, model=model, seed=seed, only_ioi=True, prepend_bos=True)

    all_results = t.zeros(12, 12)

    for layer, head_idx in tqdm(list(itertools.product(range(12), range(12)))):
        results = compute_weighted_avg_logit_suppression(
            model=model,
            toks=ioi_dataset.toks,
            head=(layer, head_idx)
        )

        all_results[layer, head_idx] = results["per_sequence"].mean()

    return all_results

all_results = create_really_hacky_first_pass_metric()

  0%|          | 0/144 [00:00<?, ?it/s]

In [9]:
imshow(all_results * (all_results < 0), height=800, title="Source suppression scores")

# OpenWebText

In [11]:
data = get_webtext(seed=6)

Found cached dataset openwebtext-10k (/home/ubuntu/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)


  0%|          | 0/1 [00:00<?, ?it/s]

In [37]:
def create_really_hacky_second_pass_metric(
    batch_size: int = 60,
    seq_len: int = 100,
):
    toks = model.to_tokens(data[:batch_size])[:, :seq_len]

    all_results = t.zeros(12, 12)

    for layer, head_idx in tqdm(list(itertools.product(range(12), range(12)))):
        results = compute_weighted_avg_logit_suppression(
            model=model,
            toks=toks,
            head=(layer, head_idx)
        )

        all_results[layer, head_idx] = results["per_sequence"].mean()

    return all_results

all_results = create_really_hacky_first_pass_metric()
all_results_2 = create_really_hacky_second_pass_metric()

  0%|          | 0/144 [00:00<?, ?it/s]

In [39]:
imshow(
    t.stack([all_results * (all_results < 0), all_results_2 * (all_results_2 < 0)]),
    facet_col=0,
    facet_labels=["IOI", "Webtext"],
    height=700, 
    width=1500, 
    title="Source suppression scores",
    text_auto=".2f"
)