<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/MLP_expms_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

## Setup

In [None]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-26d25ecd-0070-7ad2-e038-a2be15a3b0be)


In [None]:
import plotly.io as pio
try:
    import google.colab
    print("Running as a Colab notebook")
    pio.renderers.default = "colab"
    %pip install transformer-lens fancy-einsum
    %pip install -U kaleido # kaleido only works if you restart the runtime. Required to write figures to disk (final cell)
except:
    print("Running as a Jupyter notebook")
    pio.renderers.default = "vscode"
    from IPython import get_ipython
    ipython = get_ipython()

Running as a Colab notebook
Collecting transformer-lens
  Downloading transformer_lens-1.3.0-py3-none-any.whl (101 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.8/101.8 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting fancy-einsum
  Downloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)
Collecting datasets>=2.7.1 (from transformer-lens)
  Downloading datasets-2.13.1-py3-none-any.whl (486 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6.0 (from transformer-lens)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jaxtyping>=0.2.11 (from transformer-lens)
  Downloading jaxtyping-0.2.20-py3-none-any.whl (24 kB)
Collecting numpy>=1.23 (from transformer-lens)
  Downloading numpy-1.25.1-cp310-cp310-manylinux_2_17_

Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.9/79.9 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: kaleido
Successfully installed kaleido-0.2.1


In [None]:
!pip install 'torchtyping'

Collecting torchtyping
  Downloading torchtyping-0.1.4-py3-none-any.whl (17 kB)
Installing collected packages: torchtyping
Successfully installed torchtyping-0.1.4


In [None]:
import torch
from fancy_einsum import einsum
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils, ActivationCache
from torchtyping import TensorType as TT
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import einops
from typing import List, Union, Optional
from functools import partial
import pandas as pd
from pathlib import Path
import urllib.request
from bs4 import BeautifulSoup
from tqdm import tqdm
from datasets import load_dataset
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # https://stackoverflow.com/q/62691279
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Choose Model

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Change Prompt Inputs Here

In [None]:
prompts = ["John is tall. Mary is"]

# tok_1, tok_2 = model.to_single_token(" short"), model.to_single_token(" tall")
answers = [(" short", " tall"),
            ]

corrupted_prompts = ["John is short. Mary is"]

In [None]:
answer_tokens = []
for answer in answers:
    correct_token = model.to_single_token(answer[0])
    incorrect_token = model.to_single_token(answer[1])
    answer_tokens.append((correct_token, incorrect_token))
answer_tokens = torch.tensor(answer_tokens).cuda()

In [None]:
for i, prompt in enumerate(prompts):
    utils.test_prompt(prompt, answers[i][0], model, prepend_bos=True, top_k=3)

Tokenized prompt: ['<|endoftext|>', 'John', ' is', ' tall', '.', ' Mary', ' is']
Tokenized answer: [' short']


Top 0th token. Logit: 16.89 Prob: 21.23% Token: | short|
Top 1th token. Logit: 16.05 Prob:  9.19% Token: | thin|
Top 2th token. Logit: 15.89 Prob:  7.78% Token: | tall|


# Discovering the Neuron



In [None]:
tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.cuda() # Move the tokens to the GPU
original_logits, cache = model.run_with_cache(tokens) # Run the model and cache all activations

def ave_correct_incorrect_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

print("Per prompt logit difference:", ave_correct_incorrect_logit_diff(original_logits, answer_tokens, per_prompt=True))
original_average_logit_diff = ave_correct_incorrect_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", original_average_logit_diff.item())

Per prompt logit difference: tensor([1.0032], device='cuda:0')
Average logit difference: 1.0032310485839844


In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]

def residual_stack_to_logit_diff(residual_stack: TT["components", "batch", "d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)

def imshow_fig(tensor, renderer=None, **kwargs):
    return px.imshow(tensor.cpu(), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs)

## Activation Patching by the Layer

In [None]:
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
corrupted_average_logit_diff = ave_correct_incorrect_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", corrupted_average_logit_diff)
print("Clean Average Logit Diff", original_average_logit_diff)

Corrupted Average Logit Diff tensor(0.7222, device='cuda:0')
Clean Average Logit Diff tensor(1.0032, device='cuda:0')


In [None]:
def patch_resid(corrupted_resid: TT["batch", "pos", "d_model"], hook, pos, clean_cache):
    corrupted_resid[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_resid

def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement,
    # divide by the total improvement from clean to corrupted to normalise.
    # 0 means zero change, negative means actively made worse,
    # 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff)/(original_average_logit_diff - corrupted_average_logit_diff)

In [None]:
patched_mlp_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32, device=device)
for layer in tqdm(range(model.cfg.n_layers)):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_resid, pos=position, clean_cache=cache)
        patched_attn_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("attn_out", layer), hook_fn)],
            return_type="logits"
        )
        patched_mlp_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("mlp_out", layer), hook_fn)],
            return_type="logits"
        )
        patched_mlp_logit_diff = ave_correct_incorrect_logit_diff(patched_mlp_logits, answer_tokens)
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(patched_mlp_logit_diff)

100%|██████████| 12/12 [00:04<00:00,  2.78it/s]


In [None]:
prompt_token_strs =  [f"'{tok}'" + (chr(8203) if i > 9 else '') for i, tok in enumerate(model.to_str_tokens(tokens[0]))] # Add zero-width space because it combines columns with same key.
patched_mlp_fig = imshow_fig(patched_mlp_diff, x=prompt_token_strs, title="Logit Difference From Patched MLP Layer",
                             labels={"x":"Token", "y":"Layer"})
# patched_mlp_fig.add_annotation(x=9, y=25, text="Significant Logit Diff. for Layer 25 MLP", showarrow=True, arrowhead=1, ax=-150, ay=-10)
# patched_mlp_fig.add_annotation(x=7, y=0, text="Significant Logit Diff. for Layer 0 MLP", showarrow=True, arrowhead=1, ax=-150, ay=0)
patched_mlp_fig.show()

# Change MLP Layer Input Here

Take MLP layers that stand out from activation patching on MLPs, and patch all neurons within that layer.

In [None]:
layer_to_check = 9
hook_name = f"blocks.{layer_to_check}.mlp.hook_post"

In [None]:
# token_to_check
# patched_mlp_fig.add_annotation(x=token_to_check, y=layer_to_check, text="Significant Logit Diff. for MLP Layer"+str(layer_to_check), showarrow=True, arrowhead=1, ax=-150, ay=-10)

# Finding 1: We can discover predictive neurons by activation patching individual neurons

To run the following, you need:
1. Clean tokens
2. Run them through cache
3. Corrupted tokens
4. define normalize_patched_logit_diff

In [None]:
def patch_neuron_activation(corrupted_mlp_act: TT["batch", "pos", "d_mlp"], hook, neuron, clean_cache):
    corrupted_mlp_act[:, :, neuron] = clean_cache[hook.name][:, :, neuron]
    return corrupted_mlp_act

patched_neurons_normalized_improvement = torch.zeros(model.cfg.d_mlp, device=device, dtype=torch.float32)
from tqdm import tqdm
for neuron in tqdm(range(model.cfg.d_mlp)):
    hook_fn = partial(patch_neuron_activation, neuron=neuron, clean_cache=cache)
    patched_neuron_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks = [(hook_name, hook_fn)],
        return_type="logits"
    )

    patched_neuron_logit_diff = ave_correct_incorrect_logit_diff(patched_neuron_logits, answer_tokens)
    patched_neurons_normalized_improvement[neuron] = normalize_patched_logit_diff(patched_neuron_logit_diff)

100%|██████████| 3072/3072 [01:17<00:00, 39.56it/s]


In [None]:
patched_neuron_fig = px.scatter(y=patched_neurons_normalized_improvement.cpu(),
        x=list(range(len(patched_neurons_normalized_improvement))),
        title="Logit Difference From Patched Neurons in MLP Layer"+str(layer_to_check),
        labels={"x":"Neuron", "y":"Patch Improvement"},
        )
# patched_neuron_fig.add_annotation(x=1000, y=0.485, text="Neuron 892 stands out", showarrow=True, arrowhead=1, ax=50, ay=40)
patched_neuron_fig.show()

Record neurons that stand out

In [None]:
num_to_print = 5

# Sort the list while keeping the original indices
sorted_indices = sorted(range(len(patched_neurons_normalized_improvement)), key=lambda i: patched_neurons_normalized_improvement[i], reverse=True)[:num_to_print]
sorted_values = [patched_neurons_normalized_improvement[i] for i in sorted_indices]

# Print the original indices and values in two columns
print("Index\tValue")
for index, value in enumerate(sorted_values):
    original_index = sorted_indices[index]
    print(f"{original_index}\t{value}")


Index	Value
934	2.4503378868103027
840	0.6991933584213257
2684	0.6940759420394897
1436	0.4804687201976776
513	0.37432342767715454


# Finding 2: The activation of the _ neuron” correlates with the '_' token being predicted.

## Change Tokens and Neurons Input Here

Make sure tokens_of_interest_strs, neuron_layers, and neuron_indices are all the same size

In [None]:
tokens_of_interest_strs = [" short"]
neuron_layers = torch.tensor([9], device=device)
neuron_indices = torch.tensor([934], device=device)
tokens_of_interest = torch.tensor([model.to_single_token(token_str) for token_str in tokens_of_interest_strs], device=device)

In [None]:
dataset = load_dataset("NeelNanda/pile-10k", split="train")
dataset[0]['text'][:150]

Downloading metadata:   0%|          | 0.00/921 [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/373 [00:00<?, ?B/s]

Downloading and preparing dataset None/None (download: 31.72 MiB, generated: 58.43 MiB, post-processed: Unknown size, total: 90.15 MiB) to /root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/33.3M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7. Subsequent calls will reuse this data.


'It is done, and submitted. You can play “Survival of the Tastiest” on Android, and on the web. Playing on the web works, but you have to simulate mult'

## The neuron’s output weights have a high dot-product with which tokens?

In [None]:
vocab_strs = [f"'{tok[:10]}'" for i, tok in enumerate(model.to_str_tokens(torch.arange(model.cfg.d_vocab)))]

neuron_congruence_with_vocab_figs = []
for i, neuron_layer in enumerate(neuron_layers[:2]):
        weight_out_for_special_neuron = model.blocks[neuron_layer].mlp.W_out[neuron_indices[i]]
        weight_out_affect_on_logits = weight_out_for_special_neuron @ model.embed.W_E.T
        neuroon_congruence_with_vocab_fig = px.scatter(x=vocab_strs,
                y=weight_out_affect_on_logits.cpu(),
                labels={"x":"Token", "y":f"Congruence (W_out • Token)"},
                hover_name=vocab_strs,
                title=f"Layer {neuron_layers[i]} Neuron {neuron_indices[i]} Output Congruence for each Token",
                )
        sorted_weights = weight_out_affect_on_logits.sort(descending=True)
        for i, (index, val) in enumerate(list(zip(sorted_weights.indices, sorted_weights.values))[:7]):
                neuroon_congruence_with_vocab_fig.add_annotation(x=index, y=val, text=vocab_strs[index], showarrow=True, ax=-4, ay=-9)
        neuroon_congruence_with_vocab_fig.show()
        neuron_congruence_with_vocab_figs.append(neuroon_congruence_with_vocab_fig)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
mlp_output_weights = torch.cat([block.mlp.W_out for block in model.blocks], dim=0)# (n_layer * d_mlp, d_model)

token_congruence_with_each_neuron_figs = []
for i, token_of_interest_str in enumerate(tokens_of_interest_strs[:1]):
        token_of_interest_dot_product = torch.einsum("d, nd -> n", model.embed.W_E[tokens_of_interest[i]], mlp_output_weights)
        neuron_names = [f"Layer {i//model.cfg.d_mlp}" + (f" Neuron {i%model.cfg.d_mlp}" if i%model.cfg.d_mlp != 0 else "") for i in range(mlp_output_weights.shape[0])]
        token_congruence_with_each_neuron_fig = px.scatter(x=neuron_names,
                y=token_of_interest_dot_product.cpu(),
                labels={"x":"Neuron", "y":f"Congruence (W_out • Token)"},
                hover_name=neuron_names,
                title=f"Congruence of '{token_of_interest_str}' Token with each Neuron Output Weights",
        )
        token_congruence_with_each_neuron_fig.update_layout(xaxis={"dtick": model.cfg.d_mlp})
        neuron_total_index= (neuron_layers[i] * model.cfg.d_mlp) + neuron_indices[i]
        token_congruence_with_each_neuron_fig.add_annotation(x=neuron_total_index - 800,
                                                        y=token_of_interest_dot_product[neuron_total_index],
                                                        text=f"Layer {neuron_layers[i]} Neuron {neuron_indices[i]}", showarrow=True, ax=-100, ay=0)
        token_congruence_with_each_neuron_fig.show()
        token_congruence_with_each_neuron_figs.append(token_congruence_with_each_neuron_fig)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
tokens_of_interest_strs = [" tall"]
neuron_layers = torch.tensor([9], device=device)
neuron_indices = torch.tensor([934], device=device)
tokens_of_interest = torch.tensor([model.to_single_token(token_str) for token_str in tokens_of_interest_strs], device=device)

mlp_output_weights = torch.cat([block.mlp.W_out for block in model.blocks], dim=0) # (n_layer * d_mlp, d_model)

token_congruence_with_each_neuron_figs = []
for i, token_of_interest_str in enumerate(tokens_of_interest_strs[:1]):
        token_of_interest_dot_product = torch.einsum("d, nd -> n", model.embed.W_E[tokens_of_interest[i]], mlp_output_weights)
        neuron_names = [f"Layer {i//model.cfg.d_mlp}" + (f" Neuron {i%model.cfg.d_mlp}" if i%model.cfg.d_mlp != 0 else "") for i in range(mlp_output_weights.shape[0])]
        token_congruence_with_each_neuron_fig = px.scatter(x=neuron_names,
                y=token_of_interest_dot_product.cpu(),
                labels={"x":"Neuron", "y":f"Congruence (W_out • Token)"},
                hover_name=neuron_names,
                title=f"Congruence of '{token_of_interest_str}' Token with each Neuron Output Weights",
        )
        token_congruence_with_each_neuron_fig.update_layout(xaxis={"dtick": model.cfg.d_mlp})
        neuron_total_index= (neuron_layers[i] * model.cfg.d_mlp) + neuron_indices[i]
        token_congruence_with_each_neuron_fig.add_annotation(x=neuron_total_index - 800,
                                                        y=token_of_interest_dot_product[neuron_total_index],
                                                        text=f"Layer {neuron_layers[i]} Neuron {neuron_indices[i]}", showarrow=True, ax=-100, ay=0)
        token_congruence_with_each_neuron_fig.show()
        token_congruence_with_each_neuron_figs.append(token_congruence_with_each_neuron_fig)

Output hidden; open in https://colab.research.google.com to view.

# Finding 3: We can use neurons’ output congruence to find specific neurons that predict a token

In [None]:
# embedding, neuron_weights = model.embed.W_E.clone().cpu(), mlp_output_weights.clone().cpu() # Too big for GPU memory
# weight_similarity = torch.einsum("tk, nk -> tn", embedding, neuron_weights) # (n_tokens, n_layers * d_mlp)
# top_2_weights_for_each_token = torch.topk(weight_similarity, 2, dim=1) # (n_tokens, 2)
# print('top_2_weights_for_each_token', top_2_weights_for_each_token.indices.shape)

In [None]:
# layer_indices = top_2_weights_for_each_token.indices // model.cfg.d_mlp
# neuron_indices = top_2_weights_for_each_token.indices % model.cfg.d_mlp
# top_2_weight_diff_for_each_token = top_2_weights_for_each_token.values[:, 0] - top_2_weights_for_each_token.values[:, 1]
# neuron_labels = [
#     f"Layer: {layer_indices[i, 0]}, Neuron: {neuron_indices[i, 0]} - Layer: {layer_indices[i, 1]}, Neuron: {neuron_indices[i, 1]}"
#     for i in range(len(top_2_weight_diff_for_each_token))
# ]

# top_neuron_for_each_token_fig = px.scatter(x=vocab_strs, y=top_2_weights_for_each_token.values[:, 0], hover_name=neuron_labels,
#             labels={'x': 'Token', 'y': 'Congruence'},
#             title="Top Output Congruence for Each Token Embedding")
# sorted_top_dot_products = top_2_weights_for_each_token.values[:, 0].sort(descending=True)
# for i, (index, val) in enumerate(list(zip(sorted_top_dot_products.indices, sorted_top_dot_products.values))[:100]):
#     if index > 1000: # Weird command tokens break plotly
#         top_neuron_for_each_token_fig.add_annotation(x=index, y=val + 1,
#                                                     text=vocab_strs[index], showarrow=False,
#                                                     font=dict(color='darkgreen', size=10),
#                                                     bgcolor='white', borderwidth=1, borderpad=1, bordercolor='lightgrey',
#                                                 )
# top_neuron_for_each_token_fig.show()

# top_2_neuron_diff_for_each_token_fig = px.scatter(x=vocab_strs, y=top_2_weight_diff_for_each_token, hover_name=neuron_labels,
#            labels={'x': 'Token', 'y': 'Top 2 Congruence Difference'},
#            title="Difference between 2 Most Congruent Neurons for each Token")
# sorted_weight_diffs = top_2_weight_diff_for_each_token.sort(descending=True)
# for i, (index, val) in enumerate(list(zip(sorted_weight_diffs.indices, sorted_weight_diffs.values))[:21:3]):
#         top_2_neuron_diff_for_each_token_fig.add_annotation(x=index, y=val, text=vocab_strs[index], showarrow=True,
#                                                 ax=90 * ((i + 1) % 2) - 45, ay=-12)
# top_2_neuron_diff_for_each_token_fig.update_layout(yaxis = {"dtick": 1})
# top_2_neuron_diff_for_each_token_fig.show()

In [None]:
# top_2_tokens_for_each_neuron = torch.topk(weight_similarity, 2, dim=0) # (n_layers * d_mlp, 2)
# print('top_2_tokens', top_2_tokens_for_each_neuron.indices.shape)
# del weight_similarity

In [None]:
# top_2_token_diff_for_each_neuron = top_2_tokens_for_each_neuron.values[0, :] - top_2_tokens_for_each_neuron.values[1, :]
# token_labels = [
#     f"'{vocab_strs[top_2_tokens_for_each_neuron.indices[0, i]]}' - '{vocab_strs[top_2_tokens_for_each_neuron.indices[1, i]]}'"
#     for i in range(len(top_2_token_diff_for_each_neuron))
# ]

# top_2_token_diff_for_each_neuron_fig = px.scatter(x=neuron_names, y=top_2_token_diff_for_each_neuron, hover_name=token_labels,
#            labels={'x': 'Neuron', 'y': 'Top 2 Most Congruent Tokens Difference'},
#            title="Difference between Top 2 Most Congruent Tokens for each Neuron")
# sorted_weight_diffs = top_2_token_diff_for_each_neuron.sort(descending=True)
# top_2_token_diff_for_each_neuron_fig.update_layout(xaxis={"dtick": model.cfg.d_mlp})
# for i, (index, val) in enumerate(list(zip(sorted_weight_diffs.indices, sorted_weight_diffs.values))[:13]):
#         top_2_token_diff_for_each_neuron_fig.add_annotation(x=index, y=val, text=token_labels[index], showarrow=True,
#                                                             font=dict(size=10), ax=100 * (i % 2) - 50, ay=0)
# top_2_token_diff_for_each_neuron_fig.update_layout(yaxis = {"dtick": 1})
# top_2_token_diff_for_each_neuron_fig.show()

### Finding a cleanly associated neuron

In [None]:
# top_neuron_for_each_token = top_2_weights_for_each_token.indices[:, 0].flatten()
# top_token_for_each_neuron = top_2_tokens_for_each_neuron.indices[0, :].flatten()

# top_neuron_for_each_neurons_top_token = top_neuron_for_each_token[top_token_for_each_neuron]
# top_neuron_is_top_token = top_neuron_for_each_neurons_top_token.eq(torch.arange(0, top_neuron_for_each_neurons_top_token.shape[0]))

# mutual_monotokenic_neurons = top_neuron_is_top_token.nonzero().flatten()
# mutual_mononeuronic_tokens = top_token_for_each_neuron[mutual_monotokenic_neurons]

# mutual_exclusive_congruence = top_2_token_diff_for_each_neuron[mutual_monotokenic_neurons] * top_2_weight_diff_for_each_token[mutual_mononeuronic_tokens]
# # Sort by monotokenicity score
# monotokenicity_scores_sorted, monotokenicity_scores_sorted_indices = mutual_exclusive_congruence.sort(descending=True)
# mutual_mononeuronic_tokens = mutual_mononeuronic_tokens[monotokenicity_scores_sorted_indices]
# mutual_monotokenic_neurons = mutual_monotokenic_neurons[monotokenicity_scores_sorted_indices]
# monotoken_index = [neuron_names[i.item()] for i in mutual_monotokenic_neurons]
# top_mutual_exclusive_congruence_pairs_fig = px.bar(
#         x=[vocab_strs[i.item()] for i in mutual_mononeuronic_tokens[:30]],
#         y=monotokenicity_scores_sorted[:30],
#         text=[neuron_names[i.item()] for i in mutual_monotokenic_neurons[:30]],
#         labels={'x': 'Token-Neuron Pair', 'y': 'Mutual Exclusive Congruence Score'},
#         title="Top 30 Token-Neuron Pairs by Mutual Exclusive Congruence"
#     )
# top_mutual_exclusive_congruence_pairs_fig.show()

### Save figures

In [None]:
# all_figs = {
#     "logit_lens_fig": logit_lens_fig,
#     "patched_resid_fig": patched_resid_fig,
#     "patched_attn_fig": patched_attn_fig,
#     "patched_mlp_fig": patched_mlp_fig,
#     "patched_neuron_fig": patched_neuron_fig,
#     "multiline_top_pred_proportion_fig": multiline_top_pred_proportion_fig,
#     "top_neuron_for_each_token": top_neuron_for_each_token_fig,
#     "top_2_neuron_diff_for_each_token_fig": top_2_neuron_diff_for_each_token_fig,
#     "top_2_token_diff_for_each_neuron_fig": top_2_token_diff_for_each_neuron_fig,
#     "top_mutual_exclusive_congruence_pairs": top_mutual_exclusive_congruence_pairs_fig
# }
# token_neuron_pair_fig_lists = {
#     "top_pred_proportion_fig": top_pred_proportion_figs,
#     "neuron_congruence_with_vocab_fig": neuron_congruence_with_vocab_figs,
#     "token_congruence_with_each_neuron_fig": token_congruence_with_each_neuron_figs
# }
# for fig_base_name, fig_list in token_neuron_pair_fig_lists.items():
#     for i, fig in enumerate(fig_list):
#         all_figs[f"{tokens_of_interest_strs[i].strip()}_layer_{neuron_layers[i]}_neuron_{neuron_indices[i]}_{fig_base_name}"] = fig

# # from google.colab import drive
# # drive.mount('/content/drive')
# # fig_dir = Path('/content/drive/MyDrive/An_Neuron/generated_figures')
# fig_dir = Path('figures')
# png_fig_dir, svg_fig_dir, html_fig_dir = fig_dir / "png", fig_dir / "svg", fig_dir / "html"
# png_fig_dir.mkdir(parents=True, exist_ok=True)
# svg_fig_dir.mkdir(parents=True, exist_ok=True)
# html_fig_dir.mkdir(parents=True, exist_ok=True)

# # If this doesn't work, restart the runtime to ensure kaleido is installed :(
# for name, fig in all_figs.items():
#     fig.write_image(png_fig_dir / f"{name}.png", width=1048, height=512, scale=4)
#     fig.write_image(svg_fig_dir / f"{name}.svg", scale=4)
#     fig.write_html(html_fig_dir / f"{name}.html", include_plotlyjs="cdn")