# Setup
(No need to change anything)

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-n2fbz_qh
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-n2fbz_qh
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [None]:
# # Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
# import plotly.io as pio

# if IN_COLAB or not DEBUG_MODE:
#     # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
#     pio.renderers.default = "colab"
# else:
#     pio.renderers.default = "png"

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7871730af850>

Plotting helper functions:

In [None]:
# def imshow(tensor, renderer=None, **kwargs):
#     px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

# def line(tensor, renderer=None, **kwargs):
#     px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

# def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
#     x = utils.to_numpy(x)
#     y = utils.to_numpy(y)
#     px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [None]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9165, done.[K
remote: Counting objects: 100% (1877/1877), done.[K
remote: Compressing objects: 100% (319/319), done.[K
remote: Total 9165 (delta 1649), reused 1655 (delta 1554), pack-reused 7288[K
Receiving objects: 100% (9165/9165), 156.33 MiB | 23.19 MiB/s, done.
Resolving deltas: 100% (5544/5544), done.


In [None]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [None]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

In [None]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [None]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
}

In [None]:
def generate_prompts_list(x ,y):
    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    prompts_list = []
    # for i in range(x, y):
    for i in range(0, 8):
        prompt_dict = {
            'S1': months[i],
            'S2': months[i+1],
            'S3': months[i+2],
            'S4': months[i+3],
            'corr': months[i+4],
            'incorr': months[i+3],
            'text': f"{months[i]} {months[i+1]} {months[i+2]} {months[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 11)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [None]:
import random

def generate_prompts_list_corr(x ,y):
    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    prompts_list = []
    # for i in range(x, y):
    for i in range(0, 8):
        r1 = random.choice(months)
        r2 = random.choice(months)
        while True:
            r3_ind = random.randint(0,len(months)-1)
            r4_ind = random.randint(0,len(months)-1)
            if months[r3_ind] != months[r4_ind-1]:
                break
        r3 = months[r3_ind]
        r4 = months[r4_ind]
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': months[i+4],
            'incorr': months[i+3],
            'text': f"{r1} {r2} {r3} {r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(1, 11)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)
prompts_list_2

[{'S1': 'August',
  'S2': 'December',
  'S3': 'January',
  'S4': 'May',
  'corr': 'May',
  'incorr': 'April',
  'text': 'August December January May'},
 {'S1': 'December',
  'S2': 'April',
  'S3': 'July',
  'S4': 'May',
  'corr': 'June',
  'incorr': 'May',
  'text': 'December April July May'},
 {'S1': 'March',
  'S2': 'December',
  'S3': 'February',
  'S4': 'September',
  'corr': 'July',
  'incorr': 'June',
  'text': 'March December February September'},
 {'S1': 'May',
  'S2': 'November',
  'S3': 'July',
  'S4': 'October',
  'corr': 'August',
  'incorr': 'July',
  'text': 'May November July October'},
 {'S1': 'July',
  'S2': 'June',
  'S3': 'February',
  'S4': 'June',
  'corr': 'September',
  'incorr': 'August',
  'text': 'July June February June'},
 {'S1': 'November',
  'S2': 'May',
  'S3': 'September',
  'S4': 'June',
  'corr': 'October',
  'incorr': 'September',
  'text': 'November May September June'},
 {'S1': 'July',
  'S2': 'July',
  'S3': 'April',
  'S4': 'February',
  'corr': '

In [None]:
prompts_list_2 = [{'S1': 'October',
  'S2': 'July',
  'S3': 'February',
  'S4': 'May',
  'corr': 'May',
  'incorr': 'April',
  'text': 'October July February May'},
 {'S1': 'August',
  'S2': 'March',
  'S3': 'March',
  'S4': 'March',
  'corr': 'June',
  'incorr': 'May',
  'text': 'August March March March'},
 {'S1': 'May',
  'S2': 'August',
  'S3': 'October',
  'S4': 'July',
  'corr': 'July',
  'incorr': 'June',
  'text': 'May August October July'},
 {'S1': 'October',
  'S2': 'April',
  'S3': 'February',
  'S4': 'February',
  'corr': 'August',
  'incorr': 'July',
  'text': 'October April February February'},
 {'S1': 'April',
  'S2': 'March',
  'S3': 'June',
  'S4': 'September',
  'corr': 'September',
  'incorr': 'August',
  'text': 'April March June September'},
 {'S1': 'April',
  'S2': 'March',
  'S3': 'March',
  'S4': 'August',
  'corr': 'October',
  'incorr': 'September',
  'text': 'April March March August'},
 {'S1': 'August',
  'S2': 'February',
  'S3': 'December',
  'S4': 'September',
  'corr': 'November',
  'incorr': 'October',
  'text': 'August February December September'},
 {'S1': 'November',
  'S2': 'August',
  'S3': 'March',
  'S4': 'December',
  'corr': 'December',
  'incorr': 'November',
  'text': 'November August March December'}]
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

# Ablation Expm Functions

In [None]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [None]:
model.reset_hooks(including_permanent=True)
ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)
orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)

In [None]:
def mean_ablate_by_lst(lst, model, orig_score, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        # "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        # "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    # ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    # orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

In [None]:
def find_circuit_forw(curr_circuit=None, orig_score=100, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [None]:
def find_circuit_backw(curr_circuit=None, orig_score=100, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

# iter backw fwd, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, orig_score=orig_score, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, orig_score=orig_score, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
99.86515808105469

Removed: (11, 1)
99.70606231689453

Removed: (11, 2)
99.49992370605469

Removed: (11, 3)
99.38542938232422

Removed: (11, 4)
99.53096008300781

Removed: (11, 5)
99.47246551513672

Removed: (11, 6)
98.66363525390625

Removed: (11, 7)
98.82321166992188

Removed: (11, 9)
98.42066955566406

Removed: (11, 10)
97.1253662109375

Removed: (11, 11)
97.68685150146484

Removed: (10, 0)
97.6208724975586

Removed: (10, 1)
97.27182006835938

Removed: (10, 4)
97.25872039794922

Removed: (10, 5)
97.16763305664062

Removed: (10, 7)
98.66638946533203

Removed: (10, 8)
98.82259368896484

Removed: (10, 9)
98.74352264404297

Removed: (10, 10)
97.57072448730469

Removed: (10, 11)
97.70975494384766

Removed: (9, 0)
97.74153137207031

Removed: (9, 1)
105.21685791015625

Removed: (9, 2)
104.90479278564453

Removed: (9, 3)
103.66998291015625

Removed: (9, 4)
103.63906860351562

Removed: (9, 5)
110.19441223144531

Removed: (9, 6)
109.9915542602539

Remov

In [None]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (0, 5),
 (1, 0),
 (1, 5),
 (2, 2),
 (2, 7),
 (2, 9),
 (3, 7),
 (3, 10),
 (4, 4),
 (5, 0),
 (5, 1),
 (5, 4),
 (5, 8),
 (6, 0),
 (6, 1),
 (6, 5),
 (6, 9),
 (6, 10),
 (7, 10),
 (7, 11),
 (8, 6),
 (8, 8),
 (8, 11),
 (10, 2),
 (10, 3),
 (11, 8)]

In [None]:
len(bf_3)

27

## loop rmv and check for most impt heads

In [None]:
circ = bf_3
circ_score = mean_ablate_by_lst(circ, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 97.0222


In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 90.5230
removed: (0, 5)
Average logit difference (circuit / full) %: 95.8878
removed: (1, 0)
Average logit difference (circuit / full) %: 95.7988
removed: (1, 5)
Average logit difference (circuit / full) %: 92.9287
removed: (2, 2)
Average logit difference (circuit / full) %: 95.9326
removed: (2, 7)
Average logit difference (circuit / full) %: 96.7999
removed: (2, 9)
Average logit difference (circuit / full) %: 96.4289
removed: (3, 7)
Average logit difference (circuit / full) %: 95.0207
removed: (3, 10)
Average logit difference (circuit / full) %: 96.1482
removed: (4, 4)
Average logit difference (circuit / full) %: 41.1407
removed: (5, 0)
Average logit difference (circuit / full) %: 93.3619
removed: (5, 1)
Average logit difference (circuit / full) %: 96.0624
removed: (5, 4)
Average logit difference (circuit / full) %: 96.5859
removed: (5, 8)
Average logit difference (circuit / full) %: 90.8191
removed: (6, 0)
Average logit dif

In [None]:
sorted_lh_scores = dict(sorted(lh_scores.items(), key=lambda item: item[1]))
sorted_lh_scores

{(4, 4): 41.14067840576172,
 (6, 10): 81.43862915039062,
 (7, 11): 82.63829803466797,
 (8, 11): 84.38214111328125,
 (0, 1): 90.52302551269531,
 (8, 6): 90.76065826416016,
 (5, 8): 90.81907653808594,
 (7, 10): 91.15576171875,
 (11, 8): 92.72691345214844,
 (8, 8): 92.82693481445312,
 (1, 5): 92.92874145507812,
 (5, 0): 93.36192321777344,
 (6, 9): 94.63748168945312,
 (3, 7): 95.02069091796875,
 (6, 1): 95.37979888916016,
 (10, 2): 95.4301986694336,
 (1, 0): 95.79875183105469,
 (0, 5): 95.88780212402344,
 (2, 2): 95.9326400756836,
 (10, 3): 96.03795623779297,
 (5, 1): 96.06241607666016,
 (3, 10): 96.148193359375,
 (6, 5): 96.29025268554688,
 (2, 9): 96.42891693115234,
 (5, 4): 96.58585357666016,
 (6, 0): 96.64701080322266,
 (2, 7): 96.79991149902344}

In [None]:
for lh, score in sorted_lh_scores.items():
    print(lh, -round(circ_score-score, 2))

(4, 4) -55.88
(6, 10) -15.58
(7, 11) -14.38
(8, 11) -12.64
(0, 1) -6.5
(8, 6) -6.26
(5, 8) -6.2
(7, 10) -5.87
(11, 8) -4.3
(8, 8) -4.2
(1, 5) -4.09
(5, 0) -3.66
(6, 9) -2.38
(3, 7) -2.0
(6, 1) -1.64
(10, 2) -1.59
(1, 0) -1.22
(0, 5) -1.13
(2, 2) -1.09
(10, 3) -0.98
(5, 1) -0.96
(3, 10) -0.87
(6, 5) -0.73
(2, 9) -0.59
(5, 4) -0.44
(6, 0) -0.38
(2, 7) -0.22


# try other tasks circs

## iter backw fwd, threshold 20

In [None]:
threshold = 20
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, orig_score=orig_score, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, orig_score=orig_score, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
99.86518859863281

Removed: (11, 1)
99.7060546875

Removed: (11, 2)
99.49995422363281

Removed: (11, 3)
99.38541412353516

Removed: (11, 4)
99.53091430664062

Removed: (11, 5)
99.47248840332031

Removed: (11, 6)
98.66361999511719

Removed: (11, 7)
98.82323455810547

Removed: (11, 8)
95.578125

Removed: (11, 9)
95.16633605957031

Removed: (11, 10)
93.87151336669922

Removed: (11, 11)
94.37239074707031

Removed: (10, 0)
94.30249786376953

Removed: (10, 1)
94.0016098022461

Removed: (10, 2)
92.67064666748047

Removed: (10, 3)
91.27398681640625

Removed: (10, 4)
91.2550048828125

Removed: (10, 5)
91.16800689697266

Removed: (10, 6)
90.92051696777344

Removed: (10, 7)
93.1395492553711

Removed: (10, 8)
93.26133728027344

Removed: (10, 9)
93.16360473632812

Removed: (10, 10)
92.06661987304688

Removed: (10, 11)
92.10260009765625

Removed: (9, 0)
92.12821960449219

Removed: (9, 1)
96.14894104003906

Removed: (9, 2)
95.83226013183594

Removed: (9, 3)
94.

In [None]:
bf_20 = curr_circuit.copy()
bf_20

[(0, 1),
 (0, 5),
 (1, 5),
 (2, 2),
 (2, 3),
 (2, 5),
 (2, 9),
 (3, 7),
 (3, 10),
 (4, 4),
 (5, 0),
 (5, 1),
 (5, 8),
 (6, 1),
 (6, 9),
 (6, 10),
 (7, 10),
 (7, 11),
 (8, 11)]

In [None]:
len(bf_20)

19

### loop rmv and check for most impt heads

In [None]:
circ = bf_20
circ_score = mean_ablate_by_lst(circ, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 80.0051


In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 74.3332
removed: (0, 5)
Average logit difference (circuit / full) %: 78.9162
removed: (1, 5)
Average logit difference (circuit / full) %: 76.7735
removed: (2, 2)
Average logit difference (circuit / full) %: 78.9474
removed: (2, 3)
Average logit difference (circuit / full) %: 79.7928
removed: (2, 5)
Average logit difference (circuit / full) %: 79.8287
removed: (2, 9)
Average logit difference (circuit / full) %: 79.5261
removed: (3, 7)
Average logit difference (circuit / full) %: 78.4577
removed: (3, 10)
Average logit difference (circuit / full) %: 79.4427
removed: (4, 4)
Average logit difference (circuit / full) %: 35.9684
removed: (5, 0)
Average logit difference (circuit / full) %: 76.7337
removed: (5, 1)
Average logit difference (circuit / full) %: 78.9997
removed: (5, 8)
Average logit difference (circuit / full) %: 76.6089
removed: (6, 1)
Average logit difference (circuit / full) %: 75.3679
removed: (6, 9)
Average logit dif

In [None]:
sorted_lh_scores = dict(sorted(lh_scores.items(), key=lambda item: item[1]))
sorted_lh_scores

{(4, 4): 35.96836853027344,
 (7, 11): 65.2315673828125,
 (8, 11): 65.68665313720703,
 (6, 10): 67.38734436035156,
 (7, 10): 71.73152160644531,
 (0, 1): 74.33316040039062,
 (6, 1): 75.36791229248047,
 (5, 8): 76.60887145996094,
 (5, 0): 76.73367309570312,
 (1, 5): 76.7734603881836,
 (6, 9): 77.81916809082031,
 (3, 7): 78.45770263671875,
 (0, 5): 78.91622924804688,
 (2, 2): 78.94742584228516,
 (5, 1): 78.9997329711914,
 (3, 10): 79.4427490234375,
 (2, 9): 79.52606201171875,
 (2, 3): 79.79278564453125,
 (2, 5): 79.82872772216797}

In [None]:
for lh, score in sorted_lh_scores.items():
    print(lh, -round(circ_score-score, 2))

(4, 4) -44.04
(7, 11) -14.77
(8, 11) -14.32
(6, 10) -12.62
(7, 10) -8.27
(0, 1) -5.67
(6, 1) -4.64
(5, 8) -3.4
(5, 0) -3.27
(1, 5) -3.23
(6, 9) -2.19
(3, 7) -1.55
(0, 5) -1.09
(2, 2) -1.06
(5, 1) -1.01
(3, 10) -0.56
(2, 9) -0.48
(2, 3) -0.21
(2, 5) -0.18


# try again using incorr logit i

In [None]:
def generate_prompts_list(x ,y):
    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    prompts_list = []
    # for i in range(x, y):
    for i in range(0, 8):
        prompt_dict = {
            'S1': months[i],
            'S2': months[i+1],
            'S3': months[i+2],
            'S4': months[i+3],
            'corr': months[i+4],
            'incorr': months[i],
            'text': f"{months[i]} {months[i+1]} {months[i+2]} {months[i+2]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 11)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [None]:
model.reset_hooks(including_permanent=True)
ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)
orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)

## after ipp

In [None]:
# bf 80 after rmv 8,9 and 2,9

circuit = [(0, 1), (2, 2), (4, 4), (5, 0), (5, 1), (5, 4), (5, 6), (6, 6), (6, 9), (6, 10), (7, 7), (7, 11), (8, 8), (9, 1)]

mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 76.4315


76.43152618408203

In [None]:
circuit = [(0, 1), (2, 2), (2, 9), (4, 4), (5, 0), (5, 1), (5, 4), (5, 6), (6, 6), (6, 9), (6, 10), (7, 7), (7, 11), (8, 8), (8, 9), (9, 1)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 80.0188


80.01876068115234

## iter backw fwd, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, orig_score=orig_score, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, orig_score=orig_score, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
100.27477264404297

Removed: (11, 1)
100.23351287841797

Removed: (11, 2)
100.46318817138672

Removed: (11, 3)
100.17131042480469

Removed: (11, 4)
100.55448150634766

Removed: (11, 5)
100.61019134521484

Removed: (11, 6)
101.81040954589844

Removed: (11, 7)
101.80139923095703

Removed: (11, 8)
98.91886138916016

Removed: (11, 9)
98.89364624023438

Removed: (11, 11)
99.15338134765625

Removed: (10, 0)
99.13613891601562

Removed: (10, 1)
100.12897491455078

Removed: (10, 2)
100.70376586914062

Removed: (10, 3)
101.1800308227539

Removed: (10, 4)
101.31055450439453

Removed: (10, 5)
101.12251281738281

Removed: (10, 6)
100.38468933105469

Removed: (10, 8)
100.62175750732422

Removed: (10, 9)
100.5606689453125

Removed: (10, 10)
99.8133544921875

Removed: (10, 11)
99.99896240234375

Removed: (9, 0)
100.01261901855469

Removed: (9, 2)
99.6978530883789

Removed: (9, 3)
98.97310638427734

Removed: (9, 4)
99.01006317138672

Removed: (9, 5)
98.3266983032

In [None]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (2, 3),
 (2, 5),
 (2, 7),
 (2, 8),
 (2, 9),
 (4, 4),
 (5, 0),
 (5, 6),
 (6, 9),
 (6, 10),
 (7, 8),
 (7, 11),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 9),
 (9, 1),
 (9, 7),
 (9, 11),
 (10, 7),
 (11, 10)]

In [None]:
len(bf_3)

22

### loop rmv and check for most impt heads

In [None]:
circ = bf_3
circ_score = mean_ablate_by_lst(circ, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 97.0229


In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 89.3814
removed: (2, 3)
Average logit difference (circuit / full) %: 96.9478
removed: (2, 5)
Average logit difference (circuit / full) %: 96.8179
removed: (2, 7)
Average logit difference (circuit / full) %: 96.6588
removed: (2, 8)
Average logit difference (circuit / full) %: 96.6361
removed: (2, 9)
Average logit difference (circuit / full) %: 96.5452
removed: (4, 4)
Average logit difference (circuit / full) %: 24.7081
removed: (5, 0)
Average logit difference (circuit / full) %: 93.0373
removed: (5, 6)
Average logit difference (circuit / full) %: 94.2867
removed: (6, 9)
Average logit difference (circuit / full) %: 94.3910
removed: (6, 10)
Average logit difference (circuit / full) %: 88.5307
removed: (7, 8)
Average logit difference (circuit / full) %: 96.6392
removed: (7, 11)
Average logit difference (circuit / full) %: 82.5906
removed: (8, 1)
Average logit difference (circuit / full) %: 93.5788
removed: (8, 6)
Average logit di

In [None]:
sorted_lh_scores = dict(sorted(lh_scores.items(), key=lambda item: item[1]))
sorted_lh_scores

{(4, 4): 24.708080291748047,
 (9, 1): 35.33229064941406,
 (7, 11): 82.59058380126953,
 (6, 10): 88.53067016601562,
 (0, 1): 89.3814468383789,
 (10, 7): 91.84457397460938,
 (8, 6): 92.03407287597656,
 (5, 0): 93.03734588623047,
 (8, 1): 93.57878875732422,
 (8, 9): 93.63392639160156,
 (9, 11): 94.2775650024414,
 (5, 6): 94.28670501708984,
 (6, 9): 94.3909683227539,
 (9, 7): 94.74427032470703,
 (8, 8): 94.98773193359375,
 (11, 10): 95.5536117553711,
 (2, 9): 96.54518127441406,
 (2, 8): 96.63613891601562,
 (7, 8): 96.63919067382812,
 (2, 7): 96.65879821777344,
 (2, 5): 96.81787109375,
 (2, 3): 96.94779205322266}

In [None]:
for lh, score in sorted_lh_scores.items():
    print(lh, -round(circ_score-score, 2))

(4, 4) -72.31
(9, 1) -61.69
(7, 11) -14.43
(6, 10) -8.49
(0, 1) -7.64
(10, 7) -5.18
(8, 6) -4.99
(5, 0) -3.99
(8, 1) -3.44
(8, 9) -3.39
(9, 11) -2.75
(5, 6) -2.74
(6, 9) -2.63
(9, 7) -2.28
(8, 8) -2.04
(11, 10) -1.47
(2, 9) -0.48
(2, 8) -0.39
(7, 8) -0.38
(2, 7) -0.36
(2, 5) -0.21
(2, 3) -0.08


## iter backw fwd, threshold 20

In [None]:
threshold = 20
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, orig_score=orig_score, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, orig_score=orig_score, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
99.81829833984375

Removed: (11, 1)
99.77723693847656

Removed: (11, 2)
100.00584411621094

Removed: (11, 3)
99.71531677246094

Removed: (11, 4)
100.09679412841797

Removed: (11, 5)
100.15222930908203

Removed: (11, 6)
101.34697723388672

Removed: (11, 7)
101.33797454833984

Removed: (11, 8)
98.46856689453125

Removed: (11, 9)
98.4434814453125

Removed: (11, 10)
96.54823303222656

Removed: (11, 11)
96.78971862792969

Removed: (10, 0)
96.77925109863281

Removed: (10, 1)
97.79193115234375

Removed: (10, 2)
98.43711853027344

Removed: (10, 3)
98.9050064086914

Removed: (10, 4)
99.03523254394531

Removed: (10, 5)
98.84487915039062

Removed: (10, 6)
98.0883560180664

Removed: (10, 7)
93.26271057128906

Removed: (10, 8)
93.51444244384766

Removed: (10, 9)
93.46068572998047

Removed: (10, 10)
92.99003601074219

Removed: (10, 11)
93.23729705810547

Removed: (9, 0)
93.25186920166016

Removed: (9, 2)
92.9366683959961

Removed: (9, 3)
92.23115539550781

Rem

In [None]:
bf_20 = curr_circuit.copy()
bf_20

[(0, 1),
 (2, 2),
 (2, 9),
 (4, 4),
 (5, 0),
 (5, 1),
 (5, 4),
 (5, 6),
 (6, 6),
 (6, 9),
 (6, 10),
 (7, 7),
 (7, 11),
 (8, 8),
 (8, 9),
 (9, 1)]

In [None]:
len(bf_20)

16

### loop rmv and check for most impt heads

In [None]:
circ = bf_20
circ_score = mean_ablate_by_lst(circ, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 80.0188


In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 73.6033
removed: (2, 2)
Average logit difference (circuit / full) %: 78.1361
removed: (2, 9)
Average logit difference (circuit / full) %: 79.6444
removed: (4, 4)
Average logit difference (circuit / full) %: 20.8427
removed: (5, 0)
Average logit difference (circuit / full) %: 75.6994
removed: (5, 1)
Average logit difference (circuit / full) %: 79.5243
removed: (5, 4)
Average logit difference (circuit / full) %: 79.9790
removed: (5, 6)
Average logit difference (circuit / full) %: 77.3481
removed: (6, 6)
Average logit difference (circuit / full) %: 79.4478
removed: (6, 9)
Average logit difference (circuit / full) %: 77.0189
removed: (6, 10)
Average logit difference (circuit / full) %: 71.9496
removed: (7, 7)
Average logit difference (circuit / full) %: 78.5723
removed: (7, 11)
Average logit difference (circuit / full) %: 65.4579
removed: (8, 8)
Average logit difference (circuit / full) %: 76.9234
removed: (8, 9)
Average logit di

In [None]:
sorted_lh_scores = dict(sorted(lh_scores.items(), key=lambda item: item[1]))
sorted_lh_scores

{(4, 4): 20.842721939086914,
 (9, 1): 21.033855438232422,
 (7, 11): 65.4579086303711,
 (6, 10): 71.94964599609375,
 (0, 1): 73.60332489013672,
 (5, 0): 75.69940185546875,
 (8, 9): 76.79069519042969,
 (8, 8): 76.9233627319336,
 (6, 9): 77.0189208984375,
 (5, 6): 77.34809875488281,
 (2, 2): 78.13613891601562,
 (7, 7): 78.5722885131836,
 (6, 6): 79.44776916503906,
 (5, 1): 79.52426147460938,
 (2, 9): 79.64443969726562,
 (5, 4): 79.97897338867188}

In [None]:
for lh, score in sorted_lh_scores.items():
    print(lh, -round(circ_score-score, 2))

(4, 4) -59.18
(9, 1) -58.98
(7, 11) -14.56
(6, 10) -8.07
(0, 1) -6.42
(5, 0) -4.32
(8, 9) -3.23
(8, 8) -3.1
(6, 9) -3.0
(5, 6) -2.67
(2, 2) -1.88
(7, 7) -1.45
(6, 6) -0.57
(5, 1) -0.49
(2, 9) -0.37
(5, 4) -0.04


## try other tasks circs

### gt, IOI

In [None]:
# greater-than
circuit = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 27.5909


27.590892791748047

In [None]:
# IOI
circuit = [(0, 1), (0, 10), (2, 2), (3, 0), (4, 11), (5, 5), (5, 8), (5, 9), (6, 9), (7, 3), (7, 9), (8, 6), (8, 10), (9, 0), (9, 6), (9, 7), (9, 9), (10, 0), (10, 1), (10, 2), (10, 6), (10, 7), (10, 10), (11, 2), (11, 9), (11, 10)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 4.3136


4.313619613647461

### bf 80

In [None]:
# digits
circuit = [(0, 1), (0, 2), (0, 5), (0, 7), (0, 8), (0, 10), (1, 0), (1, 1), (1, 5), (1, 7), (1, 11), (2, 0), (2, 1), (2, 2), (2, 3), (2, 6), (2, 8), (2, 9), (2, 10), (2, 11), (3, 3), (3, 4), (3, 5), (3, 7), (3, 8), (3, 9), (3, 11), (4, 4), (4, 10), (5, 1), (5, 4), (5, 6), (5, 8), (5, 11), (6, 4), (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), (6, 11), (7, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 69.7072


69.70720672607422

In [None]:
# numwords
circuit = [(0, 1), (0, 9), (0, 10), (1, 5), (4, 4), (4, 7), (5, 6), (5, 8), (6, 1), (6, 6), (6, 10), (7, 5), (7, 6), (7, 10), (7, 11), (8, 7), (8, 8), (8, 10), (8, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 70.6112


70.6111831665039

In [None]:
# months
circuit = [(0, 1), (2, 2), (2, 9), (4, 4), (5, 0), (5, 1), (5, 4), (5, 6), (6, 6), (6, 9), (6, 10), (7, 7), (7, 11), (8, 8), (8, 9), (9, 1)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 80.0188


80.01876831054688

### bf 97

In [None]:
# digits incr
# incorr i+3
circuit = [(0, 1), (0, 2), (0, 5), (0, 7), (0, 8), (0, 10), (1, 0), (1, 1), (1, 3), (1, 5), (1, 7), (1, 11), (2, 0), (2, 1), (2, 2), (2, 3), (2, 5), (2, 6), (2, 8), (2, 9), (2, 10), (3, 3), (3, 7), (3, 8), (3, 10), (3, 11), (4, 2), (4, 4), (4, 6), (4, 10), (4, 11), (5, 1), (5, 4), (5, 8), (5, 10), (5, 11), (6, 2), (6, 3), (6, 4), (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), (6, 11), (7, 11), (8, 6), (8, 8), (9, 1), (10, 7), (11, 10)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 82.0537


82.05374145507812

In [None]:
# numwords
# incorr i+3
circuit = [(0, 1), (0, 6), (0, 7), (0, 9), (0, 10), (1, 0), (1, 5), (3, 3), (4, 4), (4, 10), (5, 4), (5, 6), (5, 8), (6, 6), (6, 10), (7, 6), (7, 10), (7, 11), (8, 8), (9, 1), (10, 7)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 70.6149


70.61485290527344

In [None]:
# months
# incorr i
circuit = [(0, 1), (2, 3), (2, 5), (2, 7), (2, 8), (2, 9), (4, 4), (5, 0), (5, 6), (6, 9), (6, 10), (7, 8), (7, 11), (8, 1), (8, 6), (8, 8), (8, 9), (9, 1), (9, 7), (9, 11), (10, 7), (11, 10)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 96.5812


96.5811996459961

# try incorr i+3 again

## Generate dataset with multiple prompts

In [None]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [None]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
}

In [None]:
def generate_prompts_list(x ,y):
    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    prompts_list = []
    # for i in range(x, y):
    for i in range(0, 8):
        prompt_dict = {
            'S1': months[i],
            'S2': months[i+1],
            'S3': months[i+2],
            'S4': months[i+3],
            'corr': months[i+4],
            'incorr': months[i+3],
            'text': f"{months[i]} {months[i+1]} {months[i+2]} {months[i+2]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 11)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [None]:
import random

def generate_prompts_list_corr(x ,y):
    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    prompts_list = []
    # for i in range(x, y):
    for i in range(0, 8):
        r1 = random.choice(months)
        r2 = random.choice(months)
        while True:
            r3_ind = random.randint(0,len(months)-1)
            r4_ind = random.randint(0,len(months)-1)
            if months[r3_ind] != months[r4_ind-1]:
                break
        r3 = months[r3_ind]
        r4 = months[r4_ind]
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': months[i+4],
            'incorr': months[i+3],
            'text': f"{r1} {r2} {r3} {r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(1, 11)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)
prompts_list_2

[{'S1': 'June',
  'S2': 'December',
  'S3': 'April',
  'S4': 'November',
  'corr': 'May',
  'incorr': 'April',
  'text': 'June December April November'},
 {'S1': 'July',
  'S2': 'June',
  'S3': 'November',
  'S4': 'February',
  'corr': 'June',
  'incorr': 'May',
  'text': 'July June November February'},
 {'S1': 'February',
  'S2': 'October',
  'S3': 'September',
  'S4': 'July',
  'corr': 'July',
  'incorr': 'June',
  'text': 'February October September July'},
 {'S1': 'January',
  'S2': 'November',
  'S3': 'August',
  'S4': 'March',
  'corr': 'August',
  'incorr': 'July',
  'text': 'January November August March'},
 {'S1': 'January',
  'S2': 'December',
  'S3': 'March',
  'S4': 'August',
  'corr': 'September',
  'incorr': 'August',
  'text': 'January December March August'},
 {'S1': 'July',
  'S2': 'July',
  'S3': 'September',
  'S4': 'February',
  'corr': 'October',
  'incorr': 'September',
  'text': 'July July September February'},
 {'S1': 'July',
  'S2': 'October',
  'S3': 'January'

## Ablation Expm Functions

In [None]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [None]:
ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)
orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)

In [None]:
def mean_ablate_by_lst(lst, model, orig_score, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        # "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        # "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    # ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    # orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

In [None]:
def find_circuit_forw(curr_circuit=None, orig_score=100, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [None]:
def find_circuit_backw(curr_circuit=None, orig_score=100, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

## iter backw fwd, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, orig_score=orig_score, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, orig_score=orig_score, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
187.2841339111328

Removed: (11, 1)
186.7925262451172

Removed: (11, 2)
186.4005126953125

Removed: (11, 3)
186.14523315429688

Removed: (11, 4)
186.22230529785156

Removed: (11, 5)
186.20193481445312

Removed: (11, 6)
183.9085235595703

Removed: (11, 7)
184.41839599609375

Removed: (11, 8)
181.02496337890625

Removed: (11, 9)
180.13568115234375

Removed: (11, 10)
178.2689666748047

Removed: (11, 11)
179.22093200683594

Removed: (10, 0)
179.03067016601562

Removed: (10, 1)
179.2295684814453

Removed: (10, 2)
175.9254150390625

Removed: (10, 3)
173.44964599609375

Removed: (10, 4)
173.47613525390625

Removed: (10, 5)
173.3822021484375

Removed: (10, 6)
173.35984802246094

Removed: (10, 7)
177.2030487060547

Removed: (10, 8)
177.2852783203125

Removed: (10, 9)
177.1464385986328

Removed: (10, 10)
176.1490020751953

Removed: (10, 11)
176.11376953125

Removed: (9, 0)
176.1731719970703

Removed: (9, 1)
179.74916076660156

Removed: (9, 2)
179.415710449

KeyboardInterrupt: ignored

In [None]:
bf_3 = curr_circuit.copy()
bf_3

In [None]:
len(bf_3)

## loop rmv and check for most impt heads

In [None]:
circ = bf_3
circ_score = mean_ablate_by_lst(circ, model, orig_score, print_output=True).item()

In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=True).item()
    lh_scores[lh] = new_score

In [None]:
sorted_lh_scores = dict(sorted(lh_scores.items(), key=lambda item: item[1]))
sorted_lh_scores

In [None]:
for lh, score in sorted_lh_scores.items():
    print(lh, -round(circ_score-score, 2))