In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "vscode"

In [3]:
import torch
from fancy_einsum import einsum
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils, ActivationCache
from torchtyping import TensorType as TT
import plotly.express as px
import numpy as np
import einops
from typing import List, Union, Optional
import pysvelte
from IPython.display import HTML
from functools import partial


IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html



In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x10bf66a40>

In [5]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [6]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.has_mps else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-large",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer


In [318]:
import os
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# example_prompt = "I climbed the pear tree and picked a pear. I climbed the lemon tree and picked"
# example_prompt = "I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked"
# example_prompt = "I climbed up the pear tree and picked a pear. I climbed up the avocado tree and picked"
example_prompt = "To fight the Civil War, the king will need to raise"
# example_prompt = "I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked"
example_answer = " an"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'To', ' fight', ' the', ' Civil', ' War', ',', ' the', ' king', ' will', ' need', ' to', ' raise']
Tokenized answer: [' an']
stable sort


Top 0th token. Logit: 16.87 Prob: 32.46% Token: | an|
Top 1th token. Logit: 16.24 Prob: 17.32% Token: | a|
Top 2th token. Logit: 15.21 Prob:  6.19% Token: | the|
Top 3th token. Logit: 14.78 Prob:  4.03% Token: | armies|
Top 4th token. Logit: 14.64 Prob:  3.50% Token: | more|
Top 5th token. Logit: 14.53 Prob:  3.13% Token: | his|
Top 6th token. Logit: 14.41 Prob:  2.77% Token: | troops|
Top 7th token. Logit: 14.29 Prob:  2.48% Token: | up|
Top 8th token. Logit: 14.23 Prob:  2.32% Token: | money|
Top 9th token. Logit: 14.01 Prob:  1.87% Token: | taxes|


In [269]:
# prompt_format = "I climbed up the {} tree and picked {} {}. I climbed up the {} tree and picked"
# fruits_answer_and_wrong_answer = [
#     # ("pear", "a", "an"),
#     ("apple", "an", "a"),
#     # ("lemon", "a", "an"),
#     ("orange", "an", "a"),
# ]
# prompts = []
# # List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
# answer_and_wrong_answer_tokens = []
# max_prompt_count = 2
# prompt_count = 0
# for j, (fruit1, answer1, wrong_answer1) in enumerate(fruits_answer_and_wrong_answer):
#     for i, (fruit2, answer2, wrong_answer2) in enumerate(fruits_answer_and_wrong_answer):
#         if fruit1 != fruit2 and prompt_count < max_prompt_count:
#             prompts.append(prompt_format.format(fruit1, answer1, fruit1, fruit2))
#             answer_and_wrong_answer_tokens.append((model.to_single_token(f" {answer2}"), model.to_single_token(f" {wrong_answer2}")))
#             prompt_count += 1

original_prompts = ["I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked",
           "I climbed up the pear tree and picked a pear. I climbed up the avocado tree and picked",
           "To fight the Civil War, the king will need to raise"]
original_long_prompts = original_prompts[0:2]
original_short_prompts = original_prompts[2:3]
answer_and_wrong_answer_tokens = torch.tensor([(model.to_single_token(" an"), model.to_single_token(" a")) for _ in range(len(original_prompts))]).to(device=device)
long_answer_and_wrong_answer_tokens = answer_and_wrong_answer_tokens[0:2]
short_answer_and_wrong_answer_tokens = answer_and_wrong_answer_tokens[2:3]
print(original_prompts)
print(len(original_prompts))
print(answer_and_wrong_answer_tokens)

['I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked', 'I climbed up the pear tree and picked a pear. I climbed up the avocado tree and picked', 'To fight the Civil War, the king will need to raise']
3
tensor([[281, 257],
        [281, 257],
        [281, 257]])


In [168]:
def print_predictions_from_logits(logits_to_print, prediction_prompts, show_raw_logits=False, top_k=5, prepend_bos=True):
    sorted_logits, sorted_logits_indices = logits_to_print.sort(dim=-1, descending=True)
    probs = logits_to_print.softmax(dim=-1)
    sorted_probs, sorted_probs_indices = probs.sort(dim=-1, descending=True)
    for i, prompt in enumerate(prediction_prompts):
        print(f'prompt: {model.to_str_tokens(prompt, prepend_bos=prepend_bos)}')
        prompt_token_index = model.to_tokens(prompt).shape[1] - 1
        print('prompt_token_index:', prompt_token_index)
        for k in range(top_k):
            print(f"'{model.tokenizer.decode(sorted_probs_indices[i, prompt_token_index, k].item())}'", '-', f"{sorted_probs[i, prompt_token_index, k].item() * 100:.2f}%", end='')
            if show_raw_logits:
                print(' -', f"{sorted_logits[i, prompt_token_index, k].item():.2f}", end='')
            print()
    # patched_neuron_correct_incorrect_logit_diff = logits_to_ave_correct_incorrect_logit_diff(logits_to_print, answer_and_wrong_answer_tokens)

In [270]:
original_tokens = model.to_tokens(original_prompts, prepend_bos=True).to(device=device)
original_logits, original_cache = model.run_with_cache(original_tokens)
original_long_tokens = model.to_tokens(original_long_prompts, prepend_bos=True).to(device=device)
original_long_logits, original_long_cache = model.run_with_cache(original_long_tokens)
original_short_tokens = model.to_tokens(original_short_prompts, prepend_bos=True).to(device=device)
original_short_logits, original_short_cache = model.run_with_cache(original_short_tokens)
print_predictions_from_logits(original_logits, original_prompts, show_raw_logits=True)

prompt: ['<|endoftext|>', 'I', ' climbed', ' up', ' the', ' pear', ' tree', ' and', ' picked', ' a', ' pear', '.', ' I', ' climbed', ' up', ' the', ' apple', ' tree', ' and', ' picked']
prompt_token_index: 19
' an' - 64.92% - 20.52
' a' - 24.22% - 19.53
' apples' - 2.78% - 17.37
' two' - 2.43% - 17.23
' another' - 2.07% - 17.07
prompt: ['<|endoftext|>', 'I', ' climbed', ' up', ' the', ' pear', ' tree', ' and', ' picked', ' a', ' pear', '.', ' I', ' climbed', ' up', ' the', ' avocado', ' tree', ' and', ' picked']
prompt_token_index: 19
' an' - 56.61% - 19.77
' a' - 31.79% - 19.20
' two' - 3.02% - 16.84
' some' - 2.09% - 16.47
' another' - 1.74% - 16.29
prompt: ['<|endoftext|>', 'To', ' fight', ' the', ' Civil', ' War', ',', ' the', ' king', ' will', ' need', ' to', ' raise']
prompt_token_index: 12
' an' - 32.46% - 16.87
' a' - 17.32% - 16.24
' the' - 6.19% - 15.21
' armies' - 4.03% - 14.78
' more' - 3.50% - 14.64


In [272]:
def logits_to_ave_correct_incorrect_logit_diff(logits, prompts, answer_and_wrong_answer_tokens, per_prompt=False):
    final_token_indices = torch.tensor([model.to_tokens(p).shape[1] - 1 for p in prompts])
    gather_final_logits = einops.repeat(final_token_indices, '(batch seq) -> batch seq vocab', batch=logits.shape[0], vocab=logits.shape[-1])
    final_token_logits = logits.gather(dim=1, index=gather_final_logits).squeeze(1).to(device=device)
    answer_logits = final_token_logits.gather(dim=-1, index=answer_and_wrong_answer_tokens)
    correct_answer_logits = answer_logits[:, 0]
    incorrect_answer_logits = answer_logits[:, 1]
    answer_logit_diff = correct_answer_logits - incorrect_answer_logits
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

original_average_logit_diff = logits_to_ave_correct_incorrect_logit_diff(original_logits, original_prompts, answer_and_wrong_answer_tokens)
original_long_average_logit_diff = logits_to_ave_correct_incorrect_logit_diff(original_long_logits, original_long_prompts, long_answer_and_wrong_answer_tokens)
original_short_average_logit_diff = logits_to_ave_correct_incorrect_logit_diff(original_short_logits, original_short_prompts, short_answer_and_wrong_answer_tokens)
print("Average logit difference:", original_average_logit_diff.item())
print("Long average logit difference:", original_long_average_logit_diff.item())
print("Short average logit difference:", original_short_average_logit_diff.item())

Average logit difference: 0.7304058074951172
Long average logit difference: 0.7815141677856445
Short average logit difference: 0.6281833648681641


In [258]:
answer_residual_directions = model.tokens_to_residual_directions(answer_and_wrong_answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([3, 2, 1280])
Logit difference directions shape: torch.Size([3, 1280])


In [260]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = original_cache["resid_post", -1]
final_token_indices = torch.tensor([model.to_tokens(p).shape[1] - 1 for p in original_prompts])
gather_final_token_res = einops.repeat(final_token_indices, '(batch seq) -> batch seq dmodel', batch=len(original_prompts), dmodel=final_residual_stream.shape[-1])
final_token_residual_stream = final_residual_stream.gather(dim=1, index=gather_final_token_res).squeeze(1).to(device=device)
print("Final token stream shape:", final_token_residual_stream.shape)
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = original_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(original_prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",original_average_logit_diff.item())

Final token stream shape: torch.Size([3, 1280])
Calculated average logit diff: 0.8913342356681824
Original logit difference: 0.7304058074951172


In [261]:

def residual_stack_to_logit_diff(residual_stack: TT["components", "batch", "d_model"], cache: ActivationCache) -> float:
    gather_accumulated_res = einops.repeat(torch.arange((len(original_prompts))), '(batch seq) -> resids batch seq dmodel', resids=residual_stack.shape[0], batch=len(original_prompts), dmodel=residual_stack.shape[-1])
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=final_token_indices)
    print('scaled_residual_stack shape:', scaled_residual_stack.shape)
    scaled_residual_stack = scaled_residual_stack.gather(dim=2, index=gather_accumulated_res).squeeze(2).to(device=device)
    print('scaled residual stack shape:', scaled_residual_stack.shape)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(original_prompts)

In [262]:
accumulated_residual, labels = original_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=final_token_indices, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, original_cache)
print('logit_lens_logit_diffs shape:', logit_lens_logit_diffs.shape)
# logit_lens_logit_diffs = logit_lens_logit_diffs.gather(dim=-1, index=gather_accumulated_res).squeeze(2)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulate Residual Stream")

scaled_residual_stack shape: torch.Size([73, 3, 3, 1280])
scaled residual stack shape: torch.Size([73, 3, 1280])
logit_lens_logit_diffs shape: torch.Size([73])


In [263]:
per_layer_residual, labels = original_cache.decompose_resid(layer=-1, pos_slice=final_token_indices, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, original_cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

scaled_residual_stack shape: torch.Size([74, 3, 3, 1280])
scaled residual stack shape: torch.Size([74, 3, 1280])


In [264]:
per_head_residual, labels = original_cache.stack_head_results(layer=-1, pos_slice=final_token_indices, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, original_cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now
scaled_residual_stack shape: torch.Size([720, 3, 3, 1280])
scaled residual stack shape: torch.Size([720, 3, 1280])


In [265]:
def visualize_attention_patterns(
    heads: Union[List[int], int, TT["heads"]], 
    local_cache: Optional[ActivationCache]=None, 
    local_tokens: Optional[torch.Tensor]=None, 
    title: str=""):
    # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)
    if isinstance(heads, int):
        heads = [heads]
    elif isinstance(heads, list) or isinstance(heads, torch.Tensor):
        heads = utils.to_numpy(heads)
    # Cache defaults to the original activation cache
    if local_cache is None:
        local_cache = original_cache
    # Tokens defaults to the tokenization of the first prompt (including the BOS token)
    if local_tokens is None:
        # The tokens of the first prompt
        local_tokens = original_tokens[0]
    
    labels = []
    patterns = []
    batch_index = 0
    for head in heads:
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    str_tokens = model.to_str_tokens(local_tokens)
    patterns = torch.stack(patterns, dim=-1)
    # Plot the attention patterns
    attention_vis = pysvelte.AttentionMulti(attention=patterns, tokens=str_tokens, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attention_vis.show()

In [266]:
top_k = 10
top_positive_logit_attr_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_positive_logit_attr_heads, title=f"Top {top_k} Positive Logit Attribution Heads")
top_negative_logit_attr_heads = torch.topk(-per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_negative_logit_attr_heads, title=f"Top {top_k} Negative Logit Attribution Heads")

In [273]:
corrupted_prompts = [
    "I climbed up the pear tree and picked a pear. I climbed up the lemon tree and picked",
    "I climbed up the pear tree and picked a pear. I climbed up the lime tree and picked",
    "To fight the Civil War, the king will need to write"
]
corrupted_long_prompts = corrupted_prompts[0:2]
corrupted_short_prompts = corrupted_prompts[2:3]
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_long_tokens = model.to_tokens(corrupted_long_prompts, prepend_bos=True)
corrupted_short_tokens = model.to_tokens(corrupted_short_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
corrupted_long_logits, corrupted_long_cache = model.run_with_cache(corrupted_long_tokens, return_type="logits")
corrupted_short_logits, corrupted_short_cache = model.run_with_cache(corrupted_short_tokens, return_type="logits")
corrupted_average_logit_diff = logits_to_ave_correct_incorrect_logit_diff(corrupted_logits, corrupted_prompts, answer_and_wrong_answer_tokens)
corrupted_long_average_logit_diff = logits_to_ave_correct_incorrect_logit_diff(corrupted_long_logits, corrupted_long_prompts, long_answer_and_wrong_answer_tokens)
corrupted_short_average_logit_diff = logits_to_ave_correct_incorrect_logit_diff(corrupted_short_logits, corrupted_short_prompts, short_answer_and_wrong_answer_tokens)
print("Corrupted Average Logit Diff", corrupted_average_logit_diff)
print("Clean Average Logit Diff", original_average_logit_diff)
print("Corrupted Long Average Logit Diff", corrupted_long_average_logit_diff)
print("Clean Long Average Logit Diff", original_long_average_logit_diff)
print("Corrupted Short Average Logit Diff", corrupted_short_average_logit_diff)
print("Clean Short Average Logit Diff", original_short_average_logit_diff)

Corrupted Average Logit Diff tensor(-2.6100)
Clean Average Logit Diff tensor(0.7304)
Corrupted Long Average Logit Diff tensor(-2.7857)
Clean Long Average Logit Diff tensor(0.7815)
Corrupted Short Average Logit Diff tensor(-2.2586)
Clean Short Average Logit Diff tensor(0.6282)


In [274]:
def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff)/(original_average_logit_diff - corrupted_average_logit_diff)

In [279]:
def patch_residual_component(
    corrupted_residual_component: TT["batch", "pos", "d_model"],
    hook, 
    pos, 
    clean_cache):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component

for this_tokens, this_prompts, this_cache, this_answer_wrong_answer_tokens in [(corrupted_long_tokens, corrupted_long_prompts, original_long_cache, long_answer_and_wrong_answer_tokens), (corrupted_short_tokens, corrupted_short_prompts, original_short_cache, short_answer_and_wrong_answer_tokens)]:
    patched_residual_stream_diff = torch.zeros(model.cfg.n_layers, this_tokens.shape[1], device=device, dtype=torch.float32)
    for layer in range(model.cfg.n_layers):
        for position in range(this_tokens.shape[1]):
            hook_fn = partial(patch_residual_component, pos=position, clean_cache=this_cache)
            patched_logits = model.run_with_hooks(
                this_tokens,
                fwd_hooks = [(utils.get_act_name("resid_pre", layer), hook_fn)],
                return_type="logits"
            )
            patched_logit_diff = logits_to_ave_correct_incorrect_logit_diff(patched_logits, this_prompts, this_answer_wrong_answer_tokens)

            patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(patched_logit_diff)

    prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(this_tokens[0]))]
    imshow(patched_residual_stream_diff, x=prompt_position_labels, title="Logit Difference From Patched Residual Stream", labels={"x":"Position", "y":"Layer"})

In [284]:
for this_tokens, this_prompts, this_cache, this_answer_wrong_answer_tokens in [(corrupted_long_tokens, corrupted_long_prompts, original_long_cache, long_answer_and_wrong_answer_tokens), (corrupted_short_tokens, corrupted_short_prompts, original_short_cache, short_answer_and_wrong_answer_tokens)]:
    patched_attn_diff = torch.zeros(model.cfg.n_layers, this_tokens.shape[1], device=device, dtype=torch.float32)
    patched_mlp_diff = torch.zeros(model.cfg.n_layers, this_tokens.shape[1], device=device, dtype=torch.float32)
    for layer in range(model.cfg.n_layers):
        for position in range(this_tokens.shape[1]):
            hook_fn = partial(patch_residual_component, pos=position, clean_cache=this_cache)
            patched_attn_logits = model.run_with_hooks(
                this_tokens, 
                fwd_hooks = [(utils.get_act_name("attn_out", layer), 
                    hook_fn)], 
                return_type="logits"
            )
            patched_attn_logit_diff = logits_to_ave_correct_incorrect_logit_diff(patched_attn_logits, this_prompts, this_answer_wrong_answer_tokens)
            patched_mlp_logits = model.run_with_hooks(
                this_tokens, 
                fwd_hooks = [(utils.get_act_name("mlp_out", layer), 
                    hook_fn)], 
                return_type="logits"
            )
            patched_mlp_logit_diff = logits_to_ave_correct_incorrect_logit_diff(patched_mlp_logits, this_prompts, this_answer_wrong_answer_tokens)

            patched_attn_diff[layer, position] = normalize_patched_logit_diff(patched_attn_logit_diff)
            patched_mlp_diff[layer, position] = normalize_patched_logit_diff(patched_mlp_logit_diff)
    prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(this_tokens[0]))]
    imshow(patched_attn_diff, x=prompt_position_labels, title="Logit Difference From Patched Attention Layer", labels={"x":"Position", "y":"Layer"})
    imshow(patched_mlp_diff, x=prompt_position_labels, title="Logit Difference From Patched MLP Layer", labels={"x":"Position", "y":"Layer"})

In [248]:
def patch_head_vector(
    corrupted_head_vector: TT["batch", "pos", "head_index", "d_head"],
    hook, 
    head_index, 
    clean_cache):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][:, :, head_index, :]
    return corrupted_head_vector


patched_head_z_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("z", layer, "attn"), hook_fn)],
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_correct_incorrect_logit_diff(patched_logits, corrupted_prompts, answer_and_wrong_answer_tokens)

        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff)

In [250]:
imshow(patched_head_z_diff, title="Logit Difference From Patched Head Output", labels={"x":"Head", "y":"Layer"})

In [252]:
patched_head_v_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("v", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_correct_incorrect_logit_diff(patched_logits, corrupted_prompts, answer_and_wrong_answer_tokens)

        patched_head_v_diff[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff)

In [253]:
imshow(patched_head_v_diff, title="Logit Difference From Patched Head Value", labels={"x":"Head", "y":"Layer"})

In [254]:
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
scatter(
    x=utils.to_numpy(patched_head_v_diff.flatten()), 
    y=utils.to_numpy(patched_head_z_diff.flatten()), 
    xaxis="Value Patch",
    yaxis="Output Patch",
    caxis="Layer",
    hover_name = head_labels,
    color=einops.repeat(np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads),
    range_x=(-0.5, 0.5),
    range_y=(-0.5, 0.5),
    title="Scatter plot of output patching vs value patching")

In [266]:
def patch_head_pattern(
    corrupted_head_pattern: TT["batch", "head_index", "query_pos", "d_head"],
    hook, 
    head_index, 
    clean_cache):
    corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][:, head_index, :, :]
    return corrupted_head_pattern

patched_head_attn_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("attn", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_correct_incorrect_logit_diff(patched_logits, answer_and_wrong_answer_tokens)

        patched_head_attn_diff[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff)

In [267]:
imshow(patched_head_attn_diff, title="Logit Difference From Patched Head Pattern", labels={"x":"Head", "y":"Layer"})
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
scatter(
    x=utils.to_numpy(patched_head_attn_diff.flatten()), 
    y=utils.to_numpy(patched_head_z_diff.flatten()), 
    hover_name = head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching")

### Patch MLP Layer 31 by neuron

In [None]:
def patch_neuron_activation(
    corrupted_residual_component: TT["batch", "pos", "d_mlp"],
    hook, 
    neuron, 
    clean_cache):
    corrupted_residual_component[:, :, neuron] = clean_cache[hook.name][:, :, neuron]
    return corrupted_residual_component

patched_neurons_normalized_improvement = torch.zeros(model.cfg.d_mlp, device=device, dtype=torch.float32)
layer = 31
max_neurons = 10000
for neuron in range(model.cfg.d_mlp)[:max_neurons]:
    hook_fn = partial(patch_neuron_activation, neuron=neuron, clean_cache=cache)
    patched_neuron_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks = [("blocks.31.mlp.hook_post", hook_fn)],
        return_type="logits"
    )
    patched_neuron_correct_incorrect_logit_diff = logits_to_ave_correct_incorrect_logit_diff(patched_neuron_logits, answer_and_wrong_answer_tokens)

    patched_neurons_normalized_improvement[neuron] = normalize_patched_logit_diff(patched_neuron_correct_incorrect_logit_diff)

In [28]:

line(patched_neurons_normalized_improvement[:max_neurons], x=list(range(len(patched_neurons_normalized_improvement))), title="Logit Difference From Patched Neurons in MLP Layer 31", labels={"x":"neuron", "y":"Patch Improvement"})

In [39]:
line(model.blocks[31].mlp.W_in.T[892])

In [287]:
weight_out_for_special_neuron = model.blocks[31].mlp.W_out[892]

weight_in_for_special_neuron = model.blocks[31].mlp.W_in[:, 892]
print('weight_in_for_special_neuron', weight_in_for_special_neuron)
print('weight_in_for_special_neuron', weight_in_for_special_neuron.shape)

weight_in_for_special_neuron tensor([-0.0748,  0.0727, -0.0372,  ...,  0.0109,  0.0110, -0.0261],
       requires_grad=True)
weight_in_for_special_neuron torch.Size([1280])


### Analyze the input and output weights of the MLP in layer 31

In [200]:
# Change the above to a for loop
# words = ['an', ' an', ' picked', ' hello', ' in', ' but', ' cat', ' a', ' had', ' have', ' slog', ' beg', ' al', ' veg', ' soybe', ' cooked', ' hit', ' milkm', ' make', ' watch', ' carton']
words = [' an', ' a', ' cat', ' cooked']
for in_or_out in ['IN', 'OUT']:
    print(f'word dot weight {in_or_out} for special_neuron')
    for word in words:
        token_index = model.tokenizer(word)['input_ids'][0]
        embedding = model.embed.W_E[token_index]
        if in_or_out == 'IN':
            print(f"'{word}'",f"= {torch.dot(embedding, weight_in_for_special_neuron).item():.2f}")
        else:
            print(f"'{word}'",f"= {torch.dot(embedding, weight_out_for_special_neuron).item():.2f}")
    print()

word dot weight IN for special_neuron
' an' = 0.55
' a' = -0.04
' cat' = -0.00
' cooked' = 0.17

word dot weight OUT for special_neuron
' an' = 2.75
' a' = 0.17
' cat' = 0.18
' cooked' = 0.07



### Ablate / Modify the activation of the 892nd neuron in the 31st layer of the MLP

In [169]:

# ablation_prompts = ["After the football match, we return to the studio for post match. Cathode and anode. Analysis"]
ablation_prompts = ["I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked",
                    "I climbed up the pear tree and picked a pear. I climbed up the lemon tree and picked"]
# ablation_prompts = ["Let's understand the cathode and"]
prompt_str_tokens = model.to_str_tokens(ablation_prompts, prepend_bos=True)
print(prompt_str_tokens)
ablation_tokens = model.to_tokens(ablation_prompts, prepend_bos=True).to(device=device)
default_neuron_logits = model(ablation_tokens, return_type="logits")
print_predictions_from_logits(default_neuron_logits, ablation_prompts, show_raw_logits=True)

[['<|endoftext|>', 'I', ' climbed', ' up', ' the', ' pear', ' tree', ' and', ' picked', ' a', ' pear', '.', ' I', ' climbed', ' up', ' the', ' apple', ' tree', ' and', ' picked'], ['<|endoftext|>', 'I', ' climbed', ' up', ' the', ' pear', ' tree', ' and', ' picked', ' a', ' pear', '.', ' I', ' climbed', ' up', ' the', ' lemon', ' tree', ' and', ' picked']]
prompt
I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked
' an' - 64.92% - 20.52
' a' - 24.22% - 19.53
' apples' - 2.78% - 17.37
' two' - 2.43% - 17.23
' another' - 2.07% - 17.07
' some' - 0.71% - 16.00
' three' - 0.40% - 15.44
' up' - 0.31% - 15.17
' the' - 0.25% - 14.97
' one' - 0.24% - 14.91
prompt
I climbed up the pear tree and picked a pear. I climbed up the lemon tree and picked
' a' - 85.11% - 20.22
' an' - 3.18% - 16.93
' two' - 2.52% - 16.70
' some' - 2.05% - 16.50
' another' - 1.68% - 16.29
' up' - 1.19% - 15.95
' three' - 0.62% - 15.31
' lemon' - 0.38% - 14.80
' the' - 0.36% - 14.76
' one'

In [162]:
def zero_neuron_activation(corrupted_residual_component: TT["batch", "pos", "d_mlp"], hook, neuron):
    corrupted_residual_component[:, :, neuron] *= -1
    return corrupted_residual_component

hook_fn = partial(zero_neuron_activation, neuron=892)
model.reset_hooks()
zeroed_neuron_logits = model.run_with_hooks(
    ablation_tokens,
    fwd_hooks = [("blocks.31.mlp.hook_post", hook_fn)],
    return_type="logits"
)
print_predictions_from_logits(zeroed_neuron_logits, ablation_prompts, show_raw_logits=True)

prompt
I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked
' a' - 83.45% - 20.70
' apples' - 3.81% - 17.61
' an' - 2.85% - 17.33
' two' - 2.61% - 17.24
' another' - 2.59% - 17.23
' some' - 0.96% - 16.23
' up' - 0.51% - 15.61
' three' - 0.51% - 15.60
' apple' - 0.32% - 15.13
' one' - 0.31% - 15.12
prompt
I climbed up the pear tree and picked a pear. I climbed up the lemon tree and picked
' a' - 86.36% - 20.26
' two' - 2.44% - 16.69
' an' - 2.28% - 16.63
' some' - 1.97% - 16.48
' another' - 1.61% - 16.28
' up' - 1.17% - 15.96
' three' - 0.61% - 15.31
' lemon' - 0.36% - 14.78
' the' - 0.35% - 14.75
' one' - 0.34% - 14.71


### Modify the weights of the 892nd neuron in the 31st layer of the MLP

In [159]:
modified_model = HookedTransformer.from_pretrained(
    "gpt2-large",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer


In [160]:
modified_model.blocks[31].mlp.W_in[:, 892] *= -1
print(modified_model.blocks[31].mlp.W_in[:, 892].shape)
print(modified_model.blocks[31].mlp.W_in[:, 892])

torch.Size([1280])
tensor([ 0.0748, -0.0727,  0.0372,  ..., -0.0109, -0.0110,  0.0261],
       requires_grad=True)


In [161]:
default_neuron_logits = modified_model(ablation_tokens, return_type="logits")
print_predictions_from_logits(default_neuron_logits, ablation_prompts, show_raw_logits=True)

prompt
I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked
' a' - 58.31% - 20.24
' an' - 26.18% - 19.44
' apples' - 4.63% - 17.70
' two' - 2.97% - 17.26
' another' - 2.96% - 17.26
' some' - 1.03% - 16.20
' three' - 0.53% - 15.53
' up' - 0.49% - 15.46
' apple' - 0.39% - 15.22
' one' - 0.32% - 15.03
prompt
I climbed up the pear tree and picked a pear. I climbed up the lemon tree and picked
' a' - 85.99% - 20.26
' an' - 2.58% - 16.75
' two' - 2.46% - 16.70
' some' - 2.00% - 16.49
' another' - 1.63% - 16.29
' up' - 1.18% - 15.96
' three' - 0.61% - 15.31
' lemon' - 0.35% - 14.77
' the' - 0.35% - 14.76
' one' - 0.34% - 14.71


### Dot product of the residual and the input weights to the 892nd neuron in the 31st layer of the MLP

In [208]:
def resid_dot_special_neuron_input_weights(resid, per_prompt=False):
    # Only the final logits are relevant for the answer
    print('resid', resid.shape)
    print('weight_in_for_special_neuron', weight_in_for_special_neuron.shape)
    dot_prod = einsum("b d, d -> b", resid, weight_in_for_special_neuron)
    print('dot prod', dot_prod.shape)
    if per_prompt:
        return dot_prod
    else:
        return dot_prod.mean()

In [210]:
final_residual_stream = cache["resid_post", -1]
final_token_residual_stream = final_residual_stream[:, -1, :]
print("Final token stream shape:", final_token_residual_stream.shape)
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)
scaled_final_resid_dot_special_neuron_input_weights = resid_dot_special_neuron_input_weights(scaled_final_token_residual_stream, per_prompt=True)
print("Scaled final token stream dot product with special neuron input weights:", scaled_final_resid_dot_special_neuron_input_weights)

# average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
# print("Calculated average logit diff:", average_logit_diff.item())
# print("Original logit difference:",original_average_logit_diff.item())

Final token stream shape: torch.Size([2, 1280])
resid torch.Size([2, 1280])
weight_in_for_special_neuron torch.Size([1280])
dot prod torch.Size([2])
Scaled final token stream dot product with special neuron input weights: tensor([ 5.8166, -0.0824])


In [304]:
def residual_stack_dot_special_neuron_input_weights(residual_stack: TT["components", "batch", "d_model"], cache: ActivationCache, pos_slice) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=pos_slice)
    dot_prod = einsum("... batch d_model, d_model -> ...", scaled_residual_stack, weight_in_for_special_neuron)/len(corrupted_long_prompts)
    return dot_prod

In [317]:
# accumulated_residual, labels = original_long_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
# logit_lens_logit_diffs = residual_stack_dot_special_neuron_input_weights(accumulated_residual, original_long_cache, pos_slice=-1)
# line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Dot Product of Special Neuron Input Weights with Accumulated Residual Stream (Final Token)")

# accumulated_residual, labels = original_long_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-5, return_labels=True)
# logit_lens_logit_diffs = residual_stack_dot_special_neuron_input_weights(accumulated_residual, original_long_cache, pos_slice=-5)
# line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Dot Product of Special Neuron Input Weights with Accumulated Residual Stream (Tree type token)")

# accumulated_residual, labels = original_long_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-4, return_labels=True)
# logit_lens_logit_diffs = residual_stack_dot_special_neuron_input_weights(accumulated_residual, original_long_cache, pos_slice=-4)
# line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Dot Product of Special Neuron Input Weights with Accumulated Residual Stream (Tree type token)")

# accumulated_residual, labels = original_long_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-3, return_labels=True)
# logit_lens_logit_diffs = residual_stack_dot_special_neuron_input_weights(accumulated_residual, original_long_cache, pos_slice=-3)
# line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Dot Product of Special Neuron Input Weights with Accumulated Residual Stream (Tree type token)")

# Convert the above to a loop that plots on one graph
logit_lens_logit_diffs = []
labels = None

prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(original_long_prompts[0]))]
for pos_slice in range(len(prompt_position_labels)):
    accumulated_residual, labels = original_long_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=pos_slice, return_labels=True)
    logit_lens_logit_diffs.append(residual_stack_dot_special_neuron_input_weights(accumulated_residual, original_long_cache, pos_slice=pos_slice))

import pandas as pd
pd.options.plotting.backend = "plotly"
lines = torch.stack(logit_lens_logit_diffs).T
lines_np = lines.numpy()
lines_df = pd.DataFrame(lines_np)
# Label the columns with prompt positions
# lines_df.index = prompt_position_labels
lines_df.index = labels
lines_df.columns = prompt_position_labels
print(lines_df)
# lines_df.set_index('Group')
lines_df.plot()
# Add the labels as a column
# lines_df['layers'] = labels
# print('df', lines_df)
# Plot each token in a different with the x axis being the layer and the y axis being the dot product
# px.line(lines_df, line_group=lines_df.index, x=labels)

# px.line(lines_df, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Dot Product of Special Neuron Input Weights with Accumulated Residual Stream")
# line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Dot Product of Special Neuron Input Weights with Accumulated Residual Stream")

            <|endoftext|>_0       I_1   climbed_2      up_3     the_4  \
0_pre             -0.022101 -0.004150   -0.004878 -0.002987 -0.005285   
0_mid             -0.012796  0.005146    0.014915  0.002726  0.006215   
1_pre             -0.065868  0.023181    0.018384  0.005726  0.016961   
1_mid             -0.036645  0.018500    0.013130  0.005238 -0.006929   
2_pre             -0.239429  0.007985    0.036243 -0.005935 -0.022950   
...                     ...       ...         ...       ...       ...   
34_pre            -1.532895 -1.091031   -0.870601 -0.485370 -1.779105   
34_mid            -1.351125 -1.089431   -0.938128 -0.566494 -1.859525   
35_pre            -1.149317 -1.379076   -0.973520 -0.549864 -1.741600   
35_mid            -0.901680 -1.355920   -0.900106 -0.506206 -1.714589   
final_post        -1.218864 -1.599632   -1.084077 -0.441529 -1.496685   

              pear_5    tree_6     and_7   picked_8       a_9   pear_10  \
0_pre       0.025879  0.001566 -0.010109   0.023

In [219]:
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_dot_special_neuron_input_weights(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Dot Product of Special Neuron Input Weights with Residual Stream - Difference From Each Layer")

scaled residual stack torch.Size([74, 2, 1280])
weight_in_for_special_neuron torch.Size([1280])
dot prod torch.Size([74])


In [221]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs.T, labels={"x":"Layer", "y":"Head"}, title="Dot Product of Special Neuron Input Weights with Residual Stream - Difference From Each Head")