# Setup

In [1]:
from transformer_lens.cautils.notebook import *

from transformer_lens.rs.callum.keys_fixed import (
    project,
    get_effective_embedding_2,
)

from transformer_lens.rs.callum.orthogonal_query_investigation import (
    decompose_attn_scores_full,
    create_fucking_massive_plot_1,
    create_fucking_massive_plot_2,
    token_to_qperp_projection
)

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(False)
model.set_use_attn_result(True)
# model.set_use_split_qkv_normalized_input(False)
# model.cfg.use_hook_tokens = True
# model.use_hook_

clear_output()

# sudo pkill -9 python

In [3]:
effective_embeddings = get_effective_embedding_2(model)

W_EE = effective_embeddings['W_E (including MLPs)']
W_EE0 = effective_embeddings['W_E (only MLPs)']
W_E = model.W_E
W_EE0A = W_EE - W_E

# Define an easier-to-use dict!
effective_embeddings = {"W_EE": W_EE, "W_EE0": W_EE0, "W_E": W_E, "W_EE0A": W_EE - W_E}

# Part 1

I want to calculate, for a given head and a given input, the weighted average amount that this head suppresses the logits of tokens that it attends to (if you just take the component of its output in the unembedding direction of the token it attends to).

I do this by the following method:

* After `v` is calculated, I calculate `result_pre_attn`, which is actually the thing you get by applying `W_O` first, i.e. before taking a weighted average of `z` in accordance with the attention probabilities. In other words, `result_pre_attn` has shape `(batch, seqK, d_model)`, and it contains the vectors which will get moved from the source position `seqK` to the destination positions if it's paid 100% attention to.
    * I then calculate `result_pre_attn_projected`, by projecting `result_pre_attn` onto the unembedding directions for that particular source token. This still has shape `(batch, seqK, d_model)`.
* At the attention patterns, I cache them in the result hook.
* At the results, I get `result_post_attn_projected`, which has shape `(batch, seqQ, d_model)`, which are the actual (projections of) vectors that will be added to the residual stream at the destination positions `seqQ`.
* At the scale, I can scale these `result_post_attn_projected` vectors, and then convert them into logits, and calculate a weighted average of how much they suppress the logits of the tokens attended to (i.e. sumproduct of `amount_seqK_token_suppressed_at_seqQ * attn_from_seqQ_to_seqK`).

In [43]:
def hook_fn_cache_v(
    v: Float[Tensor, "batch seq n_heads d_head"],
    hook: HookPoint,
    model: HookedTransformer,
    head_idx: int,
):
    result_hook = model.hook_dict[utils.get_act_name("result", hook.layer())]
    result_hook.ctx["v"] = v[:, :, head_idx]



def hook_fn_cache_attn(
    pattern: Float[Tensor, "batch n_heads seqQ seqK"],
    hook: HookPoint,
    model: HookedTransformer,
    head_idx: int,
):
    result_hook = model.hook_dict[utils.get_act_name("result", hook.layer())]
    result_hook.ctx["pattern"] = pattern[:, head_idx]



def hook_fn_at_result(
    result: Float[Tensor, "batch seq n_heads d_head"],
    hook: HookPoint,
    toks: Int[Tensor, "batch seq"],
    model: HookedTransformer,
    head_idx: int,
    only_use_neg_projections: bool,
    return_new_loss: bool,
):
    '''
    If `return_new_loss` is True, then we want to compute the projected result (using cached v and pattern), so we can eventually
    calculate change in logits, and then calculate change in loss. So we store result_diff in the cache.

    If `return_new_loss` is False, then we don't need to project results in the unembedding direction - we just need the coeffs of
    that projection. So we store the coeffs in the cache. Note that we can't finish here, because we'll need to apply ln_scale when 
    we compute it.
    '''
    v = hook.ctx.pop("v")
    
    unembeddings = model.W_U.T[toks]
    result_pre_attn = einops.einsum(
        v, model.W_O[hook.layer(), head_idx],
        "batch seqK d_head, d_head d_model -> batch seqK d_model"
    )

    if return_new_loss:
        pattern = hook.ctx.pop("pattern")
        result_pre_attn_projected, result_pre_attn_perp, unembedding_coeffs = project(result_pre_attn, unembeddings, return_type="both")
        if only_use_neg_projections: 
            result_pre_attn_projected *= (unembedding_coeffs < 0).float()
        result_post_attn_projected = einops.einsum(
            result_pre_attn_projected, pattern,
            "batch seqK d_model, batch seqQ seqK -> batch seqQ d_model"
        )
        hook.ctx["result_orig"] = result[:, :, head_idx]
        hook.ctx["result_new"] = result_post_attn_projected
    
    else:
        # Figure out the direct effect the result vector at each source token has on the corresponding source logit
        unembedding_coeffs = einops.einsum(result_pre_attn, unembeddings, "batch seqK d_model, batch seqK d_model -> batch seqK")
        # unembedding_coeffs = project(result_pre_attn, unembeddings, return_type="coeffs")
        # unembedding_coeffs = t.cosine_similarity(result_pre_attn, unembeddings, dim=-1)
        if only_use_neg_projections: 
            unembedding_coeffs *= (unembedding_coeffs < 0).float()
        unembedding_coeffs = unembedding_coeffs.squeeze()
        assert unembedding_coeffs.shape == toks.shape
        hook.ctx["unembedding_coeffs"] = unembedding_coeffs


def hook_fn_at_scale(
    ln_scale: Float[Tensor, "batch seq 1"],
    hook: HookPoint,
    toks: Int[Tensor, "batch seq"],
    layer: int,
    model: HookedTransformer,
):
    '''
    This function only gets used when `return_new_loss` is False. In this case, we just want to calculate the projection of the result vectors
    at each key position onto each unembedding.
    '''
    result_hook = model.hook_dict[utils.get_act_name("result", layer)]
    batch, seq_len = toks.shape

    if "result_orig" in result_hook.ctx:
        result_orig = result_hook.ctx.pop("result_orig") / ln_scale
        result_new = result_hook.ctx.pop("result_new") / ln_scale
        logit_change_for_projection = einops.einsum(result_new - result_orig, model.W_U, "batch seqQ d_model, d_model d_vocab -> batch seqQ d_vocab")
        logit_change_for_ablation = einops.einsum(- result_orig, model.W_U, "batch seqQ d_model, d_model d_vocab -> batch seqQ d_vocab")
        result_hook.ctx["logit_change_for_proj"] = logit_change_for_projection
        result_hook.ctx["logit_change_for_perp"] = logit_change_for_ablation - logit_change_for_projection
        result_hook.ctx["logit_change_for_ablation"] = logit_change_for_ablation

    else:
        # * Total amount that vectors moving from `src` to `dest` (neither being BOS or same) pushes down logits for `src`,
        # * divided by the total probability mass of such movements

        unembedding_coeffs = result_hook.ctx.pop("unembedding_coeffs")
        pattern = result_hook.ctx.pop("pattern")

        # Create a mask of shape (batch, seqQ, seqK) where seqQ and seqK are not BOS, and seqQ != seqK (i.e. token isn't self-attending)
        # because these aren't as interesting, and are unlikely to be copy suppression anyway (word coming straight after itself!)
        is_not_bos_2d_mask = (toks != model.tokenizer.bos_token_id).float()
        is_not_bos_3d_mask = einops.einsum(is_not_bos_2d_mask, is_not_bos_2d_mask, "batch seqQ, batch seqK -> batch seqQ seqK")
        is_not_self_3d_mask = 1.0 - einops.repeat(t.eye(seq_len).to(device), "seqQ seqK -> batch seqQ seqK", batch=batch).float()
        full_mask = is_not_bos_3d_mask #* is_not_self_3d_mask
        # We then get the sum of probabilities paid to non-self and non-BOS tokens for each dest token (and then the sum of these for all sequences)
        non_bos_attn = pattern * full_mask
        non_bos_attn_per_dest = einops.reduce(non_bos_attn, "batch seqQ seqK -> batch seqQ", "sum")
        non_bos_attn_per_src = einops.reduce(non_bos_attn, "batch seqQ seqK -> batch seqK", "sum")
        non_bos_attn_per_seq = einops.reduce(non_bos_attn_per_dest, "batch seqQ -> batch", "sum")

        unembedding_coeffs_per_src_and_dest = einops.einsum(
            unembedding_coeffs, non_bos_attn,
            "batch seqK, batch seqQ seqK -> batch seqQ seqK"
        ) / ln_scale

        unembedding_coeffs_per_src = einops.einsum(
            unembedding_coeffs_per_src_and_dest,
            "batch seqQ seqK -> batch seqK",
        )

        unembedding_coeffs_weighted_avg_by_attn_to_src = einops.einsum(
            unembedding_coeffs_per_src, non_bos_attn_per_src,
            "batch seqK, batch seqK -> batch"
        ) / non_bos_attn_per_src.sum(dim=-1)

        hook.ctx["unembedding_coeffs"] = {
            "per_position": unembedding_coeffs_per_src,
            "per_sequence": unembedding_coeffs_weighted_avg_by_attn_to_src,
        }
        return

        # Finally, we can calculate the weighted average of the unembedding coeffs for each dest token, weighted by the attention paid to non-self and non-BOS tokens
        # Note, we don't renormalize individual dest positions if they don't attend much to nn-self, non-BOS tokens - this just means we don't care about that token as much!
        
        # weighted_attn = full_mask * pattern * (full_mask.sum() / (full_mask * pattern).sum())
        unembedding_coeffs_scaled_weighted_avg_per_dest = einops.reduce(
            unembedding_coeffs_per_src_and_dest * non_bos_attn,
            "batch seqQ seqK -> batch seqQ", "sum"
        ) # / (non_bos_attn_per_dest + 1e-8)
        unembedding_coeffs_scaled_weighted_avg_per_seq = einops.reduce(
            unembedding_coeffs_per_src_and_dest * non_bos_attn,
            "batch seqQ seqK -> batch", "sum"
        ) / non_bos_attn_per_seq

        hook.ctx["unembedding_coeffs"] = {
            "per_position": unembedding_coeffs_scaled_weighted_avg_per_dest,
            "per_sequence": unembedding_coeffs_scaled_weighted_avg_per_seq,
            "pattern_non_bos": non_bos_attn_per_dest
        }



def compute_weighted_avg_logit_suppression(
    model: HookedTransformer,
    toks: Int[Tensor, "batch seq"],
    head: Tuple[int, int],
    only_use_neg_projections: bool = False,
    return_new_loss = False,
):
    layer, head_idx = head

    model.reset_hooks(including_permanent=True)
    model.clear_contexts()

    scale_hook = model.hook_dict[utils.get_act_name("scale")]
    result_hook = model.hook_dict[utils.get_act_name("result", layer)]

    # Define fwd_hooks
    fwd_hooks = [
        (utils.get_act_name("v", layer), partial(hook_fn_cache_v, model=model, head_idx=head_idx)),
        (utils.get_act_name("pattern", layer), partial(hook_fn_cache_attn, model=model, head_idx=head_idx)),
        (utils.get_act_name("result", layer), partial(hook_fn_at_result, toks=toks, model=model, head_idx=head_idx, only_use_neg_projections=only_use_neg_projections, return_new_loss=return_new_loss)),
        (utils.get_act_name("scale"), partial(hook_fn_at_scale, toks=toks, model=model, layer=layer)),
    ]
    # Forward pass which stores things in hook contexts (but doesn't actually change the logits)
    logits = model.run_with_hooks(
        toks,
        return_type = "logits",
        fwd_hooks = fwd_hooks
    )

    if return_new_loss:
        logit_change_for_proj = result_hook.ctx.pop("logit_change_for_proj")
        logit_change_for_perp = result_hook.ctx.pop("logit_change_for_perp")
        logit_change_for_ablation = result_hook.ctx.pop("logit_change_for_ablation")
        logits_proj = logits + logit_change_for_proj
        logits_perp = logits + logit_change_for_perp
        logits_ablation = logits + logit_change_for_ablation
        loss_proj = utils.lm_cross_entropy_loss(logits_proj, toks, per_token=True)
        loss_perp = utils.lm_cross_entropy_loss(logits_perp, toks, per_token=True)
        loss_ablation = utils.lm_cross_entropy_loss(logits_ablation, toks, per_token=True)
        model.clear_contexts()
        return loss_proj, loss_perp, loss_ablation
    else:
        unembedding_coeffs = scale_hook.ctx.pop("unembedding_coeffs")
        # norm_fractions = result_hook.ctx.pop("norms")
        model.clear_contexts()
        return unembedding_coeffs

### Test my function

I'm going to take a standard sentence from the IOI distribution, and run my function on it. I should find that:

* For head `10.7`, the amount of suppression when `seqQ` is the position of the IO token is about the same as the head's direct output projected in the IO direction (i.e. very negative).
* For head `9.9`, for similar reasons, it should be very positive.

In [44]:
batch_size = 10
seed = 0
ioi_dataset, ioi_cache = generate_data_and_caches(batch_size, model=model, seed=seed, only_ioi=True, prepend_bos=True, symmetric=True)

def investigate_a_few_heads(
    toks: Int[Tensor, "batch seq"],
    head_list = [(9, 9), (10, 7), (11, 10)],
    only_use_neg_projections: bool = False,
):
    labels = []
    results = []
    batch_size, seq_len = toks.shape

    for head in head_list:
        unembedding_coeffs = compute_weighted_avg_logit_suppression(
            model=model,
            toks=toks,
            head=head,
            only_use_neg_projections=only_use_neg_projections,
            return_new_loss=False,
        )
        if isinstance(unembedding_coeffs, tuple):
            unembedding_coeffs = unembedding_coeffs[0]
        
        logit_suppression_per_position = unembedding_coeffs["per_position"]
        weighted_avg_logit_suppression_per_sequence = unembedding_coeffs["per_sequence"]
        
        # pattern_non_bos = unembedding_coeffs["pattern_non_bos"]
        # results_tensor = t.concat([weighted_avg_logit_suppression_per_position, weighted_avg_logit_suppression_per_sequence.unsqueeze(-1)], dim=-1)
        # pattern_non_bos = t.concat([pattern_non_bos, t.zeros(batch_size, 1, dtype=t.float32, device=device)], dim=-1)

        # results.extend([results_tensor, pattern_non_bos])
        # head_name = f"{head[0]}.{head[1]}"
        # labels.extend([head_name, head_name + " (probs)"])

        results.append(logit_suppression_per_position)
        labels.append(f"{head[0]}.{head[1]}")

    text = [model.to_str_tokens(toks[i]) + ["*"] for i in range(batch_size)]
    text = [[(i if i != "<|endoftext|>" else "") for i in T] for T in text]

    results = t.stack(results)
    zmax = results.abs().max().item()

    fig = imshow(
        results,
        animation_frame = 0,
        animation_labels = labels,
        height = 800,
        width = 1600,
        return_fig = True,
        zmax = zmax,
        zmin = -zmax,
    )

    for i in range(len(fig.data)):
        fig.data[i].update(
            text=text, 
            texttemplate="%{text}", 
            textfont={"size": 12}
        )

    fig.show()


investigate_a_few_heads(ioi_dataset.toks, only_use_neg_projections=False)

In [25]:
utils.test_prompt("Then, Sarah and Arthur had a lot of fun at the hospital. Arthur gave a bone to", " Sarah", model)

Tokenized prompt: ['<|endoftext|>', 'Then', ',', ' Sarah', ' and', ' Arthur', ' had', ' a', ' lot', ' of', ' fun', ' at', ' the', ' hospital', '.', ' Arthur', ' gave', ' a', ' bone', ' to']
Tokenized answer: [' Sarah']


Top 0th token. Logit: 16.50 Prob: 46.79% Token: | Sarah|
Top 1th token. Logit: 15.31 Prob: 14.15% Token: | Arthur|
Top 2th token. Logit: 14.42 Prob:  5.85% Token: | the|
Top 3th token. Logit: 13.46 Prob:  2.24% Token: | her|
Top 4th token. Logit: 13.09 Prob:  1.54% Token: | his|
Top 5th token. Logit: 12.59 Prob:  0.94% Token: | a|
Top 6th token. Logit: 12.52 Prob:  0.87% Token: | Mary|
Top 7th token. Logit: 12.47 Prob:  0.83% Token: | their|
Top 8th token. Logit: 11.71 Prob:  0.39% Token: | one|
Top 9th token. Logit: 11.69 Prob:  0.38% Token: | them|


In [7]:
batch_idx = 0

cv.attention.attention_patterns(
    tokens = model.to_str_tokens(ioi_dataset.toks[batch_idx]),
    attention = t.stack([ioi_cache["pattern", 10][batch_idx, 7], ioi_cache["pattern", 9][batch_idx, 9], ioi_cache["pattern", 11][batch_idx, 8]]),
    attention_head_names = ["10.7", "9.9", "11.10"]
)

In [47]:
def create_really_hacky_first_pass_metric(
    batch_size: int = 20,
    seed: int = 0,
):
    ioi_dataset = generate_data_and_caches(batch_size, model=model, seed=seed, only_ioi=True, prepend_bos=True, return_cache=False)

    unembedding_coeffs_all = t.zeros(12, 12)

    for layer, head_idx in tqdm(list(itertools.product(range(12), range(12)))):
        unembedding_coeffs = compute_weighted_avg_logit_suppression(
            model=model,
            toks=ioi_dataset.toks,
            head=(layer, head_idx),
            only_use_neg_projections=False,
            return_new_loss=False,
        )
        if isinstance(unembedding_coeffs, tuple):
            unembedding_coeffs = unembedding_coeffs[0]

        unembedding_coeffs_all[layer, head_idx] = unembedding_coeffs["per_sequence"].mean()

    return unembedding_coeffs_all

unembedding_coeffs_first_pass = create_really_hacky_first_pass_metric()

  0%|          | 0/144 [00:00<?, ?it/s]

In [48]:
imshow(
    unembedding_coeffs_first_pass * (unembedding_coeffs_first_pass < 0),
    title="Source suppression scores",
    # facet_col=0,
    # facet_labels=["Fraction of norm of residual stream output<br>preserved by projecting onto W<sub>U</sub>", "Average boost to logits of<br>source token after projection"],
    # facet_label_size=17,
    # margin=dict(t=80),
    width=600,
    height=600,
)
# * TODO - these are not sparse because it doesn't more highly weight the values that are more paid attention to. Maybe crop to all situations where >10% attn is paid? Is this a reasonable thing to do?
# The thing that confuses me most is how many false negatives there are. Why do so many heads have nontrivially large dot products? Aren't cosine sims about 0 for random vectors?

# OpenWebText

In [49]:
data = get_webtext(seed=6)
clear_output()

In [50]:
def create_really_hacky_second_pass_metric(
    batch_size: int = 60,
    seq_len: int = 100,
):
    toks = model.to_tokens(data[:batch_size])[:, :seq_len]

    unembedding_coeffs_all = t.zeros(12, 12)

    for layer, head_idx in tqdm(list(itertools.product(range(12), range(12)))):
        unembedding_coeffs = compute_weighted_avg_logit_suppression(
            model=model,
            toks=toks,
            head=(layer, head_idx),
            only_use_neg_projections=False,
            return_new_loss=False,
        )

        unembedding_coeffs_all[layer, head_idx] = unembedding_coeffs["per_sequence"].mean()

    return unembedding_coeffs_all

unembedding_coeffs_second_pass = create_really_hacky_second_pass_metric()
# unembedding_coeffs_second_pass_smaller_seqlen = create_really_hacky_second_pass_metric(seq_len=20)

  0%|          | 0/144 [00:00<?, ?it/s]

In [51]:
imshow(
    t.stack([
        unembedding_coeffs_first_pass * (unembedding_coeffs_first_pass < 0),
        unembedding_coeffs_second_pass * (unembedding_coeffs_second_pass < 0),
        # unembedding_coeffs_second_pass_smaller_seqlen * (unembedding_coeffs_second_pass_smaller_seqlen < 0),
    ]),
    facet_col=0,
    facet_labels=["IOI", "Webtext"],
    height=700, 
    width=1200, 
    title="Average direct effect of source token on prediction for source logit at destination position",
    text_auto=".2f",
    static=False
)

# ! This is "how much on average does the attention head's output suppress the logits of the source tokens (weighted by how much the source tokens are attended to)" (and with BOS and self-attn ignored)

# It's very large and negative on IOI (because basically all attn which isn't BOS or self-attn is attn to the IOI token, which negatively suppresses it)
# But it's also noticeably negative on Webtext (and still sparse).
# Key question - why isn't it sparse if you don't give more weight to the things which are attended to more? Answer - because most tokens in sequences are generic, like "the" and ",", so this doesn't really work for them.

In [None]:
imshow(
    t.stack([
        unembedding_coeffs_first_pass * (unembedding_coeffs_first_pass < 0),
        unembedding_coeffs_second_pass * (unembedding_coeffs_second_pass < 0),
        # unembedding_coeffs_second_pass_smaller_seqlen * (unembedding_coeffs_second_pass_smaller_seqlen < 0),
    ]),
    facet_col=0,
    facet_labels=["IOI", "Webtext"],
    height=700, 
    width=1200, 
    title="Average direct effect of source token on prediction for source logit at destination position",
    text_auto=".2f",
    static=False
)

# ! This is "how much on average does the attention head's output suppress the logits of the source tokens (weighted by how much the source tokens are attended to)" (and with BOS and self-attn ignored)

# It's very large and negative on IOI (because basically all attn which isn't BOS or self-attn is attn to the IOI token, which negatively suppresses it)
# But it's also noticeably negative on Webtext (and still sparse).
# Key question - why isn't it sparse if you don't give more weight to the things which are attended to more? Answer - because most tokens in sequences are generic, like "the" and ",", so this doesn't really work for them.

# Change in loss?

In [52]:
def compute_loss_diffs_from_projection(
    num_batches: int = 5,
    batch_size: int = 30,
    seq_len: int = 50,
):
    t.cuda.empty_cache()
    orig_loss_all = t.zeros(num_batches, batch_size, seq_len-1)
    head_removal_loss_all = t.zeros(model.cfg.n_layers, model.cfg.n_heads, 3, num_batches, batch_size, seq_len-1)

    progress_bar = tqdm(total=num_batches*model.cfg.n_layers*model.cfg.n_heads)

    for batch_idx in range(num_batches):
        toks = model.to_tokens(data[batch_idx*batch_size: (batch_idx+1)*batch_size])[:, :seq_len]

        orig_loss = model(toks, return_type="loss", loss_per_token=True)
        orig_loss_all[batch_idx] = orig_loss

        for layer in range(model.cfg.n_layers):
            for head_idx in range(model.cfg.n_heads):

                loss_proj, loss_perp, loss_ablation = compute_weighted_avg_logit_suppression(
                    model = model,
                    toks = toks,
                    head = (layer, head_idx),
                    only_use_neg_projections = False,
                    return_new_loss = True,
                )
                head_removal_loss_all[layer, head_idx, :, batch_idx] = t.stack([loss_proj, loss_perp, loss_ablation])
                progress_bar.update(1)

    return orig_loss_all, head_removal_loss_all

In [53]:
orig_loss_all, head_removal_loss_all = compute_loss_diffs_from_projection(
    num_batches=1,
    batch_size=10,
    seq_len=30
)

  0%|          | 0/144 [00:00<?, ?it/s]

In [57]:
orig_loss_all.mean()

tensor(4.3094)

: 

In [56]:
head_removal_loss_all.mean()

tensor(4.3109)

In [48]:
proportional_loss_increase: (head_removal_loss_all / orig_loss_all) - 1
absolute_loss_increase = head_removal_loss_all - orig_loss_all

mean_loss_increase = einops.reduce(
    absolute_loss_increase, "layer head type num_batches batch_idx seq_pos -> type layer head", "mean"
)

imshow(
    mean_loss_increase,
    facet_col=0,
    facet_labels=["Only keep projection", "Only keep perpendicular component", "Ablate entirely"],
    text_auto=".2f",
    title="Change in loss from projecting",
    width=1400,
    height=600,
    static=True,
)