# Setup

In [22]:
from transformer_lens.cautils.notebook import *

from transformer_lens.rs.callum.max_activating_exploration import print_best_outputs, find_best_improvements, clear_plots, decompose_attn_scores_full, decompose_attn_scores
from transformer_lens.rs.callum.keys_fixed import plot_contribution_to_attn_scores, create_fucking_massive_plot_1, create_fucking_massive_plot_2

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

data = get_webtext(seed=6)

clear_output()

In [3]:
LAYER_IDX, HEAD_IDX = (10, 7)
W_U = model.W_U.clone()
HEAD_HOOK_NAME = utils.get_act_name("result", LAYER_IDX)

NUM_PROMPTS = 100
BATCH_SIZE = 10

def hook_to_ablate_head(head_output: Float[Tensor, "batch seq_len head_idx d_head"], hook: HookPoint, head = (LAYER_IDX, HEAD_IDX)):
    assert head[0] == hook.layer()
    assert "result" in hook.name
    head_output[:, :, head[1], :] = 0
    return head_output

# Max activating examples for 10.7 (by ablation) 

Want to see where head 10.7 is most useful!

We can see cross-entropy loss increases by 0.01 on average when this head is ablated. That might seem like not a lot, but it's actually not far off distribution to other late-stage heads.

In [None]:
str_token_list = []
loss_list = []
ablated_loss_list = []

for i in tqdm(range(NUM_PROMPTS)):
    # new_str = data[BATCH_SIZE * i: BATCH_SIZE * (i + 1)]
    new_str = data[i]
    new_str_tokens = model.to_str_tokens(new_str)
    tokens = model.to_tokens(new_str)
    # tokens = t.stack(tokens).to(device)
    loss = model(tokens, return_type="loss", loss_per_token=True)
    ablated_loss = model.run_with_hooks(tokens, return_type="loss", loss_per_token=True, fwd_hooks=[(HEAD_HOOK_NAME, hook_to_ablate_head)])
    loss_list.append(loss)
    ablated_loss_list.append(ablated_loss)
    str_token_list.append(new_str_tokens)


all_loss = t.cat(loss_list, dim=-1).squeeze()
all_ablated_loss = t.cat(ablated_loss_list, dim=-1).squeeze()

hist(
    all_ablated_loss - all_loss,
    title="Difference in loss after ablating (positive ⇒ loss increases)",
    labels={"x": "Difference in cross-entropy loss"},
    template="simple_white",
    add_mean_line=True,
    width=1000,
    nbins=200,
    static=True,
)

In [5]:
total_num_tokens = sum(len(i) for i in str_token_list)
top_pct = int(total_num_tokens * 0.01)

best_k_indices, best_k_loss_decrease = find_best_improvements(str_token_list, loss_list, ablated_loss_list, k=top_pct)
worst_k_indices, worst_k_loss_decrease = find_best_improvements(str_token_list, loss_list, ablated_loss_list, k=top_pct, worst=True)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [6]:
caches_and_tokens = print_best_outputs(
    best_k_indices[:3],
    best_k_loss_decrease[:3],
    hook = (HEAD_HOOK_NAME, hook_to_ablate_head),
    model = model,
    data = data,
    n = 3,
    random = False,
    return_caches = True,
    names_filter = lambda name: name == utils.get_act_name("pattern", LAYER_IDX),
)

In [8]:
p = Path("/home/ubuntu/Transformerlens/transformer_lens/rs/callum/plots")

clear_plots()

window = 100

for i, (cache, tokens) in enumerate(caches_and_tokens):
    
    pattern = cache["pattern", LAYER_IDX][:, HEAD_IDX]
    pattern_sliced = pattern[:, -window:, -window:]
    html = cv.attention.attention_heads(
        attention = pattern_sliced,
        tokens = tokens[-window:],
        attention_head_names = [f"{LAYER_IDX}.{HEAD_IDX}, example {i}"]
    )
    
    with open(str((p / f"temp_file_{i}.html").resolve()), "w") as f:
        f.write(str(html))

    print("".join(tokens[-window:]))
    print("\n" + "=" * 60 + "\n")

 say 1 / 1 Back to Gallery

Police in Fairfield busted a teenager who was in possession of a half pound of methamphetamine, two guns, an ounce of marijuana and a wad of cash during a probation search Saturday evening, authorities said.

The raid happened around 5:45 p.m. at a home on the 800 block of Fifth Street, where officers discovered the teen and the trove of contraband, police said.

The teenager, a 17-year-old


 reduce stresses as much as possible so they can produce as much offspring as they can and they are more able to deal with the adverse effects of climate change.��

Advocates say more must be protected

The most effective way to protect species would be to create protected areas far from human activity, but in practice, there would have to be protected areas in a lot of different places so they could provide the broadest possible benefits, Roberts said.

He described the 10 percent


 York City in 1939. He was educated at Grinnell College (B.A., 1960) and Harvard Univer

In [9]:
clear_plots()

# Max activating examples for 10.7 (by removing direct effect) 

In [None]:
str_token_list = []
loss_list = []
patched_loss_list = []

for i in tqdm(range(NUM_PROMPTS)):
    # new_str = data[BATCH_SIZE * i: BATCH_SIZE * (i + 1)]
    new_str = data[i]
    new_str_tokens = model.to_str_tokens(new_str)
    tokens = model.to_tokens(new_str)
    # tokens = t.stack(tokens).to(device)
    loss = model(tokens, return_type="loss", loss_per_token=True)
    patched_loss = path_patch(
        model=model,
        orig_input=tokens,
        new_cache="zero",
        sender_nodes=Node("z", layer=10, head=7),
        receiver_nodes=Node("resid_post", layer=11),
        direct_includes_mlps=True,
        patching_metric="loss_per_token"
    )
    loss_list.append(loss)
    patched_loss_list.append(patched_loss)
    str_token_list.append(new_str_tokens)


all_loss = t.cat(loss_list, dim=-1).squeeze()
all_patched_loss = t.cat(patched_loss_list, dim=-1).squeeze()

In [10]:
hist(
    all_patched_loss - all_loss,
    title="Difference in loss after ablating (positive ⇒ loss increases)",
    labels={"x": "Difference in cross-entropy loss"},
    template="simple_white",
    add_mean_line=True,
    width=800,
    nbins=200,
    static=True,
)

In [22]:
total_num_tokens = sum(len(i) for i in str_token_list)
top_pct = int(total_num_tokens * 0.01)
top_half_pct = int(total_num_tokens * 0.005)

best_k_indices, best_k_loss_decrease = find_best_improvements(str_token_list, loss_list, patched_loss_list, k=top_half_pct)
worst_k_indices, worst_k_loss_decrease = find_best_improvements(str_token_list, loss_list, patched_loss_list, k=top_half_pct, worst=True)
clear_output()

In [23]:
caches_and_tokens = print_best_outputs(
    best_k_indices,
    best_k_loss_decrease,
    hook = (HEAD_HOOK_NAME, hook_to_ablate_head),
    model = model,
    data = data,
    n = 5,
    seed = 0,
    random = True,
    return_caches = False,
    # names_filter = lambda name: name == utils.get_act_name("pattern", LAYER_IDX),
)

# All's fair in love and war

First, verify that it does predict "love" with some prob. Yes, it does!

In [4]:
str_input = "All's fair in love and"
answer = " war"
incorrect = " love"
model.reset_hooks()
utils.test_prompt(str_input, answer, model)



Tokenized prompt: ['<|endoftext|>', 'All', "'s", ' fair', ' in', ' love', ' and']
Tokenized answer: [' war']


Top 0th token. Logit: 14.31 Prob: 12.80% Token: | war|
Top 1th token. Logit: 13.79 Prob:  7.59% Token: | hate|
Top 2th token. Logit: 12.95 Prob:  3.29% Token: | all|
Top 3th token. Logit: 12.87 Prob:  3.05% Token: | love|
Top 4th token. Logit: 12.82 Prob:  2.90% Token: | peace|
Top 5th token. Logit: 12.18 Prob:  1.52% Token: | good|
Top 6th token. Logit: 12.03 Prob:  1.31% Token: | in|
Top 7th token. Logit: 11.93 Prob:  1.19% Token: | death|
Top 8th token. Logit: 11.83 Prob:  1.08% Token: | friendship|
Top 9th token. Logit: 11.75 Prob:  1.00% Token: | happiness|


Now, I want to see what the direct effect of head 10.7 is on the logits.

In [5]:
toks = model.to_tokens(str_input)

model.reset_hooks()
logits, cache = model.run_with_cache(toks, return_type="logits")
logits = logits[0, -1]

# resid_post = t.stack([
#     cache["resid_post", layer][0, -1]
#     for layer in range(model.cfg.n_layers)
# ])
# resid_post_normalized = resid_post / cache["scale"][0]

# logit_lens = einops.einsum(
#     resid_post_normalized, model.W_U,
#     "batch d_model, d_model d_vocab -> batch d_vocab",
# )

neg_head_output = cache["result", 10][0, -1, 7]
neg_head_logits = neg_head_output @ model.W_U
assert neg_head_logits.shape == (model.cfg.d_vocab,)
neg_head_logprobs = neg_head_logits.log_softmax(dim=-1)

top5 = neg_head_logprobs.topk(5, largest=False)

for index, value in zip(top5.indices, top5.values):
    token = model.to_single_str_token(index.item())
    print(f"|{token}| = {value:.2f}")

|love| = -51.87
| Love| = -50.06
|Love| = -49.12
| LOVE| = -46.14
| love| = -46.00


Amazing! Does it actually predict "love" if head 10.7 isn't firing?

In [6]:
model.add_hook(HEAD_HOOK_NAME, hook_to_ablate_head)

utils.test_prompt(str_input, answer, model)

Tokenized prompt: ['<|endoftext|>', 'All', "'s", ' fair', ' in', ' love', ' and']
Tokenized answer: [' war']


Top 0th token. Logit: 14.29 Prob: 10.47% Token: | war|
Top 1th token. Logit: 14.11 Prob:  8.78% Token: | love|
Top 2th token. Logit: 14.00 Prob:  7.82% Token: | hate|
Top 3th token. Logit: 13.17 Prob:  3.41% Token: | all|
Top 4th token. Logit: 12.74 Prob:  2.22% Token: | peace|
Top 5th token. Logit: 12.47 Prob:  1.70% Token: | good|
Top 6th token. Logit: 12.38 Prob:  1.55% Token: | happiness|
Top 7th token. Logit: 12.31 Prob:  1.44% Token: | in|
Top 8th token. Logit: 12.08 Prob:  1.15% Token: | death|
Top 9th token. Logit: 12.05 Prob:  1.12% Token: | friendship|


Yes, it does! (with second highest probability).

## Deeper dive into love and war

Now, let's break down this particular example, and see which heads are responsible for this dumb copying behaviour. I suspect it'll be one of the early ones.

Also, I originally thought that maybe nothing copies it (because war and love might be close embeddings). This is possible but unlikely, given love has pretty high logits without any interventions (and `" War"` doesn't, nor do any words like this).

### Direct logit attribution - which heads write in the "love" direction?

I'll look at `" love" - " war"` because this'll be easier. I'll crib code from the IOI notebook.

In [7]:
answer_tokens = model.to_tokens([" love", " war"], prepend_bos=False).T
# answer_tokens = [" love", " war"]

answer_residual_directions: Float[Tensor, "batch=1 2 d_model"] = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)

love_residual_directions, war_residual_directions = answer_residual_directions.unbind(dim=1)
logit_diff_directions: Float[Tensor, "batch=1 d_model"] = love_residual_directions - war_residual_directions
print(f"Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([1, 2, 768])
Logit difference directions shape: torch.Size([1, 768])


In [8]:
def residual_stack_to_logit_diff(
    residual_stack: Float[Tensor, "... batch d_model"],
    cache: ActivationCache,
    logit_diff_directions: Float[Tensor, "batch d_model"] = logit_diff_directions,
) -> Float[Tensor, "..."]:
    '''
    Gets the avg logit difference between the correct and incorrect answer for a given
    stack of components in the residual stream.
    '''
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
    return einops.einsum(scaled_residual_stack, logit_diff_directions, "... batch d_model, batch d_model -> ...") / batch_size

In [9]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_residual = einops.rearrange(per_head_residual, "(layer head) ... -> layer head ...", layer=model.cfg.n_layers)

per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_love = residual_stack_to_logit_diff(per_head_residual, cache, logit_diff_directions=love_residual_directions)
per_head_logit_war = residual_stack_to_logit_diff(per_head_residual, cache, logit_diff_directions=war_residual_directions)

imshow(
    t.stack([per_head_logit_love, per_head_logit_war, per_head_logit_diffs]),
    facet_col=0,
    facet_labels=["' love'", "' war'", "' love' - ' war'"],
    labels={"x":"Head", "y":"Layer"},
    title="Logit Difference From Each Head",
    width=1200,
    static=True,
)

It looks like there are quite a few heads which are writing `' love'` into the residual stream. `8.6` is most notable, but other important ones are `6.1`, `8.8`, `9.2`, `9.3`, `9.6`.

The important thing isn't which heads write to the residual stream though, it's whether `' love'` is present in the residual stream, and if so then by how much (and also what proportion of the `" love"` unembedding is used to construct the query?).

In [46]:
toks = model.to_tokens(str_input)
dest_indices = t.tensor([-1])
src_indices = t.tensor([model.to_str_tokens(str_input).index(" love")]) # = -2

str_input_baseline = "All's fair in war and"
toks_baseline = model.to_tokens(str_input_baseline)
src_baseline_indices = t.tensor([model.to_str_tokens(str_input_baseline).index(" war")]) # = -2

# attn_scores_full = decompose_attn_scores_full(
#     toks = toks,
#     dest_indices = dest_indices,
#     src_indices = src_indices,
#     src_baseline_indices = None,
#     nnmh = (10, 7),
#     model = model,
#     use_effective_embedding = False,
#     use_layer0_heads = False
# )

# create_fucking_massive_plot_1(attn_scores_full)

contribution_to_attn_scores = decompose_attn_scores(
    toks = toks,
    dest_indices = dest_indices,
    src_indices = src_indices,
    src_baseline_indices = t.tensor([4]), # src_baseline_indices,
    toks_baseline = None, # toks_baseline,
    nnmh = (10, 7),
    model = model,
    decompose_by = "keys",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = "project_to_MLP0",
    use_effective_embedding = False,
    use_layer0_heads = False,
)

plot_contribution_to_attn_scores(
    t.stack([
        contribution_to_attn_scores[('IO_dir', 'MLP0_dir')],
        contribution_to_attn_scores[('IO_dir', 'MLP0_perp')],
        contribution_to_attn_scores[('IO_perp', 'MLP0_dir')],
        contribution_to_attn_scores[('IO_perp', 'MLP0_perp')],
    ]),
    decompose_by = "keys",
    facet_labels = [
        "q ∥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>",
        "q ∥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ∥ MLP<sub>0</sub>", 
        "q ⊥ W<sub>U</sub>[IO], k ⊥ MLP<sub>0</sub>"
    ],
    facet_col_wrap = 2,
    title = "Decompose on query-side, split by projections key & query-side"
)

In [49]:
toks = model.to_tokens(str_input)
dest_indices = t.tensor([-1])
src_indices = t.tensor([model.to_str_tokens(str_input).index(" love")]) # = -2

str_input_baseline = "All's fair in war and"
toks_baseline = model.to_tokens(str_input_baseline)
src_baseline_indices = t.tensor([model.to_str_tokens(str_input_baseline).index(" war")]) # = -2

contribution_to_attn_scores = decompose_attn_scores(
    toks = toks,
    dest_indices = dest_indices,
    src_indices = src_indices,
    src_baseline_indices = t.tensor([4]), # src_baseline_indices,
    toks_baseline = None, # toks_baseline,
    nnmh = (10, 7),
    model = model,
    decompose_by = "keys",
    intervene_on_query = "project_to_W_U_IO",
    intervene_on_key = None,
    use_effective_embedding = False,
    use_layer0_heads = False,
)

plot_contribution_to_attn_scores(
    t.stack([
        contribution_to_attn_scores[('IO_dir', 'unchanged')],
        contribution_to_attn_scores[('IO_perp', 'unchanged')],
    ]),
    decompose_by = "keys",
    facet_labels = [
        "q ∥ W<sub>U</sub>[IO]",
        "q ⊥ W<sub>U</sub>[IO]", 
    ],
    facet_col_wrap = 2,
    title = "Decompose on query-side, split by projections key & query-side"
)

In [48]:
contribution_to_attn_scores.keys()

dict_keys([('IO_dir', 'unchanged'), ('IO_perp', 'unchanged')])

## Dataset


GPT generated more examples with the following prompt:

> Give me 30 examples of short sentences with common "x and y" word pairings, for example:
> 
> All's fair in love and war (love, war)
> Nothing is certain except death and taxes (death, taxes)
> 
> The x and y should come at the end of the sentence.

<details>
<summary>Dataset</summary>

```python
s = r"""We remember him for his kindness and generosity
It's a balance of tradition and innovation
She loved the colors of pink and blue
The main ingredients are flour and sugar
It was a journey of self-discovery and transformation
They chose to adopt a lifestyle of simplicity and minimalism
Their marriage was a combination of trust and respect
The garden is full of birds and butterflies
This meal needs a touch of salt and pepper
The sunset was a blend of orange and purple
The store sells goods new and used
The climate here is both hot and humid
The area is known for its wine and cheese
The theme of the party was black and white
She's a blend of beauty and brains
This painting is a fusion of reality and fantasy
His life was a mixture of pleasure and pain
The landscape was filled with trees and flowers
She was a paradox of innocence and cunning
The play was a mix of tragedy and comedy
The book is a combination of facts and fiction
The sky was filled with stars and moonlight
We traveled by both land and sea
The pie was a combination of sweet and savory
This role requires both skill and dedication
The holiday was filled with rest and relaxation
Their home was full of love and warmth
The music was a blend of rhythm and melody
We saw a combination of elephants and lions
The path was full of twists and turns"""

sentences = s.split("\n")

sentences, answers = map(list, zip(*[
    (" ".join(i.split(" ")[:-1]), " " + i.split(" ")[-1])
    for i in sentences
]))
```
</details>