In [1]:
from transformer_lens.cautils.notebook import *



In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)
clear_output()

In [8]:
N = 100
ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric_noising = generate_data_and_caches(N, model, verbose=True)
seq_len = ioi_dataset.toks.size(1)

Average logit diff (IOI dataset): 3.0731
Average logit diff (ABC dataset): 0.2349


In [9]:
model.tokenizer.pad_token_id, model.tokenizer.eos_token_id, model.tokenizer.bos_token_id

(50256, 50256, 50256)

# Patch from neg NMHs to neg NMHs

Theory was that the neg NMHs only suppress the IO token because the IO token gets predicted (thanks to the NMHs). In other words, they aren't acting in parallel to the NMHs while doing the opposite; rather they're taking the NMH output into their query vectors, using that to attend to the IO token, and then moving "suppress IO prediction" to the end token.

We already know the OV circuit basically does negative copying for names, so the key hypothesis here is what goes into the query. The plot below shows that path patching from NMH output to NNMH query input does actually significantly affect the attention patterns: it turns them from from "end token attends a lot more to IO than to S1" to "end token attends about equally to both"

In [10]:
NEG_NAME_MOVERS = [(10, 7), (11, 10)]
NAME_MOVERS = [(9, 6), (9, 9), (10, 10)]


def get_io_vs_s_attn_for_nmh(
    patched_cache: ActivationCache,
    orig_dataset: IOIDataset,
    orig_cache: ActivationCache,
    neg_nmh: Tuple[int, int],
) -> Float[Tensor, "batch"]:
    '''
    Returns the difference between patterns[END, IO] and patterns[END, S1], where patterns
    are the attention patterns for the negative name mover head.

    This is returned in the form of a tuple of 2 tensors: one for the patched distribution
    (calculated using `patched_cache` which is returned by the path patching algorithm), and
    one for the clean IOI distribution (which is just calculated directly from that cache).
    '''
    layer, head = neg_nmh
    attn_pattern_patched = patched_cache["pattern", layer][:, head]
    attn_pattern_clean = orig_cache["pattern", layer][:, head]
    # both are (batch, seq_Q, seq_K), and I want all the "end -> IO" attention probs

    N = orig_dataset.toks.size(0)
    io_seq_pos = orig_dataset.word_idx["IO"]
    s1_seq_pos = orig_dataset.word_idx["S1"]
    end_seq_pos = orig_dataset.word_idx["end"]

    return (
        attn_pattern_patched[range(N), end_seq_pos, io_seq_pos] - attn_pattern_patched[range(N), end_seq_pos, s1_seq_pos],
        attn_pattern_clean[range(N), end_seq_pos, io_seq_pos] - attn_pattern_clean[range(N), end_seq_pos, s1_seq_pos],
    )


def get_nnmh_patching_patterns(num_batches = 40, neg_nmh = NEG_NAME_MOVERS[0], orig_is_ioi = True):
    results_patched = t.empty(size=(0,)).to(device)
    results_clean = t.empty(size=(0,)).to(device)

    for seed in tqdm(range(num_batches)):

        ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric = generate_data_and_caches(20, model=model, seed=seed)

        if orig_is_ioi:
            orig_dataset = ioi_dataset
            new_dataset = abc_dataset
            orig_cache = ioi_cache
            new_cache = abc_cache
        else:
            orig_dataset = abc_dataset
            new_dataset = ioi_dataset
            orig_cache = abc_cache
            new_cache = ioi_cache

        new_results_patched, new_results_clean = path_patch(
            model,
            orig_input=orig_dataset.toks,
            new_input=new_dataset.toks,
            orig_cache=orig_cache,
            new_cache=new_cache,
            sender_nodes=[Node("z", layer=layer, head=head) for layer, head in NAME_MOVERS], # Output of all name mover heads
            receiver_nodes=Node("q", neg_nmh[0], head=neg_nmh[1]), # To query input of negative name mover head
            patching_metric=partial(get_io_vs_s_attn_for_nmh, orig_dataset=orig_dataset, orig_cache=orig_cache, neg_nmh=neg_nmh),
            apply_metric_to_cache=True,
            direct_includes_mlps=not(model.cfg.use_split_qkv_input),
        )
        results_patched = t.concat([results_patched, new_results_patched])
        results_clean = t.concat([results_clean, new_results_clean])

        t.cuda.empty_cache()

    return results_patched, results_clean

In [15]:
neg_nmh = (10, 7)
results_patched, results_clean = get_nnmh_patching_patterns(neg_nmh = neg_nmh)

  0%|          | 0/40 [00:00<?, ?it/s]

In [18]:
fig = hist(
    [results_patched, results_clean],
    labels={"variable": "", "value": "Attention paid to IO - Attention paid to S1"},
    title=f"Change in Attention of {neg_nmh}",
    names=["Patched Name Mover Heads", "Normal forward pass"],
    width=800,
    height=600,
    opacity=0.7,
    # marginal="box",
    template="simple_white",
    return_fig=True,
)

# add line at 0

fig.update_layout(
    shapes=[
        dict(
            type="line",
            xref="paper",
            yref="y",
            x0=0,
            y0=0,
            x1=1,
            y1=0,
            line=dict(
                color="black",
                width=3,
                dash="dash",
            ),
        )
    ]
)

In [7]:
results_patched, results_clean = get_nnmh_patching_patterns(neg_nmh = (11, 10))

hist(
    [results_patched, results_clean],
    labels={"variable": "Version", "value": "Attn diff (positive ⇒ more attn paid to IO than S1)"},
    title="Difference in attn from END➔IO vs. END➔S1 (path-patched vs clean)",
    names=["Patched (ABC)", "Clean (IOI)"],
    width=800,
    height=600,
    opacity=0.7,
    # marginal="box",
    template="simple_white"
)

  0%|          | 0/40 [00:00<?, ?it/s]

KeyboardInterrupt: 

I just realised, this should have been obvious from the paper. We can see the neg name mover heads' output getting much less significant post-ablation. In fact, `10.7` is entirely wiped out.

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/negnmh.png" width="600">

# Path patching from NMHs to backup NMHs

What's the equivalent result we might expect when backup name mover heads are considered?

It's the opposite - we expect them to not attend from END to IO much when the name mover heads are working normally, but when we path patch from name movers to backup name movers, that's when they'll kick in.

I'll do the plot for `11.2`, `10.6`, and `10.10`, because these are the biggest backup heads. I'll try doing the patching in both directions (from IOI to ABC, and vice-versa).

In [None]:
results_patched, results_clean = get_nnmh_patching_patterns(num_batches=50, neg_nmh=(10, 2))

hist(
    [results_patched, results_clean],
    labels={"variable": "Version", "value": "Attn diff (positive ⇒ more attn paid to IO than S1)"},
    title="Difference in attn from END➔IO vs. END➔S1 (path-patched vs clean)",
    names=["Patched (ABC)", "Orig (IOI)"],
    width=800,
    height=600,
    opacity=0.7,
    marginal="box",
    template="simple_white"
)

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
results_patched, results_clean = get_nnmh_patching_patterns(num_batches=50, neg_nmh=(10, 2), orig_is_ioi=False)

hist(
    [results_patched, results_clean],
    labels={"variable": "Version", "value": "Attn diff (positive ⇒ more attn paid to IO than S1)"},
    title="Difference in attn from END➔IO vs. END➔S1 (path-patched vs clean)",
    names=["Patched (IOI)", "Orig (ABC)"],
    width=800,
    height=600,
    opacity=0.7,
    marginal="box",
    template="simple_white"
)

  0%|          | 0/50 [00:00<?, ?it/s]

Conclusion - we do kinda see this (patching increases attn), but the effect is very weak.