# We Found An Neuron

<a target="_blank" href="https://colab.research.google.com/github/UFO-101/an-neuron/blob/main/an_neuron_investigation.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Setup

In [1]:
!nvidia-smi -L

GPU 0: NVIDIA A100-SXM4-40GB (UUID: GPU-e3d110c9-ecca-3497-0e60-151e7f746be5)


In [None]:
import plotly.io as pio
try:
    import google.colab
    print("Running as a Colab notebook")
    pio.renderers.default = "colab"
    %pip install transformer-lens fancy-einsum
    %pip install -U kaleido
except:
    print("Running as a Jupyter notebook")
    pio.renderers.default = "vscode"
    from IPython import get_ipython
    ipython = get_ipython()

In [4]:
import torch
from fancy_einsum import einsum
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils, ActivationCache
from torchtyping import TensorType as TT
import plotly.express as px
import numpy as np
import einops
from typing import List, Union, Optional
from functools import partial
import pandas as pd
from pathlib import Path
import urllib.request
from bs4 import BeautifulSoup
from tqdm import tqdm
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # https://stackoverflow.com/q/62691279
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Discovering the Neuron

Let's load GPT-2 Large. We use GPT-2 Large because it is frustratingly hard to find prompts that have ' an' as the next predicted token with smaller models; or maybe they just don't have the capability to do so.

In [5]:
model = HookedTransformer.from_pretrained(
    "gpt2-large",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer


### Choosing the prompt
Even with GPT-2 Large, we have to use an IOI-like prompt `I climbed up the pear tree and picked a pear` in order to induce the model to predict ` an`. Without this sentence, GPT-2 Large would predict `<picked... up>`

In [6]:
an_tok, a_tok = model.to_single_token(" an"), model.to_single_token(" a")
prompts = ["I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked"]
utils.test_prompt(prompts[0], " an", model, prepend_bos=True, top_k=5)

Tokenized prompt: ['<|endoftext|>', 'I', ' climbed', ' up', ' the', ' pear', ' tree', ' and', ' picked', ' a', ' pear', '.', ' I', ' climbed', ' up', ' the', ' apple', ' tree', ' and', ' picked']
Tokenized answer: [' an']


Top 0th token. Logit: 20.52 Prob: 64.92% Token: | an|
Top 1th token. Logit: 19.53 Prob: 24.22% Token: | a|
Top 2th token. Logit: 17.37 Prob:  2.78% Token: | apples|
Top 3th token. Logit: 17.23 Prob:  2.43% Token: | two|
Top 4th token. Logit: 17.07 Prob:  2.07% Token: | another|


In [7]:
def ave_correct_incorrect_logit_diff(logits, correct_tok, incorrect_tok, per_prompt=False):
    """Returns the logit difference between the correct and incorrect answer tokens."""
    final_token_logits = logits[:, -1, :] # Only the final logits are relevant for the answer
    answer_logits = final_token_logits.gather(dim=-1, index=torch.tensor([[correct_tok, incorrect_tok]], device=device))
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

tokens = model.to_tokens(prompts, prepend_bos=True).to(device=device)
original_logits, cache = model.run_with_cache(tokens)
original_average_logit_diff = ave_correct_incorrect_logit_diff(original_logits, correct_tok=an_tok, incorrect_tok=a_tok)
print("' an' / ' a' logit difference:", original_average_logit_diff.item())

' an' / ' a' logit difference: 0.9860363006591797


### Logit Lens

In [8]:
answer_residual_directions = model.tokens_to_residual_directions(torch.tensor([(an_tok, a_tok)]))
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]

def residual_stack_to_logit_diff(residual_stack: TT["components", "batch", "d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)

accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
logit_lens_fig = px.line(y=logit_lens_logit_diffs.cpu(), x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels,
               title="Logit Difference From Residual Stream", labels={"x": "Layer", "y": "Logit Difference"})
logit_lens_fig.add_annotation(x=31.5, y=original_average_logit_diff.item(), text="Logit Difference spikes to 1.09",
                   showarrow=True, arrowhead=1, ax=-30, ay=-40)

### Activation Patching by the Layer

In [9]:
corrupted_prompts = ["I climbed up the pear tree and picked a pear. I climbed up the lemon tree and picked"]
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
corrupted_average_logit_diff = ave_correct_incorrect_logit_diff(corrupted_logits, correct_tok=an_tok, incorrect_tok=a_tok)
print("Corrupted Average Logit Diff", corrupted_average_logit_diff)
print("Clean Average Logit Diff", original_average_logit_diff)

Corrupted Average Logit Diff tensor(-3.2884, device='cuda:0')
Clean Average Logit Diff tensor(0.9860, device='cuda:0')


In [10]:
def patch_resid(corrupted_resid: TT["batch", "pos", "d_model"], hook, pos, clean_cache):
    corrupted_resid[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_resid

def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement,
    # divide by the total improvement from clean to corrupted to normalise.
    # 0 means zero change, negative means actively made worse,
    # 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff)/(original_average_logit_diff - corrupted_average_logit_diff)

patched_residual_stream_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32, device=device)
for layer in tqdm(range(model.cfg.n_layers)):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_resid, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("resid_pre", layer), hook_fn)],
            return_type="logits"
        )
        patched_logit_diff = ave_correct_incorrect_logit_diff(patched_logits, correct_tok=an_tok, incorrect_tok=a_tok)
        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(patched_logit_diff)

100%|██████████| 36/36 [00:51<00:00,  1.42s/it]


In [11]:
def imshow_fig(tensor, renderer=None, **kwargs):
    return px.imshow(tensor.cpu(), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs)

prompt_token_strs =  [f"{i}. '{tok}'" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]
print(len(prompt_token_strs))
print(patched_residual_stream_diff.shape)
patched_resid_fig = imshow_fig(patched_residual_stream_diff, x=prompt_token_strs,
           title="Logit Difference From Patched Residual Stream", labels={"x":"Token", "y":"Layer"})
patched_resid_fig.show()


20
torch.Size([36, 20])


In [12]:
patched_attn_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32, device=device)
patched_mlp_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32, device=device)
for layer in tqdm(range(model.cfg.n_layers)):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_resid, pos=position, clean_cache=cache)
        patched_attn_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("attn_out", layer), hook_fn)], 
            return_type="logits"
        )
        patched_attn_logit_diff = ave_correct_incorrect_logit_diff(patched_attn_logits, correct_tok=an_tok, incorrect_tok=a_tok)
        patched_mlp_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("mlp_out", layer), hook_fn)], 
            return_type="logits"
        )
        patched_mlp_logit_diff = ave_correct_incorrect_logit_diff(patched_mlp_logits, correct_tok=an_tok, incorrect_tok=a_tok)

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(patched_attn_logit_diff)
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(patched_mlp_logit_diff)

100%|██████████| 36/36 [01:41<00:00,  2.82s/it]


In [13]:
patched_attn_fig = imshow_fig(patched_attn_diff, x=prompt_token_strs, title="Logit Difference From Patched Attention Layer",
                              labels={"x":"Token", "y":"Layer"})
patched_attn_fig.add_annotation(x=18, y=26, text="Logit Difference spikes to 1.09", showarrow=True, arrowhead=1, ax=-150, ay=-0)
patched_attn_fig.show()

In [14]:
patched_mlp_fig = imshow_fig(patched_mlp_diff, x=prompt_token_strs, title="Logit Difference From Patched MLP Layer",
                             labels={"x":"Token", "y":"Layer"})
patched_mlp_fig.add_annotation(x=15, y=0, text="Significant Logit Diff. for Layer 0 MLP", showarrow=True, arrowhead=1, ax=-150, ay=-10)
patched_mlp_fig.add_annotation(x=18, y=31, text="Significant Logit Diff. for Layer 31 MLP", showarrow=True, arrowhead=1, ax=-150, ay=0)
patched_mlp_fig.show()

## Finding 1: For the apple tree prompt, there is a single neuron that significantly contributes to the “ an” prediction.

In [15]:
def patch_neuron_activation(corrupted_mlp_act: TT["batch", "pos", "d_mlp"], hook, neuron, clean_cache):
    corrupted_mlp_act[:, :, neuron] = clean_cache[hook.name][:, :, neuron]
    return corrupted_mlp_act

patched_neurons_normalized_improvement = torch.zeros(model.cfg.d_mlp, device=device, dtype=torch.float32)
from tqdm import tqdm
for neuron in tqdm(range(model.cfg.d_mlp)):
    hook_fn = partial(patch_neuron_activation, neuron=neuron, clean_cache=cache)
    patched_neuron_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks = [("blocks.31.mlp.hook_post", hook_fn)],
        return_type="logits"
    )
    patched_neuron_logit_diff = ave_correct_incorrect_logit_diff(patched_neuron_logits, correct_tok=an_tok, incorrect_tok=a_tok)
    patched_neurons_normalized_improvement[neuron] = normalize_patched_logit_diff(patched_neuron_logit_diff)

100%|██████████| 5120/5120 [06:07<00:00, 13.95it/s]


In [16]:
patched_neuron_fig = px.scatter(y=patched_neurons_normalized_improvement.cpu(),
        x=list(range(len(patched_neurons_normalized_improvement))), 
        title="Logit Difference From Patched Neurons in MLP Layer 31", 
        labels={"x":"Neuron", "y":"Patch Improvement"},
        )
patched_neuron_fig.add_annotation(x=1000, y=0.485, text="Neuron 892 stands out", showarrow=True, arrowhead=1, ax=50, ay=40)
patched_neuron_fig.show()

## Finding 2:  In other prompts, the neuron’s activation correlates with the “ an” token being predicted

Choose which token and neuron to investigate. Come back to this again at the end and try some intersting neurons from the final chart.

In [58]:
neuron_layer, neuron_index = 31, 892 # 28, 1921
token_of_interest_str = " an" # " though"
token_of_interest = model.to_single_token(token_of_interest_str)

We load a book as a series of prompts to gather lots of activation data for the neuron.

In [59]:
book_url = "https://github.com/asdkant/bookify-mol/releases/download/c108/Mother.of.Learning.-.nobody103.Domagoj.Kurmaic.html"
book_dir = Path('text_data')
book_dir.mkdir(exist_ok=True)
book_file = book_dir / "Mother_of_Learning.html"
if not book_file.exists():
    print("Downloading book...")
    urllib.request.urlretrieve(book_url, book_file)
soup = BeautifulSoup(book_file.read_text(), 'html.parser')
book_tokens = model.to_tokens(soup.get_text(), prepend_bos=False, truncate=False).to(device=device)

one_less_n_ctx = model.cfg.n_ctx - 1
truncated_length = (book_tokens.shape[1] // one_less_n_ctx) * one_less_n_ctx # Truncate to a multiple of n_ctx - 1
book_tokens = book_tokens[:, :truncated_length].view(-1, one_less_n_ctx)
bos_tokens = torch.full((book_tokens.shape[0], 1), 50256, dtype=torch.long, device=device)
book_tokens = torch.cat([bos_tokens, book_tokens], dim=1) # Add BOS (beginning of sequence) token to each prompt
token_of_interest_count = book_tokens.eq(token_of_interest).sum().item()
print(f"Instances of '{token_of_interest_str}' in book:", token_of_interest_count)
print(f"First few tokens:", model.to_str_tokens(book_tokens[0][:11]))

Instances of ' an' in book: 2251
First few tokens: ['<|endoftext|>', '\n', 'Arc', ' 1', '\n', 'Chapter', ' 1', ':', ' Good', ' Morning', ' Brother']


Run prompts through the model recording the neuron activations and the logit different between our token of interest and the top other token.

In [60]:
neuron_activation_cache = []
def save_neuron_activation(residual_component: TT["batch", "pos", "d_mlp"], hook, neuron):
    global neuron_activation_cache
    acts_to_save = residual_component[:, :, neuron]
    neuron_activation_cache.append(acts_to_save.flatten())

logit_diff_cache = []
def save_logit_diff(residual_component: TT["batch", "pos", "d_model"], hook):
    global logit_diff_cache
    output_logits = torch.einsum("vm, btm -> btv", model.embed.W_E, residual_component)
    logit_of_interest = output_logits[:, :, token_of_interest].clone()
    output_logits[:, :, token_of_interest] = -1e9
    max_other_logit = torch.max(output_logits, dim=-1).values
    logit_diff_cache.append((logit_of_interest - max_other_logit).flatten())

save_neuron_act_hook_fn = partial(save_neuron_activation, neuron=neuron_index)
model.reset_hooks()
for i in tqdm(range(book_tokens.shape[0])): # It's too big to run in one batch
    torch.cuda.empty_cache()
    model.run_with_hooks(
        book_tokens[i:i+1],
        fwd_hooks = [(f"blocks.{neuron_layer}.mlp.hook_post", save_neuron_act_hook_fn),
                     ("ln_final.hook_normalized", save_logit_diff)],
        return_type=None
    )

logit_diff = torch.cat(logit_diff_cache).cpu()
neuron_activations = torch.cat(neuron_activation_cache).cpu()
input_tokens = model.to_str_tokens(book_tokens.flatten(), prepend_bos=False)

100%|██████████| 1005/1005 [02:24<00:00,  6.94it/s]


In [61]:
activation_col = f"Layer {neuron_layer} Neuron {neuron_index} activation"
top_pred_col = "Top prediction"
logit_diff_col = f'"{token_of_interest_str}" Logit - Max Other Logit'

token_logit_neuron_act_df = pd.DataFrame({
    activation_col: neuron_activations,
    'input_tokens': input_tokens,
    top_pred_col: logit_diff > 0,
    logit_diff_col: logit_diff})
token_logit_neuron_act_df = token_logit_neuron_act_df[token_logit_neuron_act_df[activation_col] > 0.0]

logit_activation_fig = px.scatter(
    token_logit_neuron_act_df,
    x=activation_col,
    y=logit_diff_col,
    hover_name='input_tokens',
    color=top_pred_col,
    color_discrete_sequence=["blue", "red"],
    title=f'Layer {neuron_layer} Neuron {neuron_index} activations (>=0) vs. "{token_of_interest_str}" Logit subtract Top Other Logit')
logit_activation_fig.add_annotation(
        x=7, y=-13, showarrow=False,
        text=f"Top predictions: {(logit_diff > 0).sum().item()}<br>Actual occurances: {token_of_interest_count}",
        font=dict(size=14, color="#ffffff"), align="left",
        bordercolor="#c7c7c7", borderwidth=2, borderpad=4, bgcolor="grey"
    )
logit_activation_fig.show()

In [21]:
vocab_strs = [f"'{tok[:10]}'" for i, tok in enumerate(model.to_str_tokens(torch.arange(model.cfg.d_vocab)))]

weight_out_for_special_neuron = model.blocks[neuron_layer].mlp.W_out[neuron_index]
weight_out_affect_on_logits = weight_out_for_special_neuron @ model.unembed.W_U
neuron_dot_vocab_fig = px.scatter(x=vocab_strs, 
        y=weight_out_affect_on_logits.cpu(), 
        labels={"x":"Token", "y":"Dot Product"},
        hover_name=vocab_strs,
        title="Token Embeddings Dot Output Weights of Layer 31 MLP Neuron 892",
        )
sorted_weights = weight_out_affect_on_logits.sort(descending=True)
for i, (index, val) in enumerate(list(zip(sorted_weights.indices, sorted_weights.values))[:7]):
        neuron_dot_vocab_fig.add_annotation(x=index, y=val, text=vocab_strs[index], showarrow=True, ax=-4, ay=-9)
neuron_dot_vocab_fig.show()

## Finding 3: There are other similar neurons that predict for semantically similar tokens

In [22]:
mlp_output_weights = [block.mlp.W_out for block in model.blocks]
mlp_output_weights = torch.cat(mlp_output_weights, dim=0) # (n_layer * d_mlp, d_model)
embedding, neuron_weights = model.embed.W_E.clone().cpu(), mlp_output_weights.clone().cpu() # Too big for GPU memory
weight_similarity = torch.einsum("tk, nk -> tn", embedding, neuron_weights) # (n_tokens, n_layers * d_mlp)
top_2_weights = torch.topk(weight_similarity, 2, dim=1) # (n_tokens, 2)
print('top_2_weights', top_2_weights.indices.shape)
del weight_similarity # Free up some memory

top_2_weights torch.Size([50257, 2])


In [23]:
layer_indices = top_2_weights.indices // model.cfg.d_mlp
neuron_indices = top_2_weights.indices % model.cfg.d_mlp
top_2_weight_diff = top_2_weights.values[:, 0] - top_2_weights.values[:, 1]
hover_labels = [
    f"Layer: {layer_indices[i, 0]}, Neuron: {neuron_indices[i, 0]} - Layer: {layer_indices[i, 1]}, Neuron: {neuron_indices[i, 1]}"
    for i in range(len(top_2_weight_diff))
]

most_similar_neurons_fig = px.scatter(x=vocab_strs, y=top_2_weight_diff, hover_name=hover_labels,
           labels={'x': 'Token', 'y': 'Top 2 Dot Product Difference'},
           title="Difference between Top 2 Most Similar Neuron Output Weights for each Token Embedding")
sorted_weight_diffs = top_2_weight_diff.sort(descending=True)
for i, (index, val) in enumerate(list(zip(sorted_weight_diffs.indices, sorted_weight_diffs.values))[:21:3]):
        most_similar_neurons_fig.add_annotation(x=index, y=val, text=vocab_strs[index], showarrow=True,
                                                ax=110 * ((i + 1) % 2) - 55, ay=-12)
most_similar_neurons_fig.update_layout(yaxis = {"dtick": 1})
most_similar_neurons_fig.show()

### Save figures

In [62]:
all_figs = {
    "logit_lens_fig": logit_lens_fig,
    "patched_resid_fig": patched_resid_fig,
    "patched_attn_fig": patched_attn_fig,
    "patched_mlp_fig": patched_mlp_fig,
    "patched_neuron_fig": patched_neuron_fig,
    f"token_{token_of_interest_str.strip()}_layer_{neuron_layer}_neuron_{neuron_index}_logit_activation_fig": logit_activation_fig,
    f"token_{token_of_interest_str.strip()}_layer_{neuron_layer}_neuron_{neuron_index}_neuron_dot_vocab_fig": neuron_dot_vocab_fig,
    "most_similar_neurons_fig": most_similar_neurons_fig
}

fig_dir = Path('figures')
png_fig_dir, svg_fig_dir, html_fig_dir = fig_dir / "png", fig_dir / "svg", fig_dir / "html"
png_fig_dir.mkdir(parents=True, exist_ok=True)
svg_fig_dir.mkdir(parents=True, exist_ok=True)
html_fig_dir.mkdir(parents=True, exist_ok=True)

 # If this doesn't work, restart the runtime to ensure kaleido is installed :(
for name, fig in all_figs.items():
    fig.write_image(png_fig_dir / f"{name}.png", width=1048, height=512, scale=4)
    fig.write_image(svg_fig_dir / f"{name}.svg", scale=4)
    fig.write_html(html_fig_dir / f"{name}.html", include_plotlyjs="cdn")