# We Found An Neuron

## Setup

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "vscode"

In [3]:
import torch
from fancy_einsum import einsum
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils, ActivationCache
from torchtyping import TensorType as TT
import plotly.express as px
import numpy as np
import einops
from typing import List, Union, Optional
import pysvelte
from IPython.display import HTML
from functools import partial
import pandas as pd


IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html



In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x1074fcf40>

In [80]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def imshow_fig(tensor, renderer=None, **kwargs):
    return px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def line_fig(tensor, renderer=None, **kwargs):
    return px.line(y=utils.to_numpy(tensor), **kwargs)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def scatter_fig(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    return px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)

In [6]:
line(np.arange(5))

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

## Introduction

Let's set-up GPT-2 Large. We use GPT-2 Large because it is frustratingly hard to find prompts that have ' an' as the next predicted token with smaller models; or maybe they just don't have the capability to do so.

In [8]:
model = HookedTransformer.from_pretrained(
    "gpt2-large",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer


Even with GPT-2 Large, we have to use an IOI-like prompt `I climbed up the pear tree and picked a pear` in order to induce the model to predict ` an`. Without this sentence, GPT-2 Large would predict `<picked... up>`

In [9]:
import os
# bad_prompt = "I climbed up the apple tree and picked" # Outputs ' up'
example_prompt = "I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked" # Outputs ' an'
example_answer = " an"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'I', ' climbed', ' up', ' the', ' pear', ' tree', ' and', ' picked', ' a', ' pear', '.', ' I', ' climbed', ' up', ' the', ' apple', ' tree', ' and', ' picked']
Tokenized answer: [' an']


Top 0th token. Logit: 20.52 Prob: 64.92% Token: | an|
Top 1th token. Logit: 19.53 Prob: 24.22% Token: | a|
Top 2th token. Logit: 17.37 Prob:  2.78% Token: | apples|
Top 3th token. Logit: 17.23 Prob:  2.43% Token: | two|
Top 4th token. Logit: 17.07 Prob:  2.07% Token: | another|
Top 5th token. Logit: 16.00 Prob:  0.71% Token: | some|
Top 6th token. Logit: 15.44 Prob:  0.40% Token: | three|
Top 7th token. Logit: 15.17 Prob:  0.31% Token: | up|
Top 8th token. Logit: 14.97 Prob:  0.25% Token: | the|
Top 9th token. Logit: 14.91 Prob:  0.24% Token: | one|


In [86]:
model.to_single_token(' though')

996

In [10]:
prompts = ["I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked",]
answers = [ (" an", " a"), ]
answer_tokens = []
for correct_answer, wrong_answer in answers:
    answer_tokens.append((model.to_single_token(correct_answer), model.to_single_token(wrong_answer)))

answer_tokens = torch.tensor(answer_tokens) #.cuda() #TODO: Reinstate this when we have a GPU

print(prompts)
print(answers)
print(answer_tokens)

['I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked']
[(' an', ' a')]
tensor([[281, 257]])


In [11]:
def print_predictions_from_logits(logits_to_print, prediction_prompts, show_raw_logits=False, top_k=5, prepend_bos=True):
    """Code to print the top k predictions from a batch of logits, including the logits and probabilities."""
    sorted_logits, sorted_logits_indices = logits_to_print.sort(dim=-1, descending=True)
    probs = logits_to_print.softmax(dim=-1)
    sorted_probs, sorted_probs_indices = probs.sort(dim=-1, descending=True)
    for i, prompt in enumerate(prediction_prompts):
        print(f'prompt: {model.to_str_tokens(prompt, prepend_bos=prepend_bos)}')
        prompt_token_index = model.to_tokens(prompt).shape[1] - 1
        print('prompt_token_index:', prompt_token_index)
        for k in range(top_k):
            print(f"'{model.tokenizer.decode(sorted_probs_indices[i, prompt_token_index, k].item())}'", '-', f"{sorted_probs[i, prompt_token_index, k].item() * 100:.2f}%", end='')
            if show_raw_logits:
                print(' -', f"{sorted_logits[i, prompt_token_index, k].item():.2f}", end='')
            print()

In [12]:
tokens = model.to_tokens(prompts, prepend_bos=True).to(device=device)
original_logits, cache = model.run_with_cache(tokens)
print_predictions_from_logits(original_logits, prompts, show_raw_logits=True)

prompt: ['<|endoftext|>', 'I', ' climbed', ' up', ' the', ' pear', ' tree', ' and', ' picked', ' a', ' pear', '.', ' I', ' climbed', ' up', ' the', ' apple', ' tree', ' and', ' picked']
prompt_token_index: 19
' an' - 64.92% - 20.52
' a' - 24.22% - 19.53
' apples' - 2.78% - 17.37
' two' - 2.43% - 17.23
' another' - 2.07% - 17.07


In [13]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    """Returns the logit difference between the correct and incorrect answer tokens."""
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

print("Per prompt logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True))
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens).item())

Per prompt logit difference: tensor([0.9860])
Average logit difference: 0.9860420227050781


## Direct Logit Attribution

In [14]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([1, 2, 1280])
Logit difference directions shape: torch.Size([1, 1280])


In [15]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",original_average_logit_diff.item())

Final residual stream shape: torch.Size([1, 20, 1280])
Calculated average logit diff: 1.25741708278656
Original logit difference: 0.9860420227050781


## Logit Lens

In [16]:
def residual_stack_to_logit_diff(residual_stack: TT["components", "batch", "d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)

In [62]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
fig = line_fig(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulate Residual Stream", labels={"x": "Layer", "y": "Logit Difference"})
fig.add_annotation(x=31.5, y=original_average_logit_diff.item(), text="Logit Difference spikes to 1.09", showarrow=True, arrowhead=1, ax=-30, ay=-40)

In [18]:
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

## Activation Patching by the Layer

In [19]:
corrupted_prompts = [
    "I climbed up the pear tree and picked a pear. I climbed up the lemon tree and picked",
]
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", corrupted_average_logit_diff)
print("Clean Average Logit Diff", original_average_logit_diff)
print(corrupted_prompts)

Corrupted Average Logit Diff tensor(-3.2884)
Clean Average Logit Diff tensor(0.9860)
['I climbed up the pear tree and picked a pear. I climbed up the lemon tree and picked']


In [20]:
def patch_residual_component(
    corrupted_residual_component: TT["batch", "pos", "d_model"],
    hook, 
    pos, 
    clean_cache):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component

def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff)/(original_average_logit_diff - corrupted_average_logit_diff)

patched_residual_stream_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32) #device="cuda") # TODO: Re-enable CUDA on colab
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("resid_pre", layer), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(patched_logit_diff)

In [64]:
prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]
imshow(patched_residual_stream_diff, x=prompt_position_labels, title="Logit Difference From Patched Residual Stream", labels={"x":"Position", "y":"Layer"})

In [22]:
patched_attn_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32)#, device="cuda") # TODO: Re-enable CUDA on colab
patched_mlp_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32)#, device="cuda") # TODO: Re-enable CUDA on colab
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_attn_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("attn_out", layer), 
                hook_fn)], 
            return_type="logits"
        )
        patched_attn_logit_diff = logits_to_ave_logit_diff(patched_attn_logits, answer_tokens)
        patched_mlp_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("mlp_out", layer), 
                hook_fn)], 
            return_type="logits"
        )
        patched_mlp_logit_diff = logits_to_ave_logit_diff(patched_mlp_logits, answer_tokens)

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(patched_attn_logit_diff)
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(patched_mlp_logit_diff)

In [67]:
fig = imshow_fig(patched_attn_diff, x=prompt_position_labels, title="Logit Difference From Patched Attention Layer", labels={"x":"Position", "y":"Layer"})
fig.add_annotation(x=31.5, y=0, text="Logit Difference spikes to 1.09", showarrow=True, arrowhead=1, ax=-30, ay=-40)

NameError: name 'imshow_fig' is not defined

In [79]:
fig = imshow_fig(patched_mlp_diff, x=prompt_position_labels, title="Logit Difference From Patched MLP Layer", labels={"x":"Position", "y":"Layer"})
fig.add_annotation(x=18, y=31, text="Significant Logit Diff. for Layer 31 MLP", showarrow=True, arrowhead=1, ax=-150, ay=0)

## Activation Patching by the Neuron

In [41]:
def patch_neuron_activation(
    corrupted_residual_component: TT["batch", "pos", "d_mlp"],
    hook, 
    neuron, 
    clean_cache):
    corrupted_residual_component[:, :, neuron] = clean_cache[hook.name][:, :, neuron]
    return corrupted_residual_component

patched_neurons_normalized_improvement = torch.zeros(model.cfg.d_mlp, device=device, dtype=torch.float32)
layer = 31
max_neurons = 10000
for neuron in range(model.cfg.d_mlp)[:max_neurons]:
    hook_fn = partial(patch_neuron_activation, neuron=neuron, clean_cache=cache)
    patched_neuron_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks = [("blocks.31.mlp.hook_post", hook_fn)],
        return_type="logits"
    )
    patched_neuron_logit_diff = logits_to_ave_logit_diff(patched_neuron_logits, answer_tokens)

    patched_neurons_normalized_improvement[neuron] = normalize_patched_logit_diff(patched_neuron_logit_diff)

In [85]:

fig = scatter_fig(y=patched_neurons_normalized_improvement[:max_neurons],
        x=list(range(len(patched_neurons_normalized_improvement))), 
        title="Logit Difference From Patched Neurons in MLP Layer 31", 
        xaxis="neuron",
        yaxis="Patch Improvement",
        )

fig.add_annotation(x=1000, y=0.485, text="Neuron 892 stands out", showarrow=True, arrowhead=1, ax=50, ay=40)

In [44]:
weight_out_for_special_neuron = model.blocks[31].mlp.W_out[892]
weight_in_for_special_neuron = model.blocks[31].mlp.W_in[:, 892]
print('weight_out_for_special_neuron', weight_out_for_special_neuron)
print('weight_out_for_special_neuron', weight_out_for_special_neuron.shape)

weight_out_for_special_neuron tensor([-0.2095,  0.1758, -0.0947,  ...,  0.0269,  0.0166,  0.0753],
       requires_grad=True)
weight_out_for_special_neuron torch.Size([1280])


## Visualising the Neuron's input/output weights

In [45]:
weight_out_for_special_neuron = model.blocks[31].mlp.W_out[892]
prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(torch.arange(model.cfg.d_vocab)))]
weight_out_affect_on_logits = weight_out_for_special_neuron @ model.unembed.W_U
scatter(x=prompt_position_labels, 
        y=weight_out_affect_on_logits, 
        hover_name=prompt_position_labels,
        title="Output Weights @ Unembed Weights of Neuron 892 in MLP Layer 31",
        )

## Ablate / Modify the activation of the 892nd neuron in the 31st layer of the MLP

In [32]:

ablation_prompts = ["I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked",]
prompt_str_tokens = model.to_str_tokens(ablation_prompts, prepend_bos=True)
print(prompt_str_tokens)
ablation_tokens = model.to_tokens(ablation_prompts, prepend_bos=True).to(device=device)
default_neuron_logits = model(ablation_tokens, return_type="logits")
print_predictions_from_logits(default_neuron_logits, ablation_prompts, show_raw_logits=True)

[['<|endoftext|>', 'I', ' climbed', ' up', ' the', ' pear', ' tree', ' and', ' picked', ' a', ' pear', '.', ' I', ' climbed', ' up', ' the', ' apple', ' tree', ' and', ' picked']]
prompt: ['<|endoftext|>', 'I', ' climbed', ' up', ' the', ' pear', ' tree', ' and', ' picked', ' a', ' pear', '.', ' I', ' climbed', ' up', ' the', ' apple', ' tree', ' and', ' picked']
prompt_token_index: 19
' an' - 64.92% - 20.52
' a' - 24.22% - 19.53
' apples' - 2.78% - 17.37
' two' - 2.43% - 17.23
' another' - 2.07% - 17.07


In [33]:
def zero_neuron_activation(corrupted_residual_component: TT["batch", "pos", "d_mlp"], hook, neuron):
    corrupted_residual_component[:, :, neuron] *= -1
    return corrupted_residual_component

hook_fn = partial(zero_neuron_activation, neuron=892)
model.reset_hooks()
zeroed_neuron_logits = model.run_with_hooks(
    ablation_tokens,
    fwd_hooks = [("blocks.31.mlp.hook_post", hook_fn)],
    return_type="logits"
)
print_predictions_from_logits(zeroed_neuron_logits, ablation_prompts, show_raw_logits=True)

prompt: ['<|endoftext|>', 'I', ' climbed', ' up', ' the', ' pear', ' tree', ' and', ' picked', ' a', ' pear', '.', ' I', ' climbed', ' up', ' the', ' apple', ' tree', ' and', ' picked']
prompt_token_index: 19
' a' - 83.45% - 20.70
' apples' - 3.81% - 17.61
' an' - 2.85% - 17.33
' two' - 2.61% - 17.24
' another' - 2.59% - 17.23


## Ablating M31N892 by setting its input weights to zero

In [34]:
modified_model = HookedTransformer.from_pretrained(
    "gpt2-large",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer


In [35]:
modified_model.blocks[31].mlp.W_in[:, 892] *= 0
print(modified_model.blocks[31].mlp.W_in[:, 892].shape)
print(modified_model.blocks[31].mlp.W_in[:, 892])

torch.Size([1280])
tensor([-0., 0., -0.,  ..., 0., 0., -0.], requires_grad=True)


In [None]:
ablation_prompts = ["I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked",]
ablation_tokens = model.to_tokens(ablation_prompts, prepend_bos=True).to(device=device)
default_neuron_logits = modified_model(ablation_tokens, return_type="logits")
print_predictions_from_logits(default_neuron_logits, ablation_prompts, show_raw_logits=True)

## Moving the neuron to another layer

In [None]:
copy_neuron_model = HookedTransformer.from_pretrained(
    "gpt2-large",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

In [None]:
def zero_neuron_activation(corrupted_residual_component: TT["batch", "pos", "d_mlp"], hook, neuron):
    corrupted_residual_component[:, :, neuron] *= 0
    return corrupted_residual_component

neurn_to_ovrwrt = 892
tokens_to_plot = [" an", " a", " apples", " two", " another", " some", " up", " three", " one"]
moved_neuron_logits_list = []
ablate_both_logits_list = []
for i in range(copy_neuron_model.cfg.n_layers):
    print(f"{i}/{copy_neuron_model.cfg.n_layers}")
    copy_neuron_model.load_state_dict(model.state_dict())
    if i != 31:
        copy_neuron_model.blocks[i].mlp.W_in[:, neurn_to_ovrwrt] = weight_in_for_special_neuron
        copy_neuron_model.blocks[i].mlp.W_out[neurn_to_ovrwrt, :] = weight_out_for_special_neuron
        copy_neuron_model.blocks[31].mlp.W_in[:, 892] *= 0
    copy_neuron_logits = copy_neuron_model(ablation_tokens, return_type="logits")
    moved_neuron_logits_list.append([copy_neuron_logits[0, -1, copy_neuron_model.to_single_token(tok)] for tok in tokens_to_plot])

    ablate_neuron_892_hook_fn = partial(zero_neuron_activation, neuron=892)
    ablate_both_logits = model.run_with_hooks(
        ablation_tokens,
        fwd_hooks = [(f"blocks.31.mlp.hook_post", ablate_neuron_892_hook_fn),
                     (f"blocks.{i}.mlp.hook_post", ablate_neuron_892_hook_fn)],
        return_type="logits"
    )
    ablate_both_logits_list.append([ablate_both_logits[0, -1, copy_neuron_model.to_single_token(tok)] for tok in tokens_to_plot])


In [None]:
pd.options.plotting.backend = "plotly"

# x = pd.concat([pd.DataFrame(lgt_lst, columns=[f"{tkn} - {tkn_type}" for tkn in tokens_to_plot])
#                     for (lgt_lst, tkn_type) in
#                     [(moved_neuron_logits_list, "moved neuron"), (ablate_both_logits_list, "ablate both")]], axis=1)
# print(x)
# y = [pd.DataFrame(lgt_lst, columns=[f"{tkn} - {tkn_type}" for tkn in tokens_to_plot])
                    # for (lgt_lst, tkn_type) in
                    # [(moved_neuron_logits_list, "moved neuron"), (ablate_both_logits_list, "ablate both")]]
# fig = x.astype(float).plot()
results = pd.concat([pd.DataFrame(lgt_lst, columns=[f"{tkn} - {tkn_type}" for tkn in tokens_to_plot])
                    for (lgt_lst, tkn_type) in
                    [(moved_neuron_logits_list, "moved neuron"), (ablate_both_logits_list, "ablate both")]], axis=0)
a_an_results = results[[' an - moved neuron',' a - moved neuron', ' an - ablate both', ' a - ablate both']]
a_an_results['moved logit(an)-logit(a)'] = a_an_results[' an - moved neuron'] - a_an_results[' a - moved neuron']
a_an_results['control logit(an)-logit(a)'] = a_an_results[' an - ablate both'] - a_an_results[' a - ablate both']
a_an_results = a_an_results[['moved logit(an)-logit(a)', 'control logit(an)-logit(a)']]

fig = a_an_results.plot()
fig.add_vline(x=31)
fig.show()