# Main Notebook

This notebook contains all the mech interp tools I've developed, centered around doing circuit analysis for any task

## Imports/General Utility Functions

In [None]:
# Requires install of mamba lens
# https://github.com/Phylliida/MambaLens
# pip install git+https://github.com/Phylliida/MambaLens.git
# and also my implementation of ACDC (used for dataset managment)
# pip install git+https://github.com/Phylliida/ACDC.git

In [50]:

import torch
import pandas as pd
import plotly.express as px
import transformer_lens.utils as utils
from mamba_lens import HookedMamba
from acdc.data.utils import generate_dataset
from tqdm import tqdm
from acdc import get_pad_token
from transformer_lens.hook_points import HookPoint

torch.set_grad_enabled(False)

# from transformer lens
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        # Parse the PyTorch version to check if it's below version 2.0
        major_version = int(torch.__version__.split(".")[0])
        if major_version >= 2:
            return torch.device("mps")

    return torch.device("cpu")

# modified from neel nanda's examples
def imshow(tensor, renderer=None, xaxis="", yaxis="", font_size=None, show=True, color_continuous_midpoint=0.0, fix_size=False, **kwargs):
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=color_continuous_midpoint, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs)
    if not font_size is None:
        if 'x' in kwargs:
            fig.update_layout(
              xaxis = dict(
                tickmode='array',
                tickvals = kwargs['x'],
                ticktext = kwargs['x'], 
                ),
               font=dict(size=font_size, color="black"))
        if 'y' in kwargs:
            fig.update_layout(
              yaxis = dict(
                tickmode='array',
                tickvals = kwargs['y'],
                ticktext = kwargs['y'], 
                ),
               font=dict(size=font_size, color="black"))
    if fix_size:
        # default settings aren't very good, these are better
        plot_args = {
            'width': 800,
            'height': 600,
            "autosize": False,
            'showlegend': True,
            'margin': {"l":0,"r":0,"t":100,"b":0}
        }
        if model.cfg.n_layers < len(kwargs['y']):
            plot_args['height'] *= model.cfg.D_conv
        
        fig.update_layout(**plot_args)
        fig.update_layout(legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01
        ))
    if show:
        fig.show(renderer)
    else:
        return fig

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def bar_chart(data, x_labels, y_label, title, font_size=None):
    # it requires a pandas dict with the columns and rows named, annoying
    # by default rows and columns are named with ints so we relabel them accordingly
    renames = dict([(i, x_labels[i]) for i in range(len(x_labels))])
    ps = pd.DataFrame(data.cpu().numpy()).rename(renames, axis='rows').rename({0: y_label}, axis='columns')
    fig = px.bar(ps, y=y_label, x=x_labels, title=title)
    if not font_size is None:
        fig.update_layout(
          xaxis = dict(
            tickmode='array',
            tickvals = x_labels,
            ticktext = x_labels, 
            ),
           font=dict(size=font_size, color="black"))
    fig.show()

def get_batched_index_into(indices):
    '''
    given data that is [B,N,V] and indicies that are [B,N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    first_axis = []
    second_axis = []
    third_axis = []
    B, _, _ = indices.size()
    for b in range(B):
        second, third = get_index_into(indices[b])
        first_axis.append(torch.full(second.size(), fill_value=b, device=model.cfg.device))
        second_axis.append(second)
        third_axis.append(third)

    return torch.cat(first_axis), torch.cat(second_axis), torch.cat(third_axis)

def get_index_into(indices):
    '''
    given data that is [N,V] and indicies that are [N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    num_data, num_per_data = indices.size()
    # we want
    # [0,0,0,...,] num per data of these
    # [1,1,1,...,] num per data of these
    # ...
    # [num_data-1, num_data-1, ...]
    first_axis_index = torch.arange(num_data, dtype=torch.long).view(num_data, 1)*torch.ones([num_data, num_per_data], dtype=torch.long)
    # now we flatten it so it has an index for each term aligned with our indices
    first_axis_index = first_axis_index.flatten()
    second_axis_index = indices.flatten()
    return first_axis_index, second_axis_index

## Load Model

In [5]:

torch.set_grad_enabled(False)
device = get_device()
print("device", device)
model = HookedMamba.from_pretrained("state-spaces/mamba-370m", device=device)

device cuda


  return self.fget.__get__(instance, owner)()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Moving model to device:  cuda


## Setup Data

In [71]:
def decode_and_encode(tokenizer, tokens):
    '''
    Gets rid of weird encoding issues by encoding and decoding
    The tokens will be different that's okay and intentional
    '''
    prompt = tokenizer.decode(tokens).encode("ascii", "ignore").decode("ascii", "ignore")
    return tokenizer.encode(prompt)

def copy_data_generator(tokenizer, num_patching_pairs, copy_seq_len, num_repeats):
    '''
    Generates copy_seq_len random tokens, repeated twice (with the last token cut off)
    This is just a test to see if it can copy the repeated sequence from before

    for example, 

    uncorrupted:
    a b c d a b c (answer is d)
    corrupted:
    a b c e a b c (answer is e)
    '''
    first_len = None

    # ignore special tokens like BOS or PAD
    special_token_ids = set()
    for special_token_name, token_str in model.tokenizer.special_tokens_map.items():
        special_token_ids.add(model.tokenizer.convert_tokens_to_ids([token_str])[0])
    valid_ids = []
    for tok in range(tokenizer.vocab_size):
        if not tok in special_token_ids:
            valid_ids.append(tok)    
    valid_ids = torch.tensor(valid_ids)
    for i in list(range(num_patching_pairs)):
        while True:
            # sample without replacement
            # one extra so corrupted isn't in sequence
            data = valid_ids[torch.randperm(len(valid_ids))[:copy_seq_len+1]].flatten()
            corrupted_id = data[-1]
            data_repeating = data[:-1]
            uncorrupted_data = torch.concatenate([data_repeating]*num_repeats + [data_repeating[:-1]])
            corrupted_data = torch.concatenate([data_repeating[:-1], torch.tensor([corrupted_id])]*num_repeats + [data_repeating[:-1]])
            # make sure it encodes and decodes properly
            uncorrupted_prompt = tokenizer.decode(uncorrupted_data).encode("ascii", "ignore").decode("ascii", "ignore")
            corrupted_prompt = tokenizer.decode(corrupted_data).encode("ascii", "ignore").decode("ascii", "ignore")
            uncorrupted_answer = tokenizer.decode(data_repeating[-1]).encode("ascii", "ignore").decode("ascii", "ignore")
            corrupted_answer = tokenizer.decode(torch.tensor([corrupted_id])).encode("ascii", "ignore").decode("ascii", "ignore")
            reencoded_uncorrupted_data = torch.tensor(tokenizer.encode(uncorrupted_prompt))
            reencoded_corrupted_data = torch.tensor(tokenizer.encode(corrupted_prompt))
            if reencoded_uncorrupted_data.size() == uncorrupted_data.size() and reencoded_corrupted_data.size() == corrupted_data.size() and torch.all(reencoded_uncorrupted_data == uncorrupted_data) and torch.all(reencoded_corrupted_data == corrupted_data):
                break
        yield uncorrupted_prompt, [uncorrupted_answer], [corrupted_answer]
        yield corrupted_prompt, [corrupted_answer], [uncorrupted_answer]

num_patching_pairs = 100
seed = 42
valid_seed = 41
constrain_to_answers = False
has_symmetric_patching = True
varying_data_lengths = False
copy_seq_len = 30
num_repeats = 1 # just immediately copying seems only somewhat doable by this model

data = generate_dataset(model=model,
              data_generator=copy_data_generator,
              num_patching_pairs=num_patching_pairs,
              seed=seed,
              valid_seed=valid_seed,
              constrain_to_answers=constrain_to_answers,
              has_symmetric_patching=has_symmetric_patching, 
              varying_data_lengths=varying_data_lengths,
              copy_seq_len=copy_seq_len,
              num_repeats=num_repeats)

for i in list(range(10)):
    uncorrupted_i = i*2
    corrupted_i = i*2+1
    uncorrupted = data.data[uncorrupted_i][:data.last_token_position[uncorrupted_i]+1]
    corrupted = data.data[corrupted_i][:data.last_token_position[corrupted_i]+1]
    print(i)
    print(model.to_str_tokens(uncorrupted))
    print(f"answers   are {model.to_str_tokens(data.correct[uncorrupted_i])}")
    print(f"incorrect are {model.to_str_tokens(data.incorrect[uncorrupted_i])}")
    print(model.to_str_tokens(corrupted))
    print(f"answers   are {model.to_str_tokens(data.correct[corrupted_i])}")
    print(f"incorrect are {model.to_str_tokens(data.incorrect[corrupted_i])}")

0
['<|endoftext|>', ' supposed', 'column', ' Alien', ' looking', 'equal', ' Women', ' Cour', ' incomplete', ' cov', ' inventory', '840', '}$', ' force', ' Plans', 'ycin', ' airport', ' senators', ' Wolf', 'idav', ' nerv', ' Menu', ' gig', ' Statement', 'strom', ' cosmetic', ' receivers', ' keep', 'If', ' indent', ' Command', ' supposed', 'column', ' Alien', ' looking', 'equal', ' Women', ' Cour', ' incomplete', ' cov', ' inventory', '840', '}$', ' force', ' Plans', 'ycin', ' airport', ' senators', ' Wolf', 'idav', ' nerv', ' Menu', ' gig', ' Statement', 'strom', ' cosmetic', ' receivers', ' keep', 'If', ' indent']
answers   are [' Command']
incorrect are [' mile']
['<|endoftext|>', ' supposed', 'column', ' Alien', ' looking', 'equal', ' Women', ' Cour', ' incomplete', ' cov', ' inventory', '840', '}$', ' force', ' Plans', 'ycin', ' airport', ' senators', ' Wolf', 'idav', ' nerv', ' Menu', ' gig', ' Statement', 'strom', ' cosmetic', ' receivers', ' keep', 'If', ' indent', ' mile', ' sup

## Eval Data

In [69]:

print("printing example data points:")
for b in range(10):
    pad_token = get_pad_token(model.tokenizer)
    # because there is padding if lengths vary, this only fetches the tokens that are part of the sequence
    toks = data.data[b][:data.last_token_position[b]+1]
    print(model.tokenizer.decode(toks))
    for ind, tok in enumerate(data.correct[b]):
        if tok != pad_token:
            print(f"  correct answer: {repr(model.tokenizer.decode([tok.item()]))}")
    for ind, tok in enumerate(data.incorrect[b]):
        if tok != pad_token:
            print(f"  incorrect answer: {repr(model.tokenizer.decode([tok.item()]))}")

TOP_K = 5
from acdc import ACDCEvalData
from acdc import get_pad_token
def logging_incorrect_metric(data: ACDCEvalData):
    pad_token = get_pad_token(model.tokenizer)
    for data_subset in [data.patched, data.corrupted]:
        batch, _ = data_subset.data.size()
        for b in range(batch):
            if not data_subset.top_is_correct[b].item():
                if not data.constrain_to_answers:
                    logits = data_subset.logits[b]
                    prs = torch.nn.functional.softmax(logits, dim=0)
                    top = torch.argsort(-logits)
                toks = data_subset.data[b][:data_subset.last_token_position[b]+1]
                print("failed on this data point:")
                print(model.to_str_tokens(toks))
                print("correct prs:")
                for i, tok in enumerate(data_subset.correct[b]):
                    if tok.item() != pad_token:
                        print(data_subset.correct_prs[b,i].item(), model.tokenizer.decode([tok.item()]))
                        if not data.constrain_to_answers:
                            top_k_pos = (top==tok.item()).nonzero().item()
                            print(f" top k pos of {top_k_pos}")
                print("incorrect prs:")
                for i, tok in enumerate(data_subset.incorrect[b]):
                    if tok.item() != pad_token:
                        print(data_subset.incorrect_prs[b,i].item(), model.tokenizer.decode([tok.item()]))
                        if not data.constrain_to_answers:
                            top_k_pos = (top==tok.item()).nonzero().item()
                            print(f" top k pos of {top_k_pos}")
                if not data.constrain_to_answers:
                    for i, tok in enumerate(top[:TOP_K]):
                        if tok.item() in [x.item() for x in data_subset.correct[b]]:
                            print(f"  correct   top {i} token {tok} = {repr(model.tokenizer.decode([tok]))} logit {logits[tok]} prs {prs[tok]}")
                        elif tok.item() in [x.item() for x in data_subset.incorrect[b]]:
                            print(f"  incorrect top {i} token {tok} = {repr(model.tokenizer.decode([tok]))} logit {logits[tok]} prs {prs[tok]}")
                        else:
                            print(f"  other     top {i} token {tok} = {repr(model.tokenizer.decode([tok]))} logit {logits[tok]} prs {prs[tok]}")
    return data.patched.correct_prs[:,0]

pr_correct = data.eval(model=model, batch_size=10, metric=logging_incorrect_metric)
print(pr_correct)
print(torch.mean(pr_correct))

printing example data points:
<|endoftext|> supposedcolumn Alien lookingequal Women Cour incomplete cov inventory840}$ force Plansycin airport senators Wolfidav nerv Menu gig Statementstrom cosmetic receivers keepIf indent Command supposedcolumn Alien lookingequal Women Cour incomplete cov inventory840}$ force Plansycin airport senators Wolfidav nerv Menu gig Statementstrom cosmetic receivers keepIf indent
  correct answer: ' Command'
  incorrect answer: ' mile'
<|endoftext|> supposedcolumn Alien lookingequal Women Cour incomplete cov inventory840}$ force Plansycin airport senators Wolfidav nerv Menu gig Statementstrom cosmetic receivers keepIf indent mile supposedcolumn Alien lookingequal Women Cour incomplete cov inventory840}$ force Plansycin airport senators Wolfidav nerv Menu gig Statementstrom cosmetic receivers keepIf indent
  correct answer: ' mile'
  incorrect answer: ' Command'
<|endoftext|> readonly 366educatedYPE Challenge octSV person Mov XIII toddematporaryjquery threads 

failed on this data point:
['<|endoftext|>', ' Kamp', ' results', ' Administ', 'Path', 'ustomed', '146', ' carriers', ' spelled', 'Commun', ' footing', ' Bone', '*;', ' Kar', 'combin', 'complex', 'days', 'bl', ' insufficient', 'aline', ' AX', '099', '209', 'ica', ' moll', 'idas', ' contemplated', 'fm', 'adelphia', ' En', ' preferably', ' Kamp', ' results', ' Administ', 'Path', 'ustomed', '146', ' carriers', ' spelled', 'Commun', ' footing', ' Bone', '*;', ' Kar', 'combin', 'complex', 'days', 'bl', ' insufficient', 'aline', ' AX', '099', '209', 'ica', ' moll', 'idas', ' contemplated', 'fm', 'adelphia', ' En']
correct prs:
0.016819356009364128  preferably
 top k pos of 2
incorrect prs:
4.385238207760267e-05 pose
 top k pos of 2675
  other     top 0 token 46882 = ' Kamp' logit 9.139322280883789 prs 0.08582523465156555
  other     top 1 token 187 = '\n' logit 9.101213455200195 prs 0.08261607587337494
  correct   top 2 token 13027 = ' preferably' logit 7.509539604187012 prs 0.01681935600936

In [75]:

from acdc import accuracy_metric
num_patching_pairs = 100
seed = 42
valid_seed = 41
constrain_to_answers = False
has_symmetric_patching = True
varying_data_lengths = False

max_seq_len = 40
max_num_repeats = 10

output_data = torch.zeros([max_seq_len-1, max_num_repeats-1])
for i, copy_seq_len in enumerate(range(1,max_seq_len)):
    print(copy_seq_len)
    for j, num_repeats in enumerate(tqdm(range(1,max_num_repeats))):
        seed = i+j*num_repeats
        data = generate_dataset(model=model,
                    data_generator=copy_data_generator,
                    num_patching_pairs=num_patching_pairs,
                    seed=seed,
                    valid_seed=valid_seed,
                    constrain_to_answers=constrain_to_answers,
                    has_symmetric_patching=has_symmetric_patching, 
                    varying_data_lengths=varying_data_lengths,
                    copy_seq_len=copy_seq_len,
                    num_repeats=num_repeats)
        output_data[i,j] = torch.mean(data.eval(model=model, batch_size=20, metric=accuracy_metric)).item()


1


100%|██████████| 9/9 [00:27<00:00,  3.10s/it]


2


100%|██████████| 9/9 [00:35<00:00,  3.99s/it]


3


100%|██████████| 9/9 [00:41<00:00,  4.62s/it]


4


100%|██████████| 9/9 [00:46<00:00,  5.13s/it]


5


100%|██████████| 9/9 [00:58<00:00,  6.53s/it]


6


100%|██████████| 9/9 [01:05<00:00,  7.25s/it]


7


100%|██████████| 9/9 [01:17<00:00,  8.66s/it]


8


100%|██████████| 9/9 [01:20<00:00,  8.96s/it]


9


100%|██████████| 9/9 [01:32<00:00, 10.27s/it]


10


100%|██████████| 9/9 [01:39<00:00, 11.04s/it]


11


100%|██████████| 9/9 [01:46<00:00, 11.86s/it]


12


100%|██████████| 9/9 [02:03<00:00, 13.71s/it]


13


100%|██████████| 9/9 [02:18<00:00, 15.42s/it]


14


100%|██████████| 9/9 [02:22<00:00, 15.86s/it]


15


100%|██████████| 9/9 [02:40<00:00, 17.80s/it]


16


100%|██████████| 9/9 [02:50<00:00, 18.91s/it]


17


100%|██████████| 9/9 [02:14<00:00, 15.00s/it]


18


100%|██████████| 9/9 [01:55<00:00, 12.78s/it]


19


100%|██████████| 9/9 [02:03<00:00, 13.69s/it]


20


100%|██████████| 9/9 [02:12<00:00, 14.70s/it]


21


100%|██████████| 9/9 [02:19<00:00, 15.53s/it]


22


100%|██████████| 9/9 [02:27<00:00, 16.35s/it]


23


100%|██████████| 9/9 [02:35<00:00, 17.28s/it]


24


100%|██████████| 9/9 [02:47<00:00, 18.62s/it]


25


100%|██████████| 9/9 [02:57<00:00, 19.69s/it]


26


100%|██████████| 9/9 [03:06<00:00, 20.69s/it]


27


100%|██████████| 9/9 [03:16<00:00, 21.83s/it]


28


100%|██████████| 9/9 [03:20<00:00, 22.26s/it]


29


100%|██████████| 9/9 [03:36<00:00, 24.07s/it]


30


100%|██████████| 9/9 [03:44<00:00, 24.97s/it]


31


100%|██████████| 9/9 [03:59<00:00, 26.57s/it]


32


100%|██████████| 9/9 [04:15<00:00, 28.38s/it]


33


100%|██████████| 9/9 [04:29<00:00, 29.94s/it]


34


100%|██████████| 9/9 [04:50<00:00, 32.30s/it]


35


100%|██████████| 9/9 [05:13<00:00, 34.80s/it]


36


100%|██████████| 9/9 [05:35<00:00, 37.25s/it]


37


100%|██████████| 9/9 [05:55<00:00, 39.51s/it]


38


100%|██████████| 9/9 [06:04<00:00, 40.47s/it]


39


100%|██████████| 9/9 [06:30<00:00, 43.34s/it]


In [81]:
imshow(output_data.T, x=[str(x) for x in range(1,max_seq_len)], y=[str(y) for y in range(1,max_num_repeats)], fix_size=True, font_size=9, title='accuracy of mamba-370m on copy task', xaxis='num tokens in sequence', yaxis='num times repeated')

# here we see that it is very good as long as sequence longer than 3 and it's repeated at least twice. It's alright for repeated once but I'd prefer starting with a higher accuracy.
# I'll do 6 tokens in sequence, and repeat it 2 times (so 3 times total)
# 1 hr 5 min

## Patching

In [None]:
prompt_uncorrupted = model.tokenizer.encode(data.data[0])
prompt_corrupted = model.tokenizer.encode(data.data[1])
uncorrupted_answer = model.tokenizer.encode(data.correct[0])
corrupted_answer = model.tokenizer.encode(data.correct[1])

print(prompt_uncorrupted)
print(prompt_corrupted)
print(uncorrupted_answer)
print(corrupted_answer)

In [None]:
# modified from neel nanda's examples

H_N_PATCHING_LAYER = 39

# default settings aren't very good, these are better
plot_args = {
    'width': 800,
    'height': 600,
    "autosize": False,
    'showlegend': True,
    'margin': {"l":0,"r":0,"t":100,"b":0}
}

# you can modify this to only run things on a subset of layers
limited_layers = list(range(model.cfg.n_layers))


answer_tokens = sorted(list(set([model.tokenizer.encode(uncorrupted_answer)[0], model.tokenizer.encode(corrupted_answer)[0]])))

prompt_uncorrupted_tokens = model.to_tokens(prompt_uncorrupted)
prompt_corrupted_tokens = model.to_tokens(prompt_corrupted)

# logits should be [B,L,V] 
def uncorrupted_logit_minus_corrupted_logit(logits, uncorrupted_answer, corrupted_answer):
    uncorrupted_index = model.to_single_token(uncorrupted_answer)
    corrupted_index = model.to_single_token(corrupted_answer)
    return logits[0, -1, uncorrupted_index] - logits[0, -1, corrupted_index]

# prs should be [B,L,V] 
def uncorrupted_pr_minus_corrupted_pr(prs, uncorrupted_answer, corrupted_answer):
    uncorrupted_index = model.to_single_token(uncorrupted_answer)
    corrupted_index = model.to_single_token(corrupted_answer)
    return prs[0, -1, uncorrupted_index] - prs[0, -1, corrupted_index]



# [B,L,V]
corrupted_logits, corrupted_activations = model.run_with_cache(prompt_corrupted_tokens, only_use_these_layers=limited_layers)
corrupted_logit_diff = uncorrupted_logit_minus_corrupted_logit(logits=corrupted_logits, uncorrupted_answer=uncorrupted_answer, corrupted_answer=corrupted_answer)
corrupted_prs = torch.softmax(corrupted_logits, dim=2)
corrupted_pr_diff = uncorrupted_pr_minus_corrupted_pr(prs=corrupted_prs, uncorrupted_answer=uncorrupted_answer, corrupted_answer=corrupted_answer)


# [B,L,V]
uncorrupted_logits = model.run_with_hooks(prompt_uncorrupted_tokens, only_use_these_layers=limited_layers)
uncorrupted_logit_diff = uncorrupted_logit_minus_corrupted_logit(logits=uncorrupted_logits, uncorrupted_answer=uncorrupted_answer, corrupted_answer=corrupted_answer)
uncorrupted_prs = torch.softmax(uncorrupted_logits, dim=2)
uncorrupted_pr_diff = uncorrupted_pr_minus_corrupted_pr(prs=uncorrupted_prs, uncorrupted_answer=uncorrupted_answer, corrupted_answer=corrupted_answer)

uncorrupted_index = model.to_single_token(uncorrupted_answer)
corrupted_index = model.to_single_token(corrupted_answer)
print(f'uncorrupted prompt\n{prompt_uncorrupted}')
print(f"{repr(uncorrupted_answer)} logit {uncorrupted_logits[0,-1,uncorrupted_index]}")
print(f"{repr(uncorrupted_answer)} pr {uncorrupted_prs[0,-1,uncorrupted_index]}")
print(f"{repr(corrupted_answer)} logit {uncorrupted_logits[0,-1,corrupted_index]}")
print(f"{repr(corrupted_answer)} pr {uncorrupted_prs[0,-1,corrupted_index]}")
print(f'\ncorrupted prompt\n{prompt_corrupted}')
print(f"{repr(uncorrupted_answer)} logit {corrupted_logits[0,-1,uncorrupted_index]}")
print(f"{repr(uncorrupted_answer)} pr {corrupted_prs[0,-1,uncorrupted_index]}")
print(f"{repr(corrupted_answer)} logit {corrupted_logits[0,-1,corrupted_index]}")
print(f"{repr(corrupted_answer)} pr {corrupted_prs[0,-1,corrupted_index]}")

# We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
L = len(prompt_uncorrupted_tokens[0])
if len(prompt_corrupted_tokens[0]) != len(prompt_uncorrupted_tokens[0]):
    raise Exception("Prompts are not the same length") # feel free to comment this out, you can patch for different sized prompts its just a lil sus

# diff is logit of uncorrupted_answer - logit of corrupted_answer
# we expect corrupted_diff to have a negative value (as corrupted should put high pr on corrupted_answer)
# we expect uncorrupted to have a positive value (as uncorrupted should put high pr on uncorrupted_answer)
# thus we can treat these as (rough) min and max possible values
min_logit_diff = corrupted_logit_diff
max_logit_diff = uncorrupted_logit_diff

min_pr_diff = corrupted_pr_diff
max_pr_diff = uncorrupted_pr_diff


def generate_always_hooks():
    hooks = []

    LAYER = 39
    _, L = prompt_uncorrupted_tokens.size()
    for pos in range(L):
        # slice(None, None) is same as : (which means all)
        #hooks.append((f'blocks.{LAYER}.hook_h.{pos}', partial(h_patching_hook, layer=LAYER, position=pos, batch=slice(None, None))))
        hooks.append((f'blocks.{LAYER}.hook_layer_input', partial(position_patching_hook, layer=LAYER, position=pos, batch=slice(None, None))))

    #ABLATE_POS = 3
    #hooks.append((f'blocks.{35}.hook_layer_input', partial(position_patching_hook, layer=LAYER, position=3, batch=slice(None, None))))
    #hooks.append((f'blocks.{40}.hook_layer_input', partial(position_patching_hook, layer=LAYER, position=3, batch=slice(None, None))))
    #FINAL_POS = 19
    #hooks.append((f'blocks.{47}.hook_layer_input', partial(position_patching_hook, layer=LAYER, position=FINAL_POS, batch=slice(None, None))))
    return hooks
always_hooks = generate_always_hooks()

always_logits = model.run_with_hooks(prompt_uncorrupted_tokens, fwd_hooks=always_hooks, only_use_these_layers=limited_layers, fast_ssm=False, fast_conv=False)


always_prs = torch.softmax(always_logits, dim=2)
always_logit_diff = uncorrupted_logit_minus_corrupted_logit(logits=always_logits,
                                                                uncorrupted_answer=uncorrupted_answer,
                                                                corrupted_answer=corrupted_answer)
# normalize it so
# 0 means min_logit_diff (so 0 means that it is acting like the corrupted model)
# 1 means max_logit_diff (so 1 means that it is acting like the uncorrupted model)
normalized_always_logit_diff = (always_logit_diff-min_logit_diff)/(max_logit_diff - min_logit_diff)
# now flip them, since most interventions will do nothing and thus act like uncorrupted model, visually its better to have that at 0
# so now
# 0 means that it is acting like the uncorrupted model
# 1 means that it is acting like the corrupted model
normalized_always_logit_diff = 1.0 - normalized_always_logit_diff

# same for pr
always_pr_diff = uncorrupted_pr_minus_corrupted_pr(prs=always_prs,
                                                    uncorrupted_answer=uncorrupted_answer,
                                                    corrupted_answer=corrupted_answer)
normalized_always_pr_diff = 1.0-(always_pr_diff-min_pr_diff)/(max_pr_diff - min_pr_diff)

# make token labels that describe the patch
corrupted_str_tokens = model.to_str_tokens(prompt_corrupted_tokens)
uncorrupted_str_tokens = model.to_str_tokens(prompt_uncorrupted_tokens)
token_labels = []
for index, (corrupted_token, uncorrupted_token) in enumerate(zip(corrupted_str_tokens, uncorrupted_str_tokens)):
    if corrupted_token == uncorrupted_token:
        token_labels.append(f"{corrupted_token}_{index}")
    else:
        token_labels.append(f"{uncorrupted_token}->{corrupted_token}_{index}")

def run_patching(patching_type,
                 patching_hook_name_func,
                 patching_hook_func,
                 batch_size,
                 show_options, 
                 min_logit_diff,
                 max_logit_diff,
                 min_pr_diff,
                 max_pr_diff,
                 token_labels,
                 prompt_uncorrupted_tokens,
                 uncorrupted_answer,
                 corrupted_answer,
                 always_hooks=None,
                 show_plot=True,
                 **kwargs):
    _, L = prompt_uncorrupted_tokens.size()
    torch.cuda.empty_cache()
    hook_title = patching_hook_name_func(layer='{layer}', position='{position}')
    print(f"running patching {patching_type}, using hook {hook_title}")
    global patching_result_logits, patching_result_prs # if you want to access it once this is done running
    n_layers = len(limited_layers)

    num_results = n_layers
    if patching_type == H_N_PATCHING:
        print(f"on layer H_N_PATCHING_LAYER={H_N_PATCHING_LAYER}")
        N = model.cfg.N
        num_results = N
    elif patching_type == CONV_FILTERS_PATCHING:
        D_conv = model.cfg.D_conv
        num_results = (D_conv-1)*n_layers # -1 because the zero one is always zero so we ignore it
    
    patching_result_normalized_logits = torch.zeros((num_results, L), device=model.cfg.device)
    patching_result_normalized_prs = torch.zeros((num_results, L), device=model.cfg.device)

    num_answers = len(answer_tokens)
    patching_result_logits = torch.zeros((num_results, L, num_answers), device=model.cfg.device)
    patching_result_prs = torch.zeros((num_results, L, num_answers), device=model.cfg.device)
    
    hooks = []
    # skipping h needs A_bar stored, so also add that hook
    if patching_type == SKIPPING_H_PATCHING:
        for i, layer in list(enumerate(limited_layers)):
            hooks.append((f'blocks.{layer}.hook_A_bar', partial(A_bar_storage_hook_for_skipping_h, layer=layer)))

    # skipping layer needs layer_input (resid_pre) stored, so also add that hook
    if patching_type == LAYER_SKIPPING:
        for i, layer in list(enumerate(limited_layers)):
            hooks.append((f'blocks.{layer}.hook_resid_pre', partial(layer_input_storage_hook, layer=layer)))
    
    # conv filters works via initializing things, then storing all the stuff we want to hook, then doing all that in place at the same time
    if patching_type == CONV_FILTERS_PATCHING:
        for i, layer in list(enumerate(limited_layers)):
            # reset the storage to empty/initialize stuff
            hooks.append((f'blocks.{layer}.hook_layer_input', better_conv_patching_init_hook))
            # doing all the patches we have stored (below) at the same time
            hooks.append((f'blocks.{layer}.hook_conv', partial(better_conv_patching_hook, input_hook_name=f'blocks.{layer}.hook_in_proj', layer=layer)))

    if not always_hooks is None:
        hooks += always_hooks

    initial_num_hooks = len(hooks)

    
    if patching_type == H_N_PATCHING:
        batch = 0
        indices = []
        for n in range(N):
            for position in range(L):
                patching_hook_name = patching_hook_name_func(layer=H_N_PATCHING_LAYER, position=position)
                if batch_size != BATCH_SIZE_ALL: batch = batch % int(batch_size)
                patching_hook = partial(patching_hook_func, layer=H_N_PATCHING_LAYER, position=position, n=n, batch=batch)
                batch += 1
                indices.append((n,position))
                hooks.append((patching_hook_name, patching_hook))
    elif patching_type == CONV_FILTERS_PATCHING:
        batch = 0
        indices = []
        D_conv = model.cfg.D_conv
        ind = 0
        for i, layer in list(enumerate(limited_layers)):
            for conv_filter_i in range(D_conv):
                if conv_filter_i == 0: continue # this -d_conv-1 filter is always zero for some reason
                for position in range(L):
                    patching_hook_name = patching_hook_name_func(layer=layer, position=position)
                    if batch_size != BATCH_SIZE_ALL: batch = batch % int(batch_size)
                    hooks.append((f'blocks.{layer}.hook_in_proj', partial(better_conv_patching_storage_hook, position=position, layer=layer, conv_filter_i=conv_filter_i, batch=batch)))
                    #patching_hook = partial(patching_hook_func, layer=layer, position=position, batch=batch, conv_filter_i=conv_filter_i)
                    batch += 1
                    indices.append((ind,position))
                    #hooks.append((patching_hook_name, patching_hook))
                ind += 1
    else:
        batch = 0
        indices = []
        for i, layer in list(enumerate(limited_layers)):
            for position in range(L):
                patching_hook_name = patching_hook_name_func(layer=layer, position=position)
                if batch_size != BATCH_SIZE_ALL: batch = batch % int(batch_size)
                patching_hook = partial(patching_hook_func, layer=layer, position=position, batch=batch)
                batch += 1
                indices.append((i,position))
                hooks.append((patching_hook_name, patching_hook))

    
    if batch_size != BATCH_SIZE_ALL:
        V = model.cfg.V
        patched_logits = torch.zeros([len(indices), L, V])
        for batch_start in tqdm(list(range(0, len(indices), int(batch_size)))):
            batch_end = min(len(indices), batch_start+int(batch_size))
            # always do the initial they are for storage
            batch_hooks = hooks[:initial_num_hooks] + hooks[initial_num_hooks+batch_start:initial_num_hooks+batch_end]
            cur_batch_size = batch_end-batch_start
            patched_logits[batch_start:batch_end] = model.run_with_hooks(prompt_uncorrupted_tokens.expand(cur_batch_size,L), fwd_hooks=batch_hooks, only_use_these_layers=limited_layers, **kwargs)
    else:
        # [B,L,V]
        patched_logits = model.run_with_hooks(prompt_uncorrupted_tokens.expand(batch,L), fwd_hooks=hooks, only_use_these_layers=limited_layers, **kwargs)
   
    # [B,L,V]
    patched_prs = torch.softmax(patched_logits, dim=2)
    print("finished patching, plotting...")
    for b, (i,position) in enumerate(indices):
        if corrupted_answer != uncorrupted_answer:
            patched_logit_diff = uncorrupted_logit_minus_corrupted_logit(logits=patched_logits[b:b+1],
                                                                         uncorrupted_answer=uncorrupted_answer,
                                                                         corrupted_answer=corrupted_answer)
            # normalize it so
            # 0 means min_logit_diff (so 0 means that it is acting like the corrupted model)
            # 1 means max_logit_diff (so 1 means that it is acting like the uncorrupted model)
            normalized_patched_logit_diff = (patched_logit_diff-min_logit_diff)/(max_logit_diff - min_logit_diff)
            # now flip them, since most interventions will do nothing and thus act like uncorrupted model, visually its better to have that at 0
            # so now
            # 0 means that it is acting like the uncorrupted model
            # 1 means that it is acting like the corrupted model
            normalized_patched_logit_diff = 1.0 - normalized_patched_logit_diff
            normalized_patched_logit_diff = normalized_patched_logit_diff #normalized_always_logit_diff - normalized_patched_logit_diff
            patching_result_normalized_logits[i, position] = normalized_patched_logit_diff
            
            # same for pr
            patched_pr_diff = uncorrupted_pr_minus_corrupted_pr(prs=patched_prs[b:b+1],
                                                                uncorrupted_answer=uncorrupted_answer,
                                                                corrupted_answer=corrupted_answer)
            normalized_patched_pr_diff = 1.0-(patched_pr_diff-min_pr_diff)/(max_pr_diff - min_pr_diff)
            normalized_patched_pr_diff = normalized_always_pr_diff - normalized_patched_pr_diff
            patching_result_normalized_prs[i, position] = normalized_patched_pr_diff

        for k, answer_token in enumerate(answer_tokens):
            patching_result_logits[i, position, k] = patched_logits[b,-1,answer_token]
            patching_result_prs[i, position, k] = patched_prs[b,-1,answer_token]
    
        
    if patching_type == H_N_PATCHING:
        layer_labels = [str(n) for n in range(N)]
    elif patching_type == CONV_FILTERS_PATCHING:
        layer_labels = []
        for layer in limited_layers:
            for conv_i in range(1, D_conv):
                layer_labels.append(f"layer {layer} conv {conv_i-D_conv+1}")
    else:
        layer_labels = [str(layer) for layer in limited_layers]
    figs = []
    y_axis = 'Layer'
    if patching_type == H_N_PATCHING:
        y_axis = 'N'
    if show_plot:
        if corrupted_answer != uncorrupted_answer:
            if show_options in [SHOW_LOGITS, SHOW_BOTH]:
                figs.append(imfig(patching_result_normalized_logits, x=token_labels, y=layer_labels, xaxis="Position", yaxis=y_axis, title=f"Normalized logit difference after patching {patching_type} using hook {hook_title}", font_size=8))
            if show_options in [SHOW_PR, SHOW_BOTH]:
                figs.append(imfig(patching_result_normalized_prs, x=token_labels, y=layer_labels, xaxis="Position", yaxis=y_axis, title=f"Normalized pr difference after patching {patching_type} using hook {hook_title}", font_size=8))
        
        for k, answer_token in enumerate(answer_tokens):
            if show_options in [SHOW_LOGITS, SHOW_BOTH]:
                figs.append(imfig(patching_result_logits[:,:,k], color_continuous_midpoint=None, x=token_labels, y=layer_labels, xaxis="Position", yaxis="Layer", title=f"Logit of uncorrupted answer {repr(model.tokenizer.decode([answer_token]))} after patching {patching_type} using hook {hook_title}", font_size=8))
            if show_options in [SHOW_PR, SHOW_BOTH]:
                figs.append(imfig(patching_result_prs[:,:,k], x=token_labels, y=layer_labels, xaxis="Position", yaxis=y_axis, title=f"Pr of uncorrupted answer {repr(model.tokenizer.decode([answer_token]))} after patching {patching_type} using hook {hook_title}", font_size=8)) 
        
        for fig in figs:
            plot_args_copy = dict(list(plot_args.items()))
            if patching_type == CONV_FILTERS_PATCHING:
                plot_args_copy['height'] *= D_conv
            fig.update_layout(**plot_args_copy)
            fig.update_layout(legend=dict(
                yanchor="top",
                y=0.99,
                xanchor="left",
                x=0.01
            ))
            fig.show()
    else:
        return layer_labels, y_axis, patching_result_normalized_logits, patching_result_normalized_prs, patching_result_logits, patching_result_prs

## hooks for conv filter patching
def conv_input_storage_hook(
    conv_input: Float[torch.Tensor, "B L E"],
    hook: HookPoint,
    layer: int,
) -> Float[torch.Tensor, "B L E"]:
    global progress # it's slow enough that progress bar is useful
    if layer == 0:
        progress = tqdm(total=len(limited_layers))
    else:
        progress.update(1)
    global storage
    storage = {}
    storage['conv_input'] = conv_input
    return conv_input

def conv_patching_hook(
    conv_output: Float[torch.Tensor, "B L E"],
    hook: HookPoint,
    layer: int,
    position: int,
    batch: int,
    conv_filter_i: int,
) -> Float[torch.Tensor, "B L E"]:
    global storage
    conv_input = storage['conv_input']
    B, L, E = conv_input.size()
    conv_input = rearrange(conv_input, 'B L E -> B E L')
    conv_input_corrupted = rearrange(corrupted_activations[f'blocks.{layer}.hook_in_proj'], 'B L E -> B E L')
    
    ### This is identical to what the conv is doing
    # pad zeros in front
    # [B,E,D_CONV-1+L]
    D_CONV = model.cfg.d_conv
    padded_input = torch.nn.functional.pad(conv_input, (D_CONV-1,0), mode='constant', value=0)
    padded_input_corrupted = torch.nn.functional.pad(conv_input_corrupted, (D_CONV-1,0), mode='constant', value=0)
    output = torch.zeros([B,E,L], device=model.cfg.device)
    # [E,1,D_CONV]
    conv_weight = model.blocks[layer].conv1d.weight
    # [E]
    conv_bias = model.blocks[layer].conv1d.bias
    # this is inefficient because its recomputing things every time
    # but I don't want to have to rely on the ordering of hooks because that's sus
    # so this is good enough
    for i in range(D_CONV):
        filter_str = f'filter_{i}'
        if not filter_str in storage:
            # [B,E,L]                      [E,1]                      [B,E,L]
            filter_contribution = conv_weight[:,0,i].view(E,1)*padded_input[:,:,i:i+L]
            storage[filter_str] = filter_contribution
        filter_contribution = storage[filter_str]
        if i == conv_filter_i:
            # [1,E,L]                                   [E,1]                          # [1,E,L]
            corrupted_filter_contribution = conv_weight[:,0,i].view(E,1)*padded_input_corrupted[:,:,i:i+L]
            # [E]                                                    [E]
            filter_contribution[batch,:,position] = corrupted_filter_contribution[0,:,position]
        storage[filter_str] = filter_contribution
        output += filter_contribution
        #output += conv_weight[:,0,i].view(E,1)*conv_input
        #if i == D_CONV-1:
        #    output += conv_weight[:,0,i].view(E,1)*conv_input

    # bias is not dependent on input so no reason to patch on it, just apply it as normal
    output += conv_bias.view(E, 1)
    
    output = rearrange(output, 'B E L -> B L E')
    return output


# we do a hacky thing where this first hook clears the global storage
# second hook stores all the hooks
# then third hook computes the output (over all the hooks)
# this avoids recomputing and so is much faster
global storage
global conv_storage
storage = {}
conv_storage = {}
CONV_HOOKS = "conv hooks"
CONV_BATCHES = "conv batches"
def better_conv_patching_init_hook(
    x,
    hook: HookPoint,
    **kwargs
):
    #print("init hook with layer", hook.name)
    # we need to clear this here
    # i tried having a "current layer" variable in the conv_storage that only clears when it doesn't match
    # but that doesn't work if you only patch the same layer over and over,
    # as stuff gets carried over
    # this way of doing things is much safer and lets us assume it'll be empty
    # well not quite, note that conv_patching_hook will be called with different batch_start and batch_end inputs during one forward pass
    # so we need to account for that in the keys we use
    global conv_storage
    global storage
    storage = {}
    conv_storage = {CONV_BATCHES: set()}
    return x

def better_conv_patching_storage_hook(
    x,
    hook: HookPoint,
    conv_filter_i: int,
    position: int,
    layer: int,
    batch: int,
    **kwargs,
):
    #print("append hook with layer", hook.name, "conv filter", conv_filter_i, "position", position, "layer", layer, "batch", batch)
    batch_start = batch
    batch_end = batch+1
    global storage
    storage[hook.name] = x
    global conv_storage
    hooks_key = (CONV_HOOKS, batch_start, batch_end)
    if not hooks_key in conv_storage:
        conv_storage[hooks_key] = [] # we can't do this above because it'll be emptied again on the next batch before this is called
    conv_storage[hooks_key].append({"position": position, "conv_filter_i": conv_filter_i})
    conv_storage[CONV_BATCHES].add((batch_start, batch_end))
    return x

from jaxtyping import Float
from einops import rearrange
global corrupted_activations

global conv_storage
def better_conv_patching_hook(
    conv_output: Float[torch.Tensor, "B L E"],
    hook: HookPoint,
    input_hook_name: str,
    layer: int,
    **kwargs,
) -> Float[torch.Tensor, "B L E"]:
    global conv_storage
    global storage
    ### This is identical to what the conv is doing
    # but we break it apart so we can patch on individual filters
    
    D_CONV = model.cfg.d_conv

    global corrupted_activations
    # [E,1,D_CONV]
    conv_weight = model.blocks[layer].conv1d.weight
    # [E]
    conv_bias = model.blocks[layer].conv1d.bias
    
    # don't recompute these if we don't need to
    # because we stored all the hooks and batches in conv_storage, we can just do them all at once
    output_key = f'output' # they need to share an output because they write to the same output tensor
    if not output_key in conv_storage:
        #print("layer", layer, "keys", conv_storage)
        apply_to_all_hooks = [] # this is important because otherwise the [0:None] would overwrite the previous results (or vice versa)
        apply_to_all_key = (CONV_HOOKS, 0, None)
        if apply_to_all_key in conv_storage:
            apply_to_all_hooks = conv_storage[apply_to_all_key]
        for batch_start, batch_end in conv_storage[CONV_BATCHES]:
            if batch_start == 0 and batch_end == None: continue # we cover this in the apply to all hooks above
            def get_filter_key(i):
                return f'filter_{i}'
            conv_input_uncorrupted = storage[input_hook_name][batch_start:batch_end]
            conv_input_corrupted = corrupted_activations[input_hook_name]
            B, L, E = conv_input_uncorrupted.size()
            
            conv_input_uncorrupted = rearrange(conv_input_uncorrupted, 'B L E -> B E L')
            conv_input_corrupted = rearrange(conv_input_corrupted, 'B L E -> B E L')
            
            # pad zeros in front
            # [B,E,D_CONV-1+L]
            padded_input_uncorrupted = torch.nn.functional.pad(conv_input_uncorrupted, (D_CONV-1,0), mode='constant', value=0)
            padded_input_corrupted = torch.nn.functional.pad(conv_input_corrupted, (D_CONV-1,0), mode='constant', value=0)
    
            # compute the initial filter values
            for i in range(D_CONV):
                filter_key = get_filter_key(i)
                # [B,E,L]                      [E,1]                      [B,E,L]
                filter_contribution = conv_weight[:,0,i].view(E,1)*padded_input_uncorrupted[:,:,i:i+L]
                conv_storage[filter_key] = filter_contribution
            
            # apply all the hooks
            for hook in conv_storage[(CONV_HOOKS, batch_start, batch_end)] + apply_to_all_hooks:
                position = hook['position']
                conv_filter_i = hook['conv_filter_i']
                #print(f"position {position} conv_filter_i {conv_filter_i} batch_start {batch_start} batch_end {batch_end}")
                filter_key = get_filter_key(conv_filter_i)
                # [1,E,L]                                   [E,1]                          # [B,E,L]
                corrupted_filter_contribution = conv_weight[:,0,conv_filter_i].view(E,1)*padded_input_corrupted[:,:,conv_filter_i:conv_filter_i+L]
                filter_contribution = conv_storage[filter_key]
                if position is None:
                    # [B,E,L]                    [B,E,L]
                    filter_contribution = corrupted_filter_contribution
                else:
                    # [B,E]                                                  [B,E]
                    filter_contribution[:,:,position] = corrupted_filter_contribution[:,:,position]
                conv_storage[filter_key] = filter_contribution
            
            # compute the output
            output = torch.zeros([B,E,L], device=model.cfg.device)
            #print(f'B {B} B2 {B2} E {E} L {L} conv_storage keys {conv_storage.keys()} filter sizes {[(k,v.size()) for (k,v) in conv_storage.items() if not type(v) is int]}')
            for i in range(D_CONV):
                filter_key = get_filter_key(i)
                output += conv_storage[filter_key]
                del conv_storage[filter_key] # clean up now we are done with it, just to be safe
            # bias is not dependent on input so no reason to patch on it, just apply it as normal
            output += conv_bias.view(E, 1)
            output = rearrange(output, 'B E L -> B L E')
            # interleave it back with the corrupted as every other
            conv_output[batch_start:batch_end] = output
        conv_storage[output_key] = conv_output
    return conv_storage[output_key]
           


## hooks for layer skipping
def layer_input_storage_hook(
    layer_input: Float[torch.Tensor, "B L D"],
    hook: HookPoint,
    layer: int,
) -> Float[torch.Tensor, "B L D"]:
    global storage
    storage = {}
    storage['layer_input'] = layer_input
    return layer_input

def layer_output_skipping_hook(
    layer_output: Float[torch.Tensor, "B L D"],
    hook: HookPoint,
    position: int,
    layer: int,
    batch: int,
) -> Float[torch.Tensor, "B L D"]:
    global storage
    layer_input = storage['layer_input']
    # intervene on the batch at the position
    layer_output[batch,position,:] = layer_input[batch,position,:]
    return layer_output


## hooks for h skipping
def A_bar_storage_hook_for_skipping_h(
    A_bar: Float[torch.Tensor, "B L E N"],
    hook: HookPoint,
    layer: int,
) -> Float[torch.Tensor, "B L E N"]:
    global storage
    storage = {}
    storage['A_bar'] = A_bar
    return A_bar

def skipping_h_hook(
    h: Float[torch.Tensor, "B E N"],
    hook: HookPoint,
    position: int,
    layer: int,
    batch: int,
) -> Float[torch.Tensor, "B E N"]:
    #print("fetching", storage[grab_pos][0,0,0:5], "from position", grab_pos)
    #print("my value (being ignore) is", h[0,0,0:5])
    #print(f"skipping ahead h at position {position}")
    global storage
    B,E,N = h.size()
    grab_pos = position-1
    if grab_pos < 0:
        h[batch,:,:] = torch.zeros((E,N), device=model.cfg.device)
    else:
        B,E,N = h.size()
        A_contribution = torch.ones((E,N), device=model.cfg.device)
        for missed_pos in range(grab_pos+1, position+1):
            A_contribution *= storage['A_bar'][batch,missed_pos,:,:]
        h_stored = storage[grab_pos][batch,:,:]
        h[batch,:,:] = A_contribution*h_stored
        #return A_contribution*storage[grab_pos]
    storage[position] = h
    return h


## Regular patching hooks
def position_patching_hook( # also works for B L E, B L E N, and B L N sized things
    x: Float[torch.Tensor, "B L D"],
    hook: HookPoint,
    position: int,
    layer: int, # we don't care about this
    batch: int,
) -> Float[torch.Tensor, "B L D"]:
    # only intervene on the specific pos
    corrupted_x = corrupted_activations[hook.name]
    x[batch, position, :] = corrupted_x[0, position, :]
    return x

def h_patching_hook(
    h: Float[torch.Tensor, "B E N"],
    hook: HookPoint,
    position: int,
    layer: int,
    batch: int,
) -> Float[torch.Tensor, "B E N"]:
    corrupted_h = corrupted_activations[hook.name]
    h[batch] = corrupted_h[0]
    return h

def h_n_patching_hook(
    h: Float[torch.Tensor, "B E N"],
    hook: HookPoint,
    position: int,
    layer: int,
    n: int,
    batch: int,
) -> Float[torch.Tensor, "B E N"]:
    corrupted_h = corrupted_activations[hook.name]
    h[batch,:,n] = corrupted_h[0,:,n]
    return h

SKIPPING_H_PATCHING = 'skipping h'
H_N_PATCHING = 'h_n'
LAYER_SKIPPING = 'skipping layer'
CONV_FILTERS_PATCHING = 'conv filters'

patching_types = {
    'resid pre': (lambda layer, position: f'blocks.{layer}.hook_resid_pre', position_patching_hook),
    'layer input': (lambda layer, position: f'blocks.{layer}.hook_layer_input', position_patching_hook),
    'normalized input': (lambda layer, position: f'blocks.{layer}.hook_normalized_input', position_patching_hook),
    'skip': (lambda layer, position: f'blocks.{layer}.hook_skip', position_patching_hook), 
    'in proj': (lambda layer, position: f'blocks.{layer}.hook_in_proj', position_patching_hook), 
    CONV_FILTERS_PATCHING: (lambda layer, position: f'blocks.{layer}.hook_conv', conv_patching_hook),
    'conv': (lambda layer, position: f'blocks.{layer}.hook_conv', position_patching_hook), 
    'delta 1': (lambda layer, position: f'blocks.{layer}.hook_delta_1', position_patching_hook), 
    'delta 2': (lambda layer, position: f'blocks.{layer}.hook_delta_2', position_patching_hook), 
    'delta': (lambda layer, position: f'blocks.{layer}.hook_delta', position_patching_hook), 
    'A_bar': (lambda layer, position: f'blocks.{layer}.hook_A_bar', position_patching_hook), 
    'B': (lambda layer, position: f'blocks.{layer}.hook_B', position_patching_hook), 
    'B_bar': (lambda layer, position: f'blocks.{layer}.hook_B_bar', position_patching_hook), 
    'C': (lambda layer, position: f'blocks.{layer}.hook_C', position_patching_hook), 
    'ssm input': (lambda layer, position: f'blocks.{layer}.hook_ssm_input', position_patching_hook),
    SKIPPING_H_PATCHING: (lambda layer, position: f'blocks.{layer}.hook_h.{position}', skipping_h_hook),
    'h': (lambda layer, position: f'blocks.{layer}.hook_h.{position}', h_patching_hook),
    H_N_PATCHING: (lambda layer, position: f'blocks.{layer}.hook_h.{position}', h_n_patching_hook),
    'y': (lambda layer, position: f'blocks.{layer}.hook_y', position_patching_hook),
    'ssm output': (lambda layer, position: f'blocks.{layer}.hook_ssm_output', position_patching_hook),
    'after skip': (lambda layer, position: f'blocks.{layer}.hook_after_skip', position_patching_hook),
    'out proj': (lambda layer, position: f'blocks.{layer}.hook_out_proj', position_patching_hook),
    'resid post': (lambda layer, position: f'blocks.{layer}.hook_resid_post', position_patching_hook),
    LAYER_SKIPPING: (lambda layer, position: f'blocks.{layer}.hook_resid_post', layer_output_skipping_hook),
}

patching_types_keys = list(patching_types.keys())

def choose_patching_type(change):
    if change['type'] == 'change' and change['name'] == 'value':
        choose_patching_type.patching_type = change['new'] # hack, gives this function the patching_type attribute

choose_patching_type.patching_type = patching_types_keys[0]

patching_type_dropdown = ipywidgets.Dropdown(
    options=patching_types_keys,
    value=patching_types_keys[0],
    description='patching type',
)
patching_type_dropdown.observe(choose_patching_type)
display(patching_type_dropdown)

BATCH_SIZE_ALL = 'all'
batch_size_keys = [BATCH_SIZE_ALL] + [str(b) for b in range(model.cfg.n_layers*model.cfg.D_conv*L)]

def choose_batch_size(change):
    if change['type'] == 'change' and change['name'] == 'value':
        choose_batch_size.batch_size = change['new']

choose_batch_size.batch_size = batch_size_keys[0]

choose_batch_size_dropdown = ipywidgets.Dropdown(
    options=batch_size_keys,
    value=batch_size_keys[0],
    description='batch size',
)
choose_batch_size_dropdown.observe(choose_batch_size)
display(choose_batch_size_dropdown)

fast_conv_keys = ['True', 'False']

def choose_fast_conv(change):
    if change['type'] == 'change' and change['name'] == 'value':
        choose_fast_conv.fast_conv = change['new'] == 'True'

choose_fast_conv.fast_conv = fast_conv_keys[0] == 'True'

choose_fast_conv_dropdown = ipywidgets.Dropdown(
    options=fast_conv_keys,
    value=fast_conv_keys[0],
    description='fast conv',
)
choose_fast_conv_dropdown.observe(choose_fast_conv)
display(choose_fast_conv_dropdown)


fast_ssm_keys = ['False', 'True']

def choose_fast_ssm(change):
    if change['type'] == 'change' and change['name'] == 'value':
        choose_fast_ssm.fast_ssm = change['new'] == 'True'

choose_fast_ssm.fast_ssm = fast_ssm_keys[0] == 'True'

choose_fast_ssm_dropdown = ipywidgets.Dropdown(
    options=fast_ssm_keys,
    value=fast_ssm_keys[0],
    description='fast ssm',
)
choose_fast_ssm_dropdown.observe(choose_fast_ssm)
display(choose_fast_ssm_dropdown)

SHOW_PR = 'Pr'
SHOW_LOGITS = 'Logits'
SHOW_BOTH = 'Both'
show_options = [SHOW_LOGITS, SHOW_PR, SHOW_BOTH]

def choose_show_options(change):
    if change['type'] == 'change' and change['name'] == 'value':
        choose_show_options.show_options = change['new']

choose_show_options.show_options = show_options[0]

show_options_dropdown = ipywidgets.Dropdown(
    options=show_options,
    value=show_options[0],
    description='logits or pr',
)
show_options_dropdown.observe(choose_show_options)
display(show_options_dropdown)



def do_patching(arg, show_plot=True):
    with output: # this lets the stuff we output here be visible
        clear_output()
        patching_type = choose_patching_type.patching_type
        hook_name_func, hook_func = patching_types[patching_type]
        return run_patching(
                     patching_type=patching_type,
                     patching_hook_name_func=hook_name_func,
                     patching_hook_func=hook_func,
                     batch_size=choose_batch_size.batch_size,
                     fast_ssm=choose_fast_ssm.fast_ssm,
                     fast_conv=choose_fast_conv.fast_conv,
                     show_options=choose_show_options.show_options,
                     show_plot=show_plot,
                     min_logit_diff=min_logit_diff,
                     max_logit_diff=max_logit_diff,
                     min_pr_diff=min_pr_diff,
                     max_pr_diff=max_pr_diff,
                     token_labels=token_labels,
                     prompt_uncorrupted_tokens=prompt_uncorrupted_tokens,
                     uncorrupted_answer=uncorrupted_answer,
                     corrupted_answer=corrupted_answer,
                     always_hooks=generate_always_hooks())

patching_button = ipywidgets.Button(description = 'Run Patching')
patching_button.on_click(do_patching)
display(patching_button)

# you can't just display stuff inside a widget callback, you need a wrap any display code in this
output = ipywidgets.Output()
display(output)