# Setup
(No need to change anything)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-6lxpfmh9
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-6lxpfmh9
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [2]:
# # Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
# import plotly.io as pio

# if IN_COLAB or not DEBUG_MODE:
#     # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
#     pio.renderers.default = "colab"
# else:
#     pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f47426a7e50>

Plotting helper functions:

In [6]:
# def imshow(tensor, renderer=None, **kwargs):
#     px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

# def line(tensor, renderer=None, **kwargs):
#     px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

# def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
#     x = utils.to_numpy(x)
#     y = utils.to_numpy(y)
#     px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [8]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9106, done.[K
remote: Counting objects: 100% (1818/1818), done.[K
remote: Compressing objects: 100% (288/288), done.[K
remote: Total 9106 (delta 1611), reused 1607 (delta 1527), pack-reused 7288[K
Receiving objects: 100% (9106/9106), 155.60 MiB | 29.89 MiB/s, done.
Resolving deltas: 100% (5506/5506), done.


In [9]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [10]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

In [11]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(prompt["corr"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [12]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
}

In [13]:
def generate_prompts_list(x ,y):
    words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve', ' thirteen', ' fourteen', ' fifteen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty']
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': words[i],
            'S2': words[i+1],
            'S3': words[i+2],
            'S4': words[i+3],
            'corr': words[i+4],
            'incorr': words[i+3],  # this is arbitrary
            'text': f"{words[i]}{words[i+1]}{words[i+2]}{words[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(0, 16)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [14]:
# import random

# def generate_prompts_list_corr(x ,y):
#     words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve', ' thirteen', ' fourteen', ' fifteen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty']
#     prompts_list = []
#     for i in range(x, y):
#         r1 = random.choice(words)
#         r2 = random.choice(words)
#         while True:
#             r3_ind = random.randint(0,len(words)-1)
#             r4_ind = random.randint(0,len(words)-1)
#             if words[r3_ind] != words[r4_ind-1]:
#                 break
#         r3 = words[r3_ind]
#         r4 = words[r4_ind]
#         prompt_dict = {
#             'S1': str(r1),
#             'S2': str(r2),
#             'S3': str(r3),
#             'S4': str(r4),
#             'corr': str(r1),
#             'incorr': str(r4),
#             'text': f"{r1}{r2}{r3}{r4}"
#         }
#         prompts_list.append(prompt_dict)
#     return prompts_list

# prompts_list_2 = generate_prompts_list_corr(0, 16)
# dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

In [15]:
# prompts_list_2

In [16]:
prompts_list_2 = [{'S1': ' ten',
  'S2': ' two',
  'S3': ' thirteen',
  'S4': ' four',
  'corr': ' ten',
  'incorr': ' four',
  'text': ' ten two thirteen four'},
 {'S1': ' fifteen',
  'S2': ' seven',
  'S3': ' eight',
  'S4': ' eleven',
  'corr': ' fifteen',
  'incorr': ' eleven',
  'text': ' fifteen seven eight eleven'},
 {'S1': ' eight',
  'S2': ' eight',
  'S3': ' eight',
  'S4': ' one',
  'corr': ' eight',
  'incorr': ' one',
  'text': ' eight eight eight one'},
 {'S1': ' sixteen',
  'S2': ' eleven',
  'S3': ' thirteen',
  'S4': ' sixteen',
  'corr': ' sixteen',
  'incorr': ' sixteen',
  'text': ' sixteen eleven thirteen sixteen'},
 {'S1': ' eight',
  'S2': ' fifteen',
  'S3': ' three',
  'S4': ' twenty',
  'corr': ' eight',
  'incorr': ' twenty',
  'text': ' eight fifteen three twenty'},
 {'S1': ' fourteen',
  'S2': ' three',
  'S3': ' four',
  'S4': ' seven',
  'corr': ' fourteen',
  'incorr': ' seven',
  'text': ' fourteen three four seven'},
 {'S1': ' seventeen',
  'S2': ' twelve',
  'S3': ' nineteen',
  'S4': ' ten',
  'corr': ' seventeen',
  'incorr': ' ten',
  'text': ' seventeen twelve nineteen ten'},
 {'S1': ' ten',
  'S2': ' ten',
  'S3': ' six',
  'S4': ' three',
  'corr': ' ten',
  'incorr': ' three',
  'text': ' ten ten six three'},
 {'S1': ' nine',
  'S2': ' one',
  'S3': ' one',
  'S4': ' thirteen',
  'corr': ' nine',
  'incorr': ' thirteen',
  'text': ' nine one one thirteen'},
 {'S1': ' eleven',
  'S2': ' ten',
  'S3': ' four',
  'S4': ' eighteen',
  'corr': ' eleven',
  'incorr': ' eighteen',
  'text': ' eleven ten four eighteen'},
 {'S1': ' twenty',
  'S2': ' seven',
  'S3': ' twelve',
  'S4': ' fourteen',
  'corr': ' twenty',
  'incorr': ' fourteen',
  'text': ' twenty seven twelve fourteen'},
 {'S1': ' thirteen',
  'S2': ' eight',
  'S3': ' ten',
  'S4': ' one',
  'corr': ' thirteen',
  'incorr': ' one',
  'text': ' thirteen eight ten one'},
 {'S1': ' three',
  'S2': ' sixteen',
  'S3': ' seven',
  'S4': ' six',
  'corr': ' three',
  'incorr': ' six',
  'text': ' three sixteen seven six'},
 {'S1': ' twenty',
  'S2': ' one',
  'S3': ' sixteen',
  'S4': ' nine',
  'corr': ' twenty',
  'incorr': ' nine',
  'text': ' twenty one sixteen nine'},
 {'S1': ' nineteen',
  'S2': ' eight',
  'S3': ' thirteen',
  'S4': ' ten',
  'corr': ' nineteen',
  'incorr': ' ten',
  'text': ' nineteen eight thirteen ten'},
 {'S1': ' four',
  'S2': ' fourteen',
  'S3': ' three',
  'S4': ' one',
  'corr': ' four',
  'incorr': ' one',
  'text': ' four fourteen three one'}]

dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

# Ablation Expm Functions

In [17]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [18]:
def mean_ablate_by_lst(lst, model, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

We can also prevent redundant computation of the full circuit score by storing it and just passing it in to the function.

In [19]:
def find_circuit_forw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [20]:
def find_circuit_backw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

# iter backw fwd, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
98.4991455078125

Removed: (11, 1)
98.94969177246094

Removed: (11, 2)
99.69479370117188

Removed: (11, 3)
99.9080810546875

Removed: (11, 4)
101.00467681884766

Removed: (11, 5)
100.99285888671875

Removed: (11, 6)
101.09783935546875

Removed: (11, 7)
101.03890228271484

Removed: (11, 8)
100.28399658203125

Removed: (11, 9)
100.366455078125

Removed: (11, 10)
97.58168029785156

Removed: (11, 11)
98.09981536865234

Removed: (10, 0)
98.25293731689453

Removed: (10, 1)
100.01850128173828

Removed: (10, 2)
116.3747329711914

Removed: (10, 3)
117.53196716308594

Removed: (10, 4)
117.40097045898438

Removed: (10, 5)
117.49262237548828

Removed: (10, 6)
117.62872314453125

Removed: (10, 8)
117.43827056884766

Removed: (10, 9)
117.27503967285156

Removed: (10, 10)
117.79998016357422

Removed: (10, 11)
118.33454895019531

Removed: (9, 0)
118.6170883178711

Removed: (9, 2)
118.35466003417969

Removed: (9, 3)
118.36912536621094

Removed: (9, 4)
118.1252212

In [None]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (0, 6),
 (0, 7),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 5),
 (3, 3),
 (4, 4),
 (4, 10),
 (5, 4),
 (5, 6),
 (5, 8),
 (6, 6),
 (6, 10),
 (7, 6),
 (7, 10),
 (7, 11),
 (8, 8),
 (9, 1),
 (10, 7)]

In [None]:
len(bf_3)

21

## loop rmv and check for most impt heads

In [None]:
circ = bf_3
circ_score = mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 97.1837


In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 82.3840
removed: (0, 6)
Average logit difference (circuit / full) %: 94.6417
removed: (0, 7)
Average logit difference (circuit / full) %: 94.1867
removed: (0, 9)
Average logit difference (circuit / full) %: 90.7150
removed: (0, 10)
Average logit difference (circuit / full) %: 89.3442
removed: (1, 0)
Average logit difference (circuit / full) %: 95.3434
removed: (1, 5)
Average logit difference (circuit / full) %: 92.5516
removed: (3, 3)
Average logit difference (circuit / full) %: 93.0725
removed: (4, 4)
Average logit difference (circuit / full) %: 63.8274
removed: (4, 10)
Average logit difference (circuit / full) %: 92.0628
removed: (5, 4)
Average logit difference (circuit / full) %: 96.6786
removed: (5, 6)
Average logit difference (circuit / full) %: 93.8249
removed: (5, 8)
Average logit difference (circuit / full) %: 94.5884
removed: (6, 6)
Average logit difference (circuit / full) %: 94.0663
removed: (6, 10)
Average logit d

In [None]:
sorted_lh_scores = dict(sorted(lh_scores.items(), key=lambda item: item[1]))
sorted_lh_scores

{(9, 1): 51.14466094970703,
 (4, 4): 63.82744598388672,
 (10, 7): 77.2159194946289,
 (7, 11): 77.69091796875,
 (0, 1): 82.38401794433594,
 (8, 8): 86.17359161376953,
 (0, 10): 89.34420013427734,
 (0, 9): 90.71499633789062,
 (4, 10): 92.06282043457031,
 (6, 10): 92.093017578125,
 (1, 5): 92.55159759521484,
 (3, 3): 93.07250213623047,
 (5, 6): 93.82491302490234,
 (6, 6): 94.06629180908203,
 (0, 7): 94.1867446899414,
 (7, 10): 94.35436248779297,
 (5, 8): 94.58840942382812,
 (0, 6): 94.64169311523438,
 (7, 6): 94.70846557617188,
 (1, 0): 95.34339904785156,
 (5, 4): 96.67864227294922}

In [None]:
for lh, score in sorted_lh_scores.items():
    print(lh, -round(circ_score-score, 2))

(9, 1) -46.04
(4, 4) -33.36
(10, 7) -19.97
(7, 11) -19.49
(0, 1) -14.8
(8, 8) -11.01
(0, 10) -7.84
(0, 9) -6.47
(4, 10) -5.12
(6, 10) -5.09
(1, 5) -4.63
(3, 3) -4.11
(5, 6) -3.36
(6, 6) -3.12
(0, 7) -3.0
(7, 10) -2.83
(5, 8) -2.6
(0, 6) -2.54
(7, 6) -2.48
(1, 0) -1.84
(5, 4) -0.51


## try other tasks circs

### gt, IOI

In [None]:
# greater-than
circuit = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 13.4454


13.445417404174805

In [None]:
# IOI
circuit = [(0, 1), (0, 10), (2, 2), (3, 0), (4, 11), (5, 5), (5, 8), (5, 9), (6, 9), (7, 3), (7, 9), (8, 6), (8, 10), (9, 0), (9, 6), (9, 7), (9, 9), (10, 0), (10, 1), (10, 2), (10, 6), (10, 7), (10, 10), (11, 2), (11, 9), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 6.5866


6.586550235748291

### bf 97

In [None]:
# digits incr
circuit = [(0, 1), (0, 2), (0, 5), (0, 7), (0, 8), (0, 10), (1, 0), (1, 1), (1, 3), (1, 5), (1, 7), (1, 11), (2, 0), (2, 1), (2, 2), (2, 3), (2, 5), (2, 6), (2, 8), (2, 9), (2, 10), (3, 3), (3, 7), (3, 8), (3, 10), (3, 11), (4, 2), (4, 4), (4, 6), (4, 10), (4, 11), (5, 1), (5, 4), (5, 8), (5, 10), (5, 11), (6, 2), (6, 3), (6, 4), (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), (6, 11), (7, 11), (8, 6), (8, 8), (9, 1), (10, 7), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 98.7817


98.78166198730469

In [None]:
# numwords
circuit = [(0, 1), (0, 6), (0, 7), (0, 9), (0, 10), (1, 0), (1, 5), (3, 3), (4, 4), (4, 10), (5, 4), (5, 6), (5, 8), (6, 6), (6, 10), (7, 6), (7, 10), (7, 11), (8, 8), (9, 1), (10, 7)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 97.1837


97.18368530273438

In [None]:
# months
# incorr i
circuit = [(0, 1), (2, 3), (2, 5), (2, 7), (2, 8), (2, 9), (4, 4), (5, 0), (5, 6), (6, 9), (6, 10), (7, 8), (7, 11), (8, 1), (8, 6), (8, 8), (8, 9), (9, 1), (9, 7), (9, 11), (10, 7), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 70.5720


70.57201385498047

# iter backw fwd, threshold 20

In [None]:
threshold = 20
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
98.7055435180664

Removed: (11, 1)
98.90824127197266

Removed: (11, 2)
99.56969451904297

Removed: (11, 3)
99.80233001708984

Removed: (11, 4)
100.62274932861328

Removed: (11, 5)
100.61015319824219

Removed: (11, 6)
100.69552612304688

Removed: (11, 7)
100.64871978759766

Removed: (11, 8)
99.81178283691406

Removed: (11, 9)
99.73619079589844

Removed: (11, 10)
96.95419311523438

Removed: (11, 11)
97.49852752685547

Removed: (10, 0)
97.66179656982422

Removed: (10, 1)
99.14281463623047

Removed: (10, 2)
113.74662780761719

Removed: (10, 3)
114.81901550292969

Removed: (10, 4)
114.64929962158203

Removed: (10, 5)
114.65435028076172

Removed: (10, 6)
114.78545379638672

Removed: (10, 7)
94.1048583984375

Removed: (10, 8)
93.35231018066406

Removed: (10, 9)
93.20446014404297

Removed: (10, 10)
93.72042083740234

Removed: (10, 11)
94.05353546142578

Removed: (9, 0)
94.1783447265625

Removed: (9, 2)
93.95365905761719

Removed: (9, 3)
94.05609893798828

In [None]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (0, 9),
 (0, 10),
 (1, 5),
 (4, 4),
 (4, 7),
 (5, 6),
 (5, 8),
 (6, 1),
 (6, 6),
 (6, 10),
 (7, 5),
 (7, 6),
 (7, 10),
 (7, 11),
 (8, 7),
 (8, 8),
 (8, 10),
 (8, 11),
 (9, 1)]

In [None]:
len(bf_3)

20

## loop rmv and check for most impt heads

In [None]:
circ = bf_3
circ_score = mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 80.1992


In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 69.2554
removed: (0, 9)
Average logit difference (circuit / full) %: 75.2409
removed: (0, 10)
Average logit difference (circuit / full) %: 68.9665
removed: (1, 5)
Average logit difference (circuit / full) %: 67.9158
removed: (4, 4)
Average logit difference (circuit / full) %: 48.0506
removed: (4, 7)
Average logit difference (circuit / full) %: 78.1008
removed: (5, 6)
Average logit difference (circuit / full) %: 76.3769
removed: (5, 8)
Average logit difference (circuit / full) %: 76.8698
removed: (6, 1)
Average logit difference (circuit / full) %: 75.7282
removed: (6, 6)
Average logit difference (circuit / full) %: 76.9300
removed: (6, 10)
Average logit difference (circuit / full) %: 74.5152
removed: (7, 5)
Average logit difference (circuit / full) %: 79.4163
removed: (7, 6)
Average logit difference (circuit / full) %: 78.3269
removed: (7, 10)
Average logit difference (circuit / full) %: 79.7978
removed: (7, 11)
Average logit 

In [None]:
sorted_lh_scores = dict(sorted(lh_scores.items(), key=lambda item: item[1]))
sorted_lh_scores

{(9, 1): 34.479007720947266,
 (4, 4): 48.050559997558594,
 (7, 11): 59.81097412109375,
 (1, 5): 67.91581726074219,
 (0, 10): 68.96647644042969,
 (0, 1): 69.25543975830078,
 (8, 8): 69.53819274902344,
 (6, 10): 74.51518249511719,
 (0, 9): 75.24092864990234,
 (6, 1): 75.72819519042969,
 (8, 11): 76.19505310058594,
 (5, 6): 76.37689208984375,
 (5, 8): 76.86983489990234,
 (6, 6): 76.92996978759766,
 (4, 7): 78.10084533691406,
 (7, 6): 78.3268814086914,
 (8, 7): 78.79082489013672,
 (7, 5): 79.41632080078125,
 (7, 10): 79.79784393310547,
 (8, 10): 79.92594909667969}

In [None]:
for lh, score in sorted_lh_scores.items():
    print(lh, -round(circ_score-score, 2))

(9, 1) -45.72
(4, 4) -32.15
(7, 11) -20.39
(1, 5) -12.28
(0, 10) -11.23
(0, 1) -10.94
(8, 8) -10.66
(6, 10) -5.68
(0, 9) -4.96
(6, 1) -4.47
(8, 11) -4.0
(5, 6) -3.82
(5, 8) -3.33
(6, 6) -3.27
(4, 7) -2.1
(7, 6) -1.87
(8, 7) -1.41
(7, 5) -0.78
(7, 10) -0.4
(8, 10) -0.27


## try other tasks circs

### bf 80

In [None]:
# digits
circuit = [(0, 1), (0, 2), (0, 5), (0, 7), (0, 8), (0, 10), (1, 0), (1, 1), (1, 5), (1, 7), (1, 11), (2, 0), (2, 1), (2, 2), (2, 3), (2, 6), (2, 8), (2, 9), (2, 10), (2, 11), (3, 3), (3, 4), (3, 5), (3, 7), (3, 8), (3, 9), (3, 11), (4, 4), (4, 10), (5, 1), (5, 4), (5, 6), (5, 8), (5, 11), (6, 4), (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), (6, 11), (7, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 55.5454


55.54539108276367

In [None]:
# numwords
circuit = [(0, 1), (0, 9), (0, 10), (1, 5), (4, 4), (4, 7), (5, 6), (5, 8), (6, 1), (6, 6), (6, 10), (7, 5), (7, 6), (7, 10), (7, 11), (8, 7), (8, 8), (8, 10), (8, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 80.1992


80.1991958618164

In [22]:
# months
circuit = [(0, 1), (2, 2), (2, 9), (4, 4), (5, 0), (5, 1), (5, 4), (5, 6), (6, 6), (6, 9), (6, 10), (7, 7), (7, 11), (8, 8), (8, 9), (9, 1)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 50.2363


50.23625564575195

# try again using incorr logit i

In [None]:
def generate_prompts_list(x ,y):
    words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve', ' thirteen', ' fourteen', ' fifteen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty']
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': words[i],
            'S2': words[i+1],
            'S3': words[i+2],
            'S4': words[i+3],
            'corr': words[i+4],
            'incorr': words[i],  # this is arbitrary
            'text': f"{words[i]}{words[i+1]}{words[i+2]}{words[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(0, 16)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [None]:
model.reset_hooks(including_permanent=True)
ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)
orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)

## Ablation Expm Functions

In [None]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [None]:
model.reset_hooks(including_permanent=True)
ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)
orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)

In [None]:
def mean_ablate_by_lst(lst, model, orig_score, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        # "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        # "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    # ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    # orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

In [None]:
def find_circuit_forw(curr_circuit=None, orig_score=100, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [None]:
def find_circuit_backw(curr_circuit=None, orig_score=100, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

## iter backw fwd, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, orig_score=orig_score, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, orig_score=orig_score, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
98.98297882080078

Removed: (11, 1)
98.61659240722656

Removed: (11, 2)
98.79859924316406

Removed: (11, 3)
98.71098327636719

Removed: (11, 4)
98.97198486328125

Removed: (11, 5)
99.14219665527344

Removed: (11, 6)
99.1951675415039

Removed: (11, 7)
99.2170181274414

Removed: (11, 8)
97.68964385986328

Removed: (11, 9)
97.57614135742188

Removed: (11, 11)
97.53639221191406

Removed: (10, 0)
97.62326049804688

Removed: (10, 1)
97.38711547851562

Removed: (10, 3)
97.76373291015625

Removed: (10, 4)
97.93522644042969

Removed: (10, 5)
98.15135955810547

Removed: (10, 6)
98.24663543701172

Removed: (10, 7)
97.98377990722656

Removed: (10, 8)
97.54168701171875

Removed: (10, 9)
97.22599792480469

Removed: (10, 10)
97.47850799560547

Removed: (10, 11)
97.57877349853516

Removed: (9, 0)
97.62334442138672

Removed: (9, 2)
97.48445892333984

Removed: (9, 3)
97.49215698242188

Removed: (9, 4)
97.67486572265625

Removed: (9, 6)
97.63304138183594

Removed: 

In [None]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (0, 6),
 (0, 7),
 (0, 9),
 (0, 10),
 (3, 3),
 (3, 6),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 10),
 (5, 0),
 (5, 1),
 (5, 6),
 (6, 6),
 (6, 9),
 (6, 10),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 2),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (9, 5),
 (9, 11),
 (10, 2),
 (11, 10)]

In [None]:
len(bf_3)

30

## loop rmv and check for most impt heads

In [None]:
circ = bf_3
circ_score = mean_ablate_by_lst(circ, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 97.0740


In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 85.7242
removed: (0, 6)
Average logit difference (circuit / full) %: 93.1335
removed: (0, 7)
Average logit difference (circuit / full) %: 94.5368
removed: (0, 9)
Average logit difference (circuit / full) %: 93.0598
removed: (0, 10)
Average logit difference (circuit / full) %: 94.0437
removed: (3, 3)
Average logit difference (circuit / full) %: 94.2309
removed: (3, 6)
Average logit difference (circuit / full) %: 96.0041
removed: (4, 4)
Average logit difference (circuit / full) %: 67.7054
removed: (4, 6)
Average logit difference (circuit / full) %: 96.6399
removed: (4, 7)
Average logit difference (circuit / full) %: 96.1064
removed: (4, 10)
Average logit difference (circuit / full) %: 93.0666
removed: (5, 0)
Average logit difference (circuit / full) %: 95.5455
removed: (5, 1)
Average logit difference (circuit / full) %: 96.6853
removed: (5, 6)
Average logit difference (circuit / full) %: 91.4235
removed: (6, 6)
Average logit di

In [None]:
sorted_lh_scores = dict(sorted(lh_scores.items(), key=lambda item: item[1]))
sorted_lh_scores

{(9, 1): 57.387290954589844,
 (4, 4): 67.70535278320312,
 (7, 11): 81.2074203491211,
 (0, 1): 85.72423553466797,
 (8, 11): 86.89118194580078,
 (10, 2): 87.67463684082031,
 (8, 8): 90.22347259521484,
 (5, 6): 91.42350769042969,
 (6, 10): 92.093505859375,
 (7, 10): 92.38101959228516,
 (0, 9): 93.05979919433594,
 (4, 10): 93.06657409667969,
 (0, 6): 93.133544921875,
 (8, 6): 93.2705307006836,
 (0, 10): 94.04373168945312,
 (3, 3): 94.23091888427734,
 (6, 9): 94.27133178710938,
 (0, 7): 94.53684997558594,
 (5, 0): 95.54553985595703,
 (6, 6): 95.5802230834961,
 (11, 10): 95.73572540283203,
 (3, 6): 96.00411224365234,
 (8, 1): 96.10184478759766,
 (4, 7): 96.10635375976562,
 (8, 0): 96.21290588378906,
 (9, 5): 96.3659896850586,
 (9, 11): 96.5845947265625,
 (8, 2): 96.63079071044922,
 (4, 6): 96.6398696899414,
 (5, 1): 96.68527221679688}

In [None]:
for lh, score in sorted_lh_scores.items():
    print(lh, -round(circ_score-score, 2))

(9, 1) -39.69
(4, 4) -29.37
(7, 11) -15.87
(0, 1) -11.35
(8, 11) -10.18
(10, 2) -9.4
(8, 8) -6.85
(5, 6) -5.65
(6, 10) -4.98
(7, 10) -4.69
(0, 9) -4.01
(4, 10) -4.01
(0, 6) -3.94
(8, 6) -3.8
(0, 10) -3.03
(3, 3) -2.84
(6, 9) -2.8
(0, 7) -2.54
(5, 0) -1.53
(6, 6) -1.49
(11, 10) -1.34
(3, 6) -1.07
(8, 1) -0.97
(4, 7) -0.97
(8, 0) -0.86
(9, 5) -0.71
(9, 11) -0.49
(8, 2) -0.44
(4, 6) -0.43
(5, 1) -0.39


## try other tasks circs

### gt, IOI

In [None]:
# greater-than
circuit = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 38.2239


38.22393035888672

In [None]:
# IOI
circuit = [(0, 1), (0, 10), (2, 2), (3, 0), (4, 11), (5, 5), (5, 8), (5, 9), (6, 9), (7, 3), (7, 9), (8, 6), (8, 10), (9, 0), (9, 6), (9, 7), (9, 9), (10, 0), (10, 1), (10, 2), (10, 6), (10, 7), (10, 10), (11, 2), (11, 9), (11, 10)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 9.2381


9.23812198638916

### bf 80

In [None]:
# # fb 80, digits incr
# # https://colab.research.google.com/drive/1mFWmGAKtigFcqqWWMCwU7wWQY2HT5ZOo#scrollTo=lJEY-Zs2g_a5&line=1&uniqifier=1
# circuit = [(1, 5), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (6, 3), (6, 8), (6, 10), (7, 0), (7, 2), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (11, 8)]
# mean_ablate_by_lst(circuit, model, print_output=True).item()

In [None]:
# # fb 80, numwords
# # https://colab.research.google.com/drive/1QTv-4osLHadCAay0beew-xlXszPCG88s#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
# circuit = [(3, 2), (4, 4), (4, 8), (4, 10), (4, 11), (5, 5), (5, 6), (5, 7), (5, 8), (6, 1), (6, 7), (6, 9), (6, 10), (7, 0), (7, 2), (7, 5), (7, 6), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (10, 2)]
# mean_ablate_by_lst(circuit, model, print_output=True).item()

In [None]:
# # fb 80, months
# # https://colab.research.google.com/drive/1lhQqlizYGMC11vzp6I9mJ3dyxIr8tV3l#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
# circuit = [(4, 4), (7, 11), (8, 6), (8, 9), (8, 11), (9, 1), (9, 5), (11, 10)]
# mean_ablate_by_lst(circuit, model, print_output=True).item()

### bf 97

In [None]:
# digits incr
# incorr i+3
circuit = [(0, 1), (0, 2), (0, 5), (0, 7), (0, 8), (0, 10), (1, 0), (1, 1), (1, 3), (1, 5), (1, 7), (1, 11), (2, 0), (2, 1), (2, 2), (2, 3), (2, 5), (2, 6), (2, 8), (2, 9), (2, 10), (3, 3), (3, 7), (3, 8), (3, 10), (3, 11), (4, 2), (4, 4), (4, 6), (4, 10), (4, 11), (5, 1), (5, 4), (5, 8), (5, 10), (5, 11), (6, 2), (6, 3), (6, 4), (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), (6, 11), (7, 11), (8, 6), (8, 8), (9, 1), (10, 7), (11, 10)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 60.5607


60.560691833496094

In [None]:
# numwords
# incorr i+3
circuit = [(0, 1), (0, 6), (0, 7), (0, 9), (0, 10), (1, 0), (1, 5), (3, 3), (4, 4), (4, 10), (5, 4), (5, 6), (5, 8), (6, 6), (6, 10), (7, 6), (7, 10), (7, 11), (8, 8), (9, 1), (10, 7)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 64.2862


64.28623962402344

In [None]:
# months
# incorr i
circuit = [(0, 1), (2, 3), (2, 5), (2, 7), (2, 8), (2, 9), (4, 4), (5, 0), (5, 6), (6, 9), (6, 10), (7, 8), (7, 11), (8, 1), (8, 6), (8, 8), (8, 9), (9, 1), (9, 7), (9, 11), (10, 7), (11, 10)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 53.6491


53.649051666259766

In [None]:
# digits
# incorr i
circuit = [(0, 1), (2, 3), (2, 5), (2, 7), (2, 8), (2, 9), (4, 4), (5, 0), (5, 6), (6, 9), (6, 10), (7, 8), (7, 11), (8, 1), (8, 6), (8, 8), (8, 9), (9, 1), (9, 7), (9, 11), (10, 7), (11, 10)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 53.6491


53.649051666259766