In [1]:
from transformer_lens.cautils.notebook import *

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
model = HookedTransformer.from_pretrained(
    "pythia-410m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)
clear_output()

In [3]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.29 Prob: 64.54% Token: | Mary|
Top 1th token. Logit: 16.21 Prob:  8.05% Token: | the|
Top 2th token. Logit: 15.58 Prob:  4.29% Token: | his|
Top 3th token. Logit: 15.45 Prob:  3.79% Token: | each|
Top 4th token. Logit: 14.23 Prob:  1.11% Token: | a|
Top 5th token. Logit: 14.14 Prob:  1.02% Token: | her|
Top 6th token. Logit: 13.69 Prob:  0.65% Token: | Mrs|
Top 7th token. Logit: 13.67 Prob:  0.64% Token: | John|
Top 8th token. Logit: 13.16 Prob:  0.38% Token: | their|
Top 9th token. Logit: 12.78 Prob:  0.26% Token: | one|


## Neel's setup

In [4]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
name_pairs = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]

# Define 8 prompts, in 4 groups of 2 (with adjacent prompts having answers swapped)
prompts = [
    prompt.format(name)
    for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1]
]
# Define the answers for each prompt, in the form (correct, incorrect)
answers = [names[::i] for names in name_pairs for i in (1, -1)]
# Define the answer tokens (same shape as the answers)
answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in answers
])

tokens = model.to_tokens(prompts, prepend_bos=True)
# Move the tokens to the GPU
tokens = tokens.to(device)
# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

def logits_to_ave_logit_diff(
    logits: Float[t.Tensor, "batch seq d_vocab"],
    answer_tokens: Float[t.Tensor, "batch 2"] = answer_tokens,
    per_prompt: bool = False
):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''
    # Only the final logits are relevant for the answer
    final_logits: Float[t.Tensor, "batch d_vocab"] = logits[:, -1, :]
    # Get the logits corresponding to the indirect object / subject tokens respectively
    answer_logits: Float[t.Tensor, "batch 2"] = final_logits.gather(dim=-1, index=answer_tokens)
    # Find logit difference
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

original_per_prompt_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
print("Per prompt logit difference:", original_per_prompt_diff)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", original_average_logit_diff)

answer_residual_directions: Float[t.Tensor, "batch 2 d_model"] = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)

correct_residual_directions, incorrect_residual_directions = answer_residual_directions.unbind(dim=1)
logit_diff_directions: Float[t.Tensor, "batch d_model"] = correct_residual_directions - incorrect_residual_directions
print(f"Logit difference directions shape:", logit_diff_directions.shape)

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream: Float[Tensor, "batch seq d_model"] = cache["resid_post", -1]
print(f"Final residual stream shape: {final_residual_stream.shape}")
final_token_residual_stream: Float[Tensor, "batch d_model"] = final_residual_stream[:, -1, :]

# Apply LayerNorm scaling (to just the final sequence position)
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer=-1, pos_slice=-1)

average_logit_diff = einops.einsum(
    scaled_final_token_residual_stream, logit_diff_directions,
    "batch d_model, batch d_model ->"
) / len(prompts)

print(f"Calculated average logit diff: {average_logit_diff:.10f}")
print(f"Original logit difference:     {original_average_logit_diff:.10f}")

t.testing.assert_close(average_logit_diff, original_average_logit_diff)

def residual_stack_to_logit_diff(
    residual_stack: Float[Tensor, "... batch d_model"],
    cache: ActivationCache,
    logit_diff_directions: Float[Tensor, "batch d_model"] = logit_diff_directions,
) -> Float[Tensor, "..."]:
    '''
    Gets the avg logit difference between the correct and incorrect answer for a given
    stack of components in the residual stream.
    '''
    # SOLUTION
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
    return einops.einsum(
        scaled_residual_stack, logit_diff_directions,
        "... batch d_model, batch d_model -> ..."
    ) / batch_size


# Test function by checking that it gives the same result as the original logit difference
t.testing.assert_close(
    residual_stack_to_logit_diff(final_token_residual_stream, cache),
    original_average_logit_diff
)

Per prompt logit difference: tensor([4.4561, 3.8249, 5.3478, 3.8952, 6.8011, 6.1392, 2.9355, 5.0044],
       device='cuda:0')
Average logit difference: tensor(4.8005, device='cuda:0')
Answer residual directions shape: torch.Size([8, 2, 1024])
Logit difference directions shape: torch.Size([8, 1024])
Final residual stream shape: torch.Size([8, 15, 1024])
Calculated average logit diff: 4.8005232811
Original logit difference:     4.8005232811


# 1️⃣ Logit Attribution

In [20]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=False, pos_slice=-1, return_labels=True)
# accumulated_residual has shape (component, batch, d_model)

logit_lens_logit_diffs: Float[Tensor, "component"] = residual_stack_to_logit_diff(accumulated_residual, cache)

line(
    logit_lens_logit_diffs,
    hovermode="x unified",
    title="Logit Difference From Accumulated Residual Stream",
    labels={"x": "Layer", "y": "Logit Diff"},
    xaxis_tickvals=labels,
    width=800
)

In [26]:
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)

line(
    per_layer_logit_diffs,
    hovermode="x unified",
    title="Logit Difference From Each Layer",
    labels={"x": "Layer", "y": "Logit Diff"},
    xaxis_tickvals=labels,
    width=800
)

In [28]:
line(
    per_layer_logit_diffs[2::2],
    hovermode="x unified",
    title="Logit Difference From Each Layer",
    labels={"x": "Layer", "y": "Logit Diff"},
    xaxis_tickvals=labels[2::2],
    width=800
)

In [30]:
line(
    [per_layer_logit_diffs[2::2], per_layer_logit_diffs[1::2]],
    hovermode="x unified",
    title="Logit Difference From Each Layer",
    labels={"x": "Layer", "y": "Logit Diff"},
    xaxis_tickvals=labels[2::2],
    width=800
)

In [22]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_residual = einops.rearrange(
    per_head_residual,
    "(layer head) ... -> layer head ...",
    layer=model.cfg.n_layers
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)

imshow(
    per_head_logit_diffs,
    labels={"x":"Head", "y":"Layer"},
    title="Logit Difference From Each Head",
    width=600
)

Tried to stack head results when they weren't cached. Computing head results now


In [37]:
def topk_of_Nd_tensor(tensor: Float[Tensor, "rows cols"], k: int):
    '''
    Helper function: does same as tensor.topk(k).indices, but works over 2D tensors.
    Returns a list of indices, i.e. shape [k, tensor.ndim].

    Example: if tensor is 2D array of values for each head in each layer, this will
    return a list of heads.
    '''
    i = t.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()

from IPython.display import HTML, display
import circuitsvis as cv

k = 3

for head_type in ["Positive", "Negative"]:

    # Get the heads with largest (or smallest) contribution to the logit difference
    top_heads = topk_of_Nd_tensor(per_head_logit_diffs * (1 if head_type=="Positive" else -1), k)

    # Get all their attention patterns
    attn_patterns_for_important_heads: Float[Tensor, "head q k"] = t.stack([
        cache["pattern", layer][:, head][1]
         for layer, head in top_heads
    ])

    # Display results
    display(HTML(f"<h2>Top {k} {head_type} Logit Attribution Heads</h2>"))
    display(cv.attention.attention_patterns(
        attention = attn_patterns_for_important_heads,
        tokens = model.to_str_tokens(tokens[1]),
        attention_head_names = [f"{layer}.{head}" for layer, head in top_heads],
    ))

# 2️⃣ Activation Patching

In [5]:
clean_tokens = tokens
# Swap each adjacent pair to get corrupted tokens
indices = [i+1 if i % 2 == 0 else i-1 for i in range(len(tokens))]
corrupted_tokens = clean_tokens[indices]

print(
    "Clean string 0:    ", model.to_string(clean_tokens[0]), "\n"
    "Corrupted string 0:", model.to_string(corrupted_tokens[0])
)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

def ioi_metric(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    corrupted_logit_diff: float = corrupted_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
) -> Float[Tensor, ""]:
    '''
    Linear function of logit diff, calibrated so that it equals 0 when performance is
    same as on corrupted input, and 1 when performance is same as on clean input.
    '''
    # SOLUTION
    patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)
    return ((patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff  - corrupted_logit_diff)).item()

t.testing.assert_close(ioi_metric(clean_logits), 1.0)
t.testing.assert_close(ioi_metric(corrupted_logits), 0.0)
t.testing.assert_close(ioi_metric((clean_logits + corrupted_logits) / 2), 0.5)

Clean string 0:     <|endoftext|>When John and Mary went to the shops, John gave the bag to 
Corrupted string 0: <|endoftext|>When John and Mary went to the shops, Mary gave the bag to
Clean logit diff: 4.8005
Corrupted logit diff: -4.8005


In [33]:
act_patch_resid_pre = patching.get_act_patch_resid_pre(
    model = model,
    corrupted_tokens = corrupted_tokens,
    clean_cache = clean_cache,
    patching_metric = ioi_metric
)

labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]

imshow(
    act_patch_resid_pre,
    labels={"x": "Position", "y": "Layer"},
    x=labels,
    title="resid_pre Activation Patching",
    width=600
)

  0%|          | 0/360 [00:00<?, ?it/s]

In [46]:
results = act_patch(
    model=model,
    orig_input=corrupted_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode(["z", "q", "k", "v", "pattern"]), # iterating over all heads' output in all layers
    patching_metric=ioi_metric,
    verbose=True,
)

  0%|          | 0/1920 [00:00<?, ?it/s]

results['z'].shape = (layer=24, head=16)
results['q'].shape = (layer=24, head=16)
results['k'].shape = (layer=24, head=16)
results['v'].shape = (layer=24, head=16)
results['pattern'].shape = (layer=24, head=16)


In [None]:
assert results.keys() == {"z", "q", "k", "v", "pattern"}
assert all([r.shape == (model.cfg.n_layers, model.cfg.n_heads) for r in results.values()])

imshow(
    t.stack(tuple(results.values())) * 100,
    facet_col=0,
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
    title="Patching output of attention heads (corrupted -> clean)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=1500,
)

# 3️⃣ Path Patching - histogram experiment

In [11]:
NAME_MOVERS = [(17, 10), (17, 11), (18, 8), (18, 12), (18, 13), (18, 14)]
NEG_NAME_MOVERS = [(19, 15), (20, 15)]


def get_io_vs_s_attn_for_nmh(
    patched_cache: ActivationCache,
    orig_dataset: IOIDataset,
    orig_cache: ActivationCache,
    neg_nmh: Tuple[int, int],
) -> Float[Tensor, "batch"]:
    '''
    Returns the difference between patterns[END, IO] and patterns[END, S1], where patterns
    are the attention patterns for the negative name mover head.

    This is returned in the form of a tuple of 2 tensors: one for the patched distribution
    (calculated using `patched_cache` which is returned by the path patching algorithm), and
    one for the clean IOI distribution (which is just calculated directly from that cache).
    '''
    layer, head = neg_nmh
    attn_pattern_patched = patched_cache["pattern", layer][:, head]
    attn_pattern_clean = orig_cache["pattern", layer][:, head]
    # both are (batch, seq_Q, seq_K), and I want all the "end -> IO" attention probs

    N = orig_dataset.toks.size(0)
    io_seq_pos = orig_dataset.word_idx["IO"]
    s1_seq_pos = orig_dataset.word_idx["S1"]
    end_seq_pos = orig_dataset.word_idx["end"]

    return (
        attn_pattern_patched[range(N), end_seq_pos, io_seq_pos] - attn_pattern_patched[range(N), end_seq_pos, s1_seq_pos],
        attn_pattern_clean[range(N), end_seq_pos, io_seq_pos] - attn_pattern_clean[range(N), end_seq_pos, s1_seq_pos],
    )


def get_nnmh_patching_patterns(num_batches = 50, N = 10, neg_nmh = NEG_NAME_MOVERS[0], orig_is_ioi = True):
    results_patched = t.empty(size=(0,)).to(device)
    results_clean = t.empty(size=(0,)).to(device)

    for seed in tqdm(range(num_batches)):

        ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric = generate_data_and_caches(N, model=model, seed=seed)

        if orig_is_ioi:
            orig_dataset = ioi_dataset
            new_dataset = abc_dataset
            orig_cache = ioi_cache
        else:
            orig_dataset = abc_dataset
            new_dataset = ioi_dataset
            orig_cache = abc_cache

        new_results_patched, new_results_clean = path_patch(
            model,
            orig_input=orig_dataset.toks,
            new_input=new_dataset.toks,
            orig_cache=ioi_cache,
            new_cache=abc_cache,
            sender_nodes=[Node("z", layer=layer, head=head) for layer, head in NAME_MOVERS], # Output of all name mover heads
            receiver_nodes=Node("q", neg_nmh[0], head=neg_nmh[1]), # To query input of negative name mover head
            patching_metric=partial(get_io_vs_s_attn_for_nmh, orig_dataset=ioi_dataset, orig_cache=ioi_cache, neg_nmh=neg_nmh),
            apply_metric_to_cache=True,
            direct_includes_mlps=False,
        )
        results_patched = t.concat([results_patched, new_results_patched])
        results_clean = t.concat([results_clean, new_results_clean])

        t.cuda.empty_cache()

    return results_patched, results_clean

In [12]:
results_patched, results_clean = get_nnmh_patching_patterns(neg_nmh = NEG_NAME_MOVERS[0])

  0%|          | 0/50 [00:00<?, ?it/s]

In [13]:
hist(
    [results_patched, results_clean],
    labels={"variable": "Version", "value": "Attn diff (positive ⇒ more attn paid to IO than S1)"},
    title="Difference in attn from END➔IO vs. END➔S1 (path-patched vs clean)",
    names=["Patched", "Clean"],
    width=800,
    height=600,
    opacity=0.7,
    marginal="box",
    template="simple_white"
)