<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-h5p6yzcg
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-h5p6yzcg
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.0-py3-none-any.whl (260 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.0/261.0 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f60fbf3b640>

Plotting helper functions:

In [6]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [8]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9100, done.[K
remote: Counting objects: 100% (1812/1812), done.[K
remote: Compressing objects: 100% (287/287), done.[K
remote: Total 9100 (delta 1606), reused 1601 (delta 1522), pack-reused 7288[K
Receiving objects: 100% (9100/9100), 155.60 MiB | 22.14 MiB/s, done.
Resolving deltas: 100% (5501/5501), done.


In [9]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [10]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

In [11]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

Repalce io_tokens with correct answer (next, which is '5') and s_tokens with incorrect (current, which repeats)

In [12]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
}

In [13]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+3),
            'corr': str(i+4),
            'incorr': str(i),
            'text': f"{i} {i+1} {i+2} {i+3}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 11)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [14]:
def generate_prompts_list_corr(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i),
            'S3': str(i),
            'S4': str(i),
            'corr': str(i),
            'incorr': str(i+4),
            'text': f"{i} {i} {i} {i}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(1, 11)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

Logit diff is correct - incorr token. Here, correct is S5, and incorr is S4.

Because of this, it's possible to have logit diffs HIGHER than the "full circuit" because the correct token will still be at first place, but the logit scores assigned will just be bigger (perhaps incorrect is scored even lower in the non-full circuit with a higher logit diff score)?

# Ablation Expm Functions

In [15]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [16]:
def mean_ablate_by_lst(lst, model, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        # "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        # "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

We can also prevent redundant computation of the full circuit score by storing it and just passing it in to the function.

# Ablate the model and compare with original

### try full circuit from repeatLast iter fb

In [None]:
curr_circuit = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (3, 0), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 8), (4, 10), (4, 11), (5, 4), (5, 5), (5, 9), (6, 1), (6, 6), (6, 10), (7, 6), (7, 10), (7, 11), (8, 1), (8, 2), (8, 6), (8, 8), (9, 1), (9, 5), (10, 7), (11, 10)]
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 60.6005


60.60049819946289

In [None]:
curr_circuit = [(9, 1)]
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: -13.0262


-13.026200294494629

## compare with repeatRandElem

In [17]:
repeatRand_backw_3 = [(0, 0), (0, 1), (0, 2), (0, 3), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (1, 7), (2, 0), (2, 2), (2, 9), (3, 0), (4, 2), (4, 4), (4, 7), (4, 8), (4, 9), (4, 10), (5, 0), (5, 3), (5, 4), (5, 5), (5, 6), (6, 1), (6, 3), (6, 4), (6, 6), (6, 8), (6, 10), (7, 10), (7, 11), (8, 11), (9, 1), (10, 1)]
mean_ablate_by_lst(repeatRand_backw_3, model, print_output=True).item()

Average logit difference (circuit / full) %: 66.8810


66.88101196289062

In [19]:
repeatRand_backw_20 = [(0, 1), (0, 2), (0, 3), (0, 9), (0, 10), (0, 11), (1, 0), (1, 5), (1, 7), (1, 8), (2, 0), (2, 2), (2, 7), (2, 9), (3, 0), (3, 3), (4, 4), (4, 6), (4, 7), (4, 9), (4, 10), (4, 11), (5, 0), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (6, 1), (6, 6), (6, 9), (6, 10), (7, 10), (7, 11), (9, 1)]
mean_ablate_by_lst(repeatRand_backw_20, model, print_output=True).item()

Average logit difference (circuit / full) %: 51.0900


51.08995056152344

In [18]:
repeatRand_fb_3 = [(0, 1), (0, 2), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (2, 0), (2, 2), (2, 6), (2, 7), (3, 0), (3, 11), (4, 4), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10), (5, 0), (5, 3), (5, 4), (5, 5), (5, 6), (6, 1), (6, 6), (6, 10), (7, 10), (7, 11), (8, 8), (8, 9), (8, 11), (9, 1), (11, 10)]
mean_ablate_by_lst(repeatRand_fb_3, model, print_output=True).item()

Average logit difference (circuit / full) %: 60.1871


60.18708038330078

## Work backwards

https://www.notion.so/wlg1/Search-Methods-brainstorm-15a3020ab00b40adb79b0acf3622f5f4?pvs=4#dd6b43247d4945eda1d70ca4d4bae01d

In [None]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(11, -1, -1):  # go thru all heads in a layer first
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

Removed: (11, 0)
98.45150756835938


Removed: (11, 1)
98.24433135986328


Removed: (11, 2)
98.33431243896484


Removed: (11, 3)
97.79175567626953


Removed: (11, 4)
97.27037048339844


Removed: (11, 5)
97.40259552001953


Removed: (11, 6)
97.28413391113281


Removed: (11, 7)
97.25057220458984


Removed: (11, 11)
97.88668823242188


Removed: (10, 0)
97.7774658203125


Removed: (10, 3)
97.6139907836914


Removed: (10, 4)
97.71390533447266


Removed: (10, 5)
97.6951675415039


Removed: (10, 6)
97.82437133789062


Removed: (10, 7)
98.38075256347656


Removed: (10, 8)
99.3049545288086


Removed: (10, 9)
99.44312286376953


Removed: (10, 10)
100.19226837158203


Removed: (10, 11)
99.88433074951172


Removed: (9, 0)
99.68191528320312


Removed: (9, 2)
99.8347396850586


Removed: (9, 3)
101.64493560791016


Removed: (9, 4)
100.82312774658203


Removed: (9, 5)
99.86206817626953


Removed: (9, 6)
99.5550765991211


Removed: (9, 7)
99.68671417236328


Removed: (9, 8)
99.69947052001953


Removed: 

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 99.4927


99.49271392822266

In [None]:
curr_circuit

[(0, 1),
 (0, 9),
 (1, 0),
 (1, 5),
 (2, 2),
 (2, 9),
 (2, 10),
 (3, 0),
 (3, 3),
 (3, 6),
 (3, 7),
 (4, 4),
 (4, 8),
 (5, 1),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 8),
 (6, 1),
 (6, 3),
 (6, 4),
 (6, 6),
 (6, 9),
 (6, 10),
 (7, 1),
 (7, 2),
 (7, 6),
 (7, 7),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (9, 9),
 (10, 1),
 (10, 2),
 (11, 8),
 (11, 9),
 (11, 10)]

In [None]:
backw_3 = [(0, 1), (0, 9), (1, 0), (1, 5), (2, 2), (2, 9), (2, 10), (3, 0), (3, 3), (3, 6), (3, 7), (4, 4), (4, 8), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (6, 1), (6, 3), (6, 4), (6, 6), (6, 9), (6, 10), (7, 1), (7, 2), (7, 6), (7, 7), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (9, 9), (10, 1), (10, 2), (11, 8), (11, 9), (11, 10)]
len(backw_3)

44

Now try 10% threshold:

In [None]:
def find_circuit_backw(threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    # Start with full circuit
    curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("Removed:", (layer, head))
                print(new_score)
                print("\n")

    return curr_circuit

In [None]:
curr_circuit = find_circuit_backw(10)

Removed: (11, 0)
98.45150756835938


Removed: (11, 1)
98.24433135986328


Removed: (11, 2)
98.33431243896484


Removed: (11, 3)
97.79175567626953


Removed: (11, 4)
97.27037048339844


Removed: (11, 5)
97.40259552001953


Removed: (11, 6)
97.28413391113281


Removed: (11, 7)
97.25057220458984


Removed: (11, 8)
95.6801986694336


Removed: (11, 9)
95.31547546386719


Removed: (11, 10)
94.2880630493164


Removed: (11, 11)
94.87696075439453


Removed: (10, 0)
94.77782440185547


Removed: (10, 1)
92.73703002929688


Removed: (10, 3)
92.61138153076172


Removed: (10, 4)
92.71031188964844


Removed: (10, 5)
92.65601348876953


Removed: (10, 6)
92.7713623046875


Removed: (10, 7)
93.2691879272461


Removed: (10, 8)
94.14749908447266


Removed: (10, 9)
94.25298309326172


Removed: (10, 10)
94.97111511230469


Removed: (10, 11)
94.63520812988281


Removed: (9, 0)
94.43431091308594


Removed: (9, 2)
94.5575942993164


Removed: (9, 3)
96.24874114990234


Removed: (9, 4)
95.4171371459961


Removed

In [None]:
curr_circuit

[(0, 1),
 (0, 9),
 (1, 0),
 (1, 5),
 (1, 6),
 (2, 2),
 (2, 8),
 (2, 9),
 (3, 0),
 (3, 2),
 (3, 3),
 (3, 7),
 (3, 8),
 (3, 10),
 (4, 4),
 (5, 1),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 7),
 (5, 8),
 (5, 10),
 (6, 0),
 (6, 1),
 (6, 3),
 (6, 4),
 (6, 6),
 (6, 9),
 (6, 10),
 (6, 11),
 (7, 0),
 (7, 6),
 (7, 8),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (10, 2)]

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 92.6940


92.6939926147461

In [None]:
backw_10 = [(0, 1), (0, 9), (1, 0), (1, 5), (1, 6), (2, 2), (2, 8), (2, 9), (3, 0), (3, 2), (3, 3), (3, 7), (3, 8), (3, 10), (4, 4), (5, 1), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), (5, 10), (6, 0), (6, 1), (6, 3), (6, 4), (6, 6), (6, 9), (6, 10), (6, 11), (7, 0), (7, 6), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (10, 2)]
len(backw_10)

42

20%:

In [None]:
%%capture
curr_circuit = find_circuit_backw(20)

In [None]:
curr_circuit

[(0, 1),
 (0, 9),
 (1, 0),
 (1, 5),
 (1, 6),
 (2, 2),
 (2, 8),
 (2, 9),
 (2, 10),
 (3, 0),
 (3, 2),
 (3, 3),
 (3, 7),
 (3, 10),
 (4, 4),
 (4, 10),
 (5, 1),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 8),
 (5, 10),
 (6, 1),
 (6, 3),
 (6, 4),
 (6, 6),
 (6, 9),
 (6, 10),
 (7, 2),
 (7, 6),
 (7, 7),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1)]

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 82.3372


82.33719635009766

In [None]:
backw_20 = [(0, 1), (0, 9), (1, 0), (1, 5), (1, 6), (2, 2), (2, 8), (2, 9), (2, 10), (3, 0), (3, 2), (3, 3), (3, 7), (3, 10), (4, 4), (4, 10), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (5, 10), (6, 1), (6, 3), (6, 4), (6, 6), (6, 9), (6, 10), (7, 2), (7, 6), (7, 7), (7, 10), (7, 11), (8, 0), (8, 6), (8, 8), (8, 11), (9, 1)]
len(backw_20)

40

### set diffs of the three perf lvls

In [None]:
set(backw_3) - set(backw_10)

{(2, 10),
 (3, 6),
 (4, 8),
 (5, 2),
 (5, 3),
 (7, 1),
 (7, 2),
 (7, 7),
 (9, 9),
 (10, 1),
 (11, 8),
 (11, 9),
 (11, 10)}

In [None]:
set(backw_10) - set(backw_3)

{(1, 6),
 (2, 8),
 (3, 2),
 (3, 8),
 (3, 10),
 (5, 7),
 (5, 10),
 (6, 0),
 (6, 11),
 (7, 0),
 (7, 8)}

In [None]:
set(backw_3) - set(backw_20)

{(3, 6),
 (4, 8),
 (7, 1),
 (8, 1),
 (9, 9),
 (10, 1),
 (10, 2),
 (11, 8),
 (11, 9),
 (11, 10)}

In [None]:
set(backw_10) - set(backw_20)

{(3, 8), (5, 7), (6, 0), (6, 11), (7, 0), (7, 8), (8, 1), (10, 2)}

In [None]:
mean_ablate_by_lst(backw_20, model, print_output=True)

Average logit difference (circuit / full) %: 82.3372


tensor(82.3372, device='cuda:0')

In [None]:
mean_ablate_by_lst(backw_20 + [(10, 2)], model, print_output=True)

Average logit difference (circuit / full) %: 91.6173


tensor(91.6173, device='cuda:0')

In [None]:
mean_ablate_by_lst([x for x in backw_20 if x != (9, 1)], model, print_output=True)

Average logit difference (circuit / full) %: 51.2823


tensor(51.2823, device='cuda:0')

In [None]:
mean_ablate_by_lst([x for x in backw_20 if x != (9, 1)] + [(10, 2)], model, print_output=True)

Average logit difference (circuit / full) %: 62.0274


tensor(62.0274, device='cuda:0')

### set diff w repeatLast circs

In [None]:
repLast_backw_3 = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (0, 10), (1, 0), (1, 4), (1, 5), (2, 2), (2, 8), (2, 9), (3, 0), (3, 2), (3, 3), (3, 7), (4, 4), (4, 7), (4, 10), (5, 1), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (5, 9), (6, 1), (6, 4), (6, 6), (6, 10), (6, 11), (7, 6), (7, 10), (7, 11), (8, 0), (8, 5), (8, 6), (8, 8), (9, 1), (10, 7)]
set(backw_3) - set(repLast_backw_3)

{(2, 10),
 (3, 6),
 (4, 8),
 (5, 2),
 (6, 3),
 (6, 9),
 (7, 1),
 (7, 2),
 (7, 7),
 (8, 1),
 (8, 11),
 (9, 9),
 (10, 1),
 (10, 2),
 (11, 8),
 (11, 9),
 (11, 10)}

In [None]:
set(repLast_backw_3) - set(backw_3)

{(0, 3),
 (0, 5),
 (0, 7),
 (0, 10),
 (1, 4),
 (2, 8),
 (3, 2),
 (4, 7),
 (4, 10),
 (5, 9),
 (6, 11),
 (8, 5),
 (10, 7)}

## Prune forwards

In [None]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(0, 12):
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

Removed: (0, 0)
100.00466918945312


Removed: (0, 2)
98.07209777832031


Removed: (0, 4)
97.96208953857422


Removed: (0, 6)
97.41792297363281


Removed: (0, 11)
98.04102325439453


Removed: (1, 1)
97.67564392089844


Removed: (1, 2)
97.67819213867188


Removed: (1, 3)
97.88668823242188


Removed: (1, 4)
97.89542388916016


Removed: (1, 6)
97.8697509765625


Removed: (1, 7)
98.2431640625


Removed: (1, 8)
98.43437194824219


Removed: (1, 9)
98.68045806884766


Removed: (1, 10)
98.94314575195312


Removed: (1, 11)
99.24425506591797


Removed: (2, 0)
99.28617858886719


Removed: (2, 1)
100.14505767822266


Removed: (2, 2)
99.1255111694336


Removed: (2, 3)
99.42776489257812


Removed: (2, 4)
99.11087036132812


Removed: (2, 5)
99.4810562133789


Removed: (2, 6)
99.1651611328125


Removed: (2, 7)
98.68614959716797


Removed: (2, 8)
98.37564086914062


Removed: (2, 9)
97.45429992675781


Removed: (2, 10)
97.83071899414062


Removed: (2, 11)
98.20713806152344


Removed: (3, 1)
98.5207824707

In [None]:
curr_circuit

[(0, 1),
 (0, 3),
 (0, 5),
 (0, 7),
 (0, 8),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 5),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 10),
 (3, 11),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 10),
 (4, 11),
 (5, 4),
 (5, 5),
 (5, 9),
 (6, 1),
 (6, 6),
 (6, 10),
 (7, 6),
 (7, 10),
 (7, 11),
 (8, 1),
 (8, 2),
 (8, 6),
 (8, 8),
 (9, 1),
 (9, 5),
 (10, 7),
 (11, 10)]

## prune fwds-backwds iteratively

In [None]:
def find_circuit_forw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [None]:
def find_circuit_backw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

### iter fwd backw, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
100.09558868408203

Removed: (0, 2)
100.14701843261719

Removed: (0, 3)
98.6058120727539

Removed: (0, 4)
98.61605834960938

Removed: (0, 5)
97.37905883789062

Removed: (0, 6)
97.55424499511719

Removed: (0, 7)
97.67510223388672

Removed: (0, 10)
99.30681610107422

Removed: (0, 11)
99.86624145507812

Removed: (1, 0)
97.25206756591797

Removed: (1, 1)
97.12611389160156

Removed: (1, 2)
97.61646270751953

Removed: (1, 3)
97.06305694580078

Removed: (1, 7)
99.2613525390625

Removed: (1, 8)
99.25433349609375

Removed: (1, 9)
99.36122131347656

Removed: (1, 10)
99.99845123291016

Removed: (1, 11)
104.18968963623047

Removed: (2, 0)
104.28089141845703

Removed: (2, 1)
105.65437316894531

Removed: (2, 2)
104.25637817382812

Removed: (2, 3)
105.09373474121094

Removed: (2, 4)
104.43804168701172

Removed: (2, 5)
105.24530792236328

Removed: (2, 6)
104.80706024169922

Removed: (2, 7)
103.23683166503906

Removed: (2, 8)
102.96487426757812

Removed: (2, 9)
100.

In [None]:
fb_3 = curr_circuit.copy()
fb_3

[(0, 1),
 (1, 5),
 (3, 0),
 (3, 3),
 (3, 6),
 (3, 7),
 (3, 10),
 (4, 4),
 (4, 8),
 (4, 11),
 (5, 4),
 (5, 5),
 (5, 6),
 (6, 6),
 (6, 9),
 (6, 10),
 (6, 11),
 (7, 0),
 (7, 6),
 (7, 7),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (10, 1),
 (10, 2),
 (10, 3),
 (10, 5),
 (11, 8),
 (11, 10)]

#### compare

In [None]:
set(backw_3) - set(fb_3)

{(0, 9),
 (1, 0),
 (2, 2),
 (2, 9),
 (2, 10),
 (5, 1),
 (5, 2),
 (5, 3),
 (5, 8),
 (6, 1),
 (6, 3),
 (6, 4),
 (7, 1),
 (7, 2),
 (9, 9),
 (11, 9)}

In [None]:
set(fb_3) - set(backw_3)

{(3, 10), (4, 11), (6, 11), (7, 0), (10, 3), (10, 5)}

### iter fwd backw, threshold 20

In [None]:
threshold = 20
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1

In [None]:
curr_circuit

## prune backwds-fwds iteratively

### iter fwd backw, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
98.45150756835938

Removed: (11, 1)
98.24433135986328

Removed: (11, 2)
98.33431243896484

Removed: (11, 3)
97.79175567626953

Removed: (11, 4)
97.27037048339844

Removed: (11, 5)
97.40259552001953

Removed: (11, 6)
97.28413391113281

Removed: (11, 7)
97.25057220458984

Removed: (11, 11)
97.88668823242188

Removed: (10, 0)
97.7774658203125

Removed: (10, 3)
97.6139907836914

Removed: (10, 4)
97.71390533447266

Removed: (10, 5)
97.6951675415039

Removed: (10, 6)
97.82437133789062

Removed: (10, 7)
98.38075256347656

Removed: (10, 8)
99.3049545288086

Removed: (10, 9)
99.44312286376953

Removed: (10, 10)
100.19226837158203

Removed: (10, 11)
99.88433074951172

Removed: (9, 0)
99.68191528320312

Removed: (9, 2)
99.8347396850586

Removed: (9, 3)
101.64493560791016

Removed: (9, 4)
100.82312774658203

Removed: (9, 5)
99.86206817626953

Removed: (9, 6)
99.5550765991211

Removed: (9, 7)
99.68671417236328

Removed: (9, 8)
99.69947052001953

Removed: (9, 

In [None]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (1, 5),
 (2, 2),
 (2, 9),
 (2, 10),
 (3, 0),
 (3, 3),
 (3, 6),
 (3, 7),
 (4, 4),
 (4, 8),
 (5, 4),
 (5, 5),
 (5, 6),
 (6, 4),
 (6, 6),
 (6, 9),
 (6, 10),
 (7, 2),
 (7, 6),
 (7, 7),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (9, 9),
 (10, 1),
 (10, 2),
 (11, 8),
 (11, 9),
 (11, 10)]

#### compare

In [None]:
len(bf_3)

35

In [None]:
len(fb_3)

34

In [None]:
set(backw_3) - set(bf_3)

{(0, 9), (1, 0), (5, 1), (5, 2), (5, 3), (5, 8), (6, 1), (6, 3), (7, 1)}

In [None]:
set(bf_3) - set(backw_3)

set()

In [None]:
set(fb_3) - set(bf_3)

{(3, 10), (4, 11), (6, 11), (7, 0), (10, 3), (10, 5)}

In [None]:
set(bf_3) - set(fb_3)

{(2, 2), (2, 9), (2, 10), (6, 4), (7, 2), (9, 9), (11, 9)}

## etc fns

In [None]:
# base_lst = [(0, 1), (0, 10), (3, 0), (4, 4), (5, 5), (6, 1), (6, 6), (7, 10), (7, 11), (8, 8), (8, 11), (9, 1), (9, 5), (10, 7)]

In [None]:
# import json

# with open("scores.json", "w") as file:
#     json.dump(all_scores, file, default=lambda x: str(x))  # Convert tuples to strings for JSON serialization

In [None]:
# from google.colab import files
# files.download("scores.json")  # or "scores.pkl" or "scores.json" depending on the file you saved

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>