<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-kro7ru3z
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-kro7ru3z
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7c33b2137ee0>

Plotting helper functions:

In [6]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [8]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9106, done.[K
remote: Counting objects: 100% (1818/1818), done.[K
remote: Compressing objects: 100% (288/288), done.[K
remote: Total 9106 (delta 1611), reused 1607 (delta 1527), pack-reused 7288[K
Receiving objects: 100% (9106/9106), 155.60 MiB | 21.02 MiB/s, done.
Resolving deltas: 100% (5506/5506), done.


In [9]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [10]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

In [11]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

Repalce io_tokens with correct answer (next, which is '5') and s_tokens with incorrect (current, which repeats)

In [12]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
}

In [13]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+3),
            'corr': str(i+4),
            'incorr': str(i),
            'text': f"{i} {i+1} {i+2} {i+3}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 101)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [14]:
import random

def generate_prompts_list_corr(x ,y):
    prompts_list = []
    for i in range(x, y):
        rDecade = random.randint(1, 9)*10
        r1 = random.randint(rDecade, 10+rDecade)
        r2 = random.randint(rDecade, 10+rDecade)
        r3 = random.randint(rDecade, 10+rDecade)
        r4 = random.randint(rDecade, 10+rDecade)
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': str(r1),
            'incorr': str(i+4),
            'text': f"{r1} {r2} {r3} {r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(1, 101)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)
prompts_list_2

[{'S1': '29',
  'S2': '23',
  'S3': '30',
  'S4': '29',
  'corr': '29',
  'incorr': '5',
  'text': '29 23 30 29'},
 {'S1': '77',
  'S2': '80',
  'S3': '80',
  'S4': '80',
  'corr': '77',
  'incorr': '6',
  'text': '77 80 80 80'},
 {'S1': '65',
  'S2': '60',
  'S3': '61',
  'S4': '61',
  'corr': '65',
  'incorr': '7',
  'text': '65 60 61 61'},
 {'S1': '14',
  'S2': '20',
  'S3': '10',
  'S4': '14',
  'corr': '14',
  'incorr': '8',
  'text': '14 20 10 14'},
 {'S1': '37',
  'S2': '36',
  'S3': '30',
  'S4': '32',
  'corr': '37',
  'incorr': '9',
  'text': '37 36 30 32'},
 {'S1': '53',
  'S2': '59',
  'S3': '54',
  'S4': '50',
  'corr': '53',
  'incorr': '10',
  'text': '53 59 54 50'},
 {'S1': '50',
  'S2': '56',
  'S3': '55',
  'S4': '51',
  'corr': '50',
  'incorr': '11',
  'text': '50 56 55 51'},
 {'S1': '33',
  'S2': '39',
  'S3': '37',
  'S4': '32',
  'corr': '33',
  'incorr': '12',
  'text': '33 39 37 32'},
 {'S1': '43',
  'S2': '47',
  'S3': '48',
  'S4': '44',
  'corr': '43',
  'in

Logit diff is correct - incorr token. Here, correct is S5, and incorr is S4.

Because of this, it's possible to have logit diffs HIGHER than the "full circuit" because the correct token will still be at first place, but the logit scores assigned will just be bigger (perhaps incorrect is scored even lower in the non-full circuit with a higher logit diff score)?

# Ablation Expm Functions

In [15]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [16]:
ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)
orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)

In [17]:
def mean_ablate_by_lst(lst, model, orig_score, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        # "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        # "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    # ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    # orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

In [18]:
def find_circuit_forw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [19]:
def find_circuit_backw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

# try other tasks circs

In [None]:
# fb 80, digits incr
# https://colab.research.google.com/drive/1mFWmGAKtigFcqqWWMCwU7wWQY2HT5ZOo#scrollTo=lJEY-Zs2g_a5&line=1&uniqifier=1
circuit = [(1, 5), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (6, 3), (6, 8), (6, 10), (7, 0), (7, 2), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (11, 8)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 82.9496


82.94963073730469

In [None]:
# fb 80, numwords
# https://colab.research.google.com/drive/1QTv-4osLHadCAay0beew-xlXszPCG88s#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(3, 2), (4, 4), (4, 8), (4, 10), (4, 11), (5, 5), (5, 6), (5, 7), (5, 8), (6, 1), (6, 7), (6, 9), (6, 10), (7, 0), (7, 2), (7, 5), (7, 6), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (10, 2)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 48.0918


48.09180450439453

In [None]:
# fb 80, months
# https://colab.research.google.com/drive/1lhQqlizYGMC11vzp6I9mJ3dyxIr8tV3l#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(4, 4), (7, 11), (8, 6), (8, 9), (8, 11), (9, 1), (9, 5), (11, 10)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 28.4399


28.43991470336914

In [None]:
# greater-than
circuit = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 12.5404


12.540437698364258

In [None]:
# IOI
circuit = [(0, 1), (0, 10), (2, 2), (3, 0), (4, 11), (5, 5), (5, 8), (5, 9), (6, 9), (7, 3), (7, 9), (8, 6), (8, 10), (9, 0), (9, 6), (9, 7), (9, 9), (10, 0), (10, 1), (10, 2), (10, 6), (10, 7), (10, 10), (11, 2), (11, 9), (11, 10)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 2.0818


2.0818023681640625

# by seqpos

In [None]:
# def mean_ablate_by_lst(CIRCUIT, SEQ_POS_TO_KEEP, model, orig_score, print_output=True):
#     model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

#     # ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

#     model_abl = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
#     ioi_logits_minimal = model_abl(dataset.toks)

#     # orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)

#     new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
#     if print_output:
#         # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
#         # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
#         print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
#     # return new_score
#     return 100 * new_score / orig_score

### try full circuit from repeatLast iter fb

In [None]:
curr_circuit = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (3, 0), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 8), (4, 10), (4, 11), (5, 4), (5, 5), (5, 9), (6, 1), (6, 6), (6, 10), (7, 6), (7, 10), (7, 11), (8, 1), (8, 2), (8, 6), (8, 8), (9, 1), (9, 5), (10, 7), (11, 10)]
mean_ablate_by_lst(curr_circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 55.3844


55.38440704345703

In [None]:
curr_circuit = [(9, 1)]
mean_ablate_by_lst(curr_circuit, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 1.4608


1.4607776403427124

# Ablate the model and compare with original

## Prune backwards

In [None]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(11, -1, -1):  # go thru all heads in a layer first
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

Removed: (11, 0)
97.85265350341797


Removed: (11, 1)
97.83056640625


Removed: (11, 2)
98.08329010009766


Removed: (11, 3)
97.8642807006836


Removed: (11, 4)
98.09476470947266


Removed: (11, 5)
98.19351196289062


Removed: (11, 6)
98.30775451660156


Removed: (11, 7)
98.28131866455078


Removed: (11, 9)
98.12042999267578


Removed: (11, 11)
99.49564361572266


Removed: (10, 0)
99.5240249633789


Removed: (10, 1)
99.18901824951172


Removed: (10, 2)
99.78797149658203


Removed: (10, 3)
99.53922271728516


Removed: (10, 4)
99.09219360351562


Removed: (10, 5)
98.38536834716797


Removed: (10, 6)
98.42536926269531


Removed: (10, 7)
98.42119598388672


Removed: (10, 8)
98.67456817626953


Removed: (10, 9)
98.50550079345703


Removed: (10, 10)
98.72576904296875


Removed: (10, 11)
98.79467010498047


Removed: (9, 0)
98.8804702758789


Removed: (9, 2)
98.96058654785156


Removed: (9, 3)
98.78217315673828


Removed: (9, 4)
98.94293212890625


Removed: (9, 5)
99.04121398925781


Removed: 

In [None]:
mean_ablate_by_lst(curr_circuit, model, orig_score, print_output=True)

Average logit difference (circuit / full) %: 97.8223


tensor(97.8223, device='cuda:0')

In [None]:
backw_3 = curr_circuit.copy()
backw_3

[(0, 1),
 (0, 2),
 (0, 4),
 (0, 5),
 (0, 9),
 (0, 10),
 (0, 11),
 (1, 0),
 (1, 1),
 (1, 2),
 (1, 3),
 (1, 4),
 (1, 5),
 (1, 7),
 (1, 8),
 (2, 2),
 (2, 6),
 (2, 8),
 (2, 10),
 (3, 1),
 (3, 3),
 (3, 6),
 (3, 7),
 (3, 8),
 (3, 9),
 (4, 4),
 (4, 9),
 (4, 10),
 (5, 1),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 6),
 (5, 8),
 (5, 11),
 (6, 1),
 (6, 3),
 (6, 8),
 (6, 9),
 (6, 10),
 (6, 11),
 (7, 2),
 (7, 6),
 (7, 8),
 (7, 11),
 (8, 0),
 (8, 5),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (11, 8),
 (11, 10)]

In [None]:
len(backw_3)

53

Now try 10% threshold:

In [None]:
def find_circuit_backw(threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    # Start with full circuit
    curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, orig_score,print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("Removed:", (layer, head))
                print(new_score)
                print("\n")

    return curr_circuit

In [None]:
curr_circuit = find_circuit_backw(10)

Removed: (11, 0)
97.85265350341797


Removed: (11, 1)
97.83056640625


Removed: (11, 2)
98.08329010009766


Removed: (11, 3)
97.8642807006836


Removed: (11, 4)
98.09476470947266


Removed: (11, 5)
98.19351196289062


Removed: (11, 6)
98.30775451660156


Removed: (11, 7)
98.28131866455078


Removed: (11, 8)
95.95122528076172


Removed: (11, 9)
95.80072021484375


Removed: (11, 10)
94.64726257324219


Removed: (11, 11)
95.88154602050781


Removed: (10, 0)
95.92694091796875


Removed: (10, 1)
95.60623168945312


Removed: (10, 2)
96.33126831054688


Removed: (10, 3)
96.11638641357422


Removed: (10, 4)
95.68040466308594


Removed: (10, 5)
94.96817779541016


Removed: (10, 6)
95.02008819580078


Removed: (10, 7)
94.5572280883789


Removed: (10, 8)
94.81536865234375


Removed: (10, 9)
94.65853881835938


Removed: (10, 10)
94.8770523071289


Removed: (10, 11)
94.94637298583984


Removed: (9, 0)
95.02518463134766


Removed: (9, 2)
95.09461212158203


Removed: (9, 3)
94.93053436279297


Remove

In [None]:
backw_10 = curr_circuit.copy()
backw_10

[(0, 1),
 (0, 3),
 (0, 4),
 (0, 5),
 (0, 8),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 1),
 (1, 3),
 (1, 4),
 (1, 5),
 (1, 7),
 (1, 8),
 (2, 0),
 (2, 2),
 (2, 3),
 (2, 4),
 (2, 6),
 (2, 8),
 (2, 10),
 (3, 3),
 (3, 7),
 (3, 8),
 (3, 11),
 (4, 4),
 (4, 6),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 1),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 8),
 (5, 11),
 (6, 2),
 (6, 3),
 (6, 4),
 (6, 6),
 (6, 8),
 (6, 10),
 (6, 11),
 (7, 8),
 (7, 9),
 (7, 11),
 (8, 5),
 (8, 6),
 (8, 11),
 (9, 1)]

In [None]:
mean_ablate_by_lst(curr_circuit, model, orig_score, print_output=True)

Average logit difference (circuit / full) %: 90.1919


tensor(90.1919, device='cuda:0')

In [None]:
len(backw_10)

52

20%:

In [None]:
# %%capture
# curr_circuit = find_circuit_backw(20)

In [None]:
# backw_20 = curr_circuit.copy()
# backw_20

In [None]:
# mean_ablate_by_lst(curr_circuit, model, orig_score, print_output=True)

In [None]:
# len(backw_20)

### set diffs of the three perf lvls

In [None]:
set(backw_3) - set(backw_10)

{(0, 2),
 (0, 11),
 (1, 2),
 (3, 1),
 (3, 6),
 (3, 9),
 (4, 9),
 (6, 1),
 (6, 9),
 (7, 2),
 (7, 6),
 (8, 0),
 (8, 8),
 (11, 8),
 (11, 10)}

In [None]:
set(backw_10) - set(backw_3)

{(0, 3),
 (0, 8),
 (2, 0),
 (2, 3),
 (2, 4),
 (3, 11),
 (4, 6),
 (4, 11),
 (5, 0),
 (5, 5),
 (6, 2),
 (6, 4),
 (6, 6),
 (7, 9)}

In [None]:
# set(backw_3) - set(backw_20)

In [None]:
# set(backw_10) - set(backw_20)

## Prune forwards

In [None]:
# # Start with full circuit
# curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
# threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

# for layer in range(0, 12):
#     for head in range(12):
#         # Copying the curr_circuit so we can iterate over one and modify the other
#         copy_circuit = curr_circuit.copy()

#         # Temporarily removing the current tuple from the copied circuit
#         copy_circuit.remove((layer, head))

#         new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

#         # print((layer,head), new_score)
#         # If the result is less than the threshold, remove the tuple from the original list
#         if (100 - new_score) < threshold:
#             curr_circuit.remove((layer, head))

#             print("Removed:", (layer, head))
#             print(new_score)
#             print("\n")

## Prune fwds-backwds iteratively

### iter fwd backw, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
100.7425765991211

Removed: (0, 2)
99.4460220336914

Removed: (0, 3)
99.01410675048828

Removed: (0, 4)
99.1076889038086

Removed: (0, 5)
98.4770278930664

Removed: (0, 6)
98.77283477783203

Removed: (0, 7)
98.28561401367188

Removed: (0, 8)
98.60443115234375

Removed: (0, 11)
97.68379974365234

Removed: (1, 1)
97.444091796875

Removed: (1, 2)
97.00460052490234

Removed: (1, 6)
97.26143646240234

Removed: (1, 9)
97.76024627685547

Removed: (1, 10)
97.0715560913086

Removed: (2, 0)
97.0306396484375

Removed: (2, 3)
97.39745330810547

Removed: (2, 4)
97.24488067626953

Removed: (2, 5)
97.24573516845703

Removed: (2, 6)
97.56053924560547

Removed: (2, 7)
98.39006042480469

Removed: (2, 8)
97.12518310546875

Removed: (2, 11)
97.07171630859375

Removed: (3, 0)
98.21617126464844

Removed: (3, 1)
97.15900421142578

Removed: (3, 2)
97.48379516601562

Removed: (3, 4)
98.41328430175781

Removed: (3, 5)
97.61367797851562

Removed: (3, 6)
97.11524963378906

Rem

In [None]:
fb_3 = curr_circuit.copy()
fb_3

[(0, 1),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 3),
 (1, 4),
 (1, 5),
 (1, 7),
 (1, 8),
 (1, 11),
 (2, 1),
 (2, 2),
 (2, 9),
 (2, 10),
 (3, 3),
 (3, 7),
 (3, 8),
 (3, 9),
 (4, 4),
 (4, 6),
 (4, 9),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 1),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 8),
 (6, 1),
 (6, 3),
 (6, 9),
 (6, 10),
 (6, 11),
 (7, 0),
 (7, 6),
 (7, 8),
 (7, 11),
 (8, 0),
 (8, 5),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (9, 10),
 (9, 11),
 (10, 5),
 (11, 0),
 (11, 8),
 (11, 10)]

In [None]:
mean_ablate_by_lst(fb_3, model, orig_score, print_output=True)

Average logit difference (circuit / full) %: 97.0029


tensor(97.0029, device='cuda:0')

In [None]:
len(fb_3)

52

#### compare

In [None]:
set(backw_3) - set(fb_3)

{(0, 2),
 (0, 4),
 (0, 5),
 (0, 11),
 (1, 1),
 (1, 2),
 (2, 6),
 (2, 8),
 (3, 1),
 (3, 6),
 (5, 11),
 (6, 8),
 (7, 2)}

In [None]:
set(fb_3) - set(backw_3)

{(1, 11),
 (2, 1),
 (2, 9),
 (4, 6),
 (4, 11),
 (5, 0),
 (5, 5),
 (7, 0),
 (9, 10),
 (9, 11),
 (10, 5),
 (11, 0)}

### iter fwd backw, threshold 10

In [None]:
threshold = 10
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
99.63545227050781

Removed: (0, 1)
97.8546142578125

Removed: (0, 2)
93.77347564697266

Removed: (0, 3)
93.27129364013672

Removed: (0, 4)
93.41110229492188

Removed: (0, 5)
92.48627471923828

Removed: (0, 6)
94.19954681396484

Removed: (0, 7)
92.69132995605469

Removed: (0, 8)
90.80460357666016

Removed: (0, 11)
91.21514892578125

Removed: (1, 1)
90.18220520019531

Removed: (1, 2)
90.00753784179688

Removed: (1, 4)
90.07615661621094

Removed: (1, 6)
90.2310562133789

Removed: (1, 9)
90.85907745361328

Removed: (1, 10)
90.97980499267578

Removed: (2, 1)
90.61381530761719

Removed: (2, 4)
90.8145523071289

Removed: (2, 5)
90.23743438720703

Removed: (2, 7)
90.24185180664062

Removed: (2, 9)
90.32540130615234

Removed: (2, 11)
90.30589294433594

Removed: (3, 1)
90.02891540527344

Removed: (3, 2)
90.7135238647461

Removed: (3, 4)
91.86970520019531

Removed: (3, 5)
91.53072357177734

Removed: (3, 6)
90.48258209228516

Removed: (3, 9)
90.23042297363281



In [None]:
curr_circuit

[(0, 9),
 (0, 10),
 (1, 0),
 (1, 3),
 (1, 5),
 (1, 7),
 (1, 8),
 (1, 11),
 (2, 0),
 (2, 2),
 (2, 3),
 (2, 6),
 (2, 8),
 (2, 10),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 8),
 (3, 11),
 (4, 2),
 (4, 4),
 (4, 6),
 (4, 9),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 1),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 8),
 (6, 1),
 (6, 2),
 (6, 3),
 (6, 9),
 (6, 10),
 (6, 11),
 (7, 0),
 (7, 2),
 (7, 6),
 (7, 8),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 6),
 (8, 7),
 (8, 8),
 (8, 11),
 (9, 1),
 (11, 8)]

In [None]:
fb_10 = curr_circuit.copy()
len(fb_10)

52

#### loop rmv and check for most impt heads

In [None]:
circ = fb_10
mean_ablate_by_lst(circ, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 90.0007


90.00074005126953

In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=True).item()
    lh_scores[lh] = new_score
lh_scores

removed: (0, 9)
Average logit difference (circuit / full) %: 84.0520
removed: (0, 10)
Average logit difference (circuit / full) %: 88.8834
removed: (1, 0)
Average logit difference (circuit / full) %: 87.0976
removed: (1, 3)
Average logit difference (circuit / full) %: 89.9764
removed: (1, 5)
Average logit difference (circuit / full) %: 77.0306
removed: (1, 7)
Average logit difference (circuit / full) %: 89.0403
removed: (1, 8)
Average logit difference (circuit / full) %: 89.5353
removed: (1, 11)
Average logit difference (circuit / full) %: 81.5163
removed: (2, 0)
Average logit difference (circuit / full) %: 88.5532
removed: (2, 2)
Average logit difference (circuit / full) %: 88.6880
removed: (2, 3)
Average logit difference (circuit / full) %: 89.3275
removed: (2, 6)
Average logit difference (circuit / full) %: 88.5417
removed: (2, 8)
Average logit difference (circuit / full) %: 89.1752
removed: (2, 10)
Average logit difference (circuit / full) %: 89.0117
removed: (3, 0)
Average logit d

{(0, 9): 84.0519790649414,
 (0, 10): 88.8834228515625,
 (1, 0): 87.0976333618164,
 (1, 3): 89.97638702392578,
 (1, 5): 77.03057098388672,
 (1, 7): 89.040283203125,
 (1, 8): 89.53528594970703,
 (1, 11): 81.51630401611328,
 (2, 0): 88.55322265625,
 (2, 2): 88.68795776367188,
 (2, 3): 89.32754516601562,
 (2, 6): 88.54174041748047,
 (2, 8): 89.17516326904297,
 (2, 10): 89.0117416381836,
 (3, 0): 88.91248321533203,
 (3, 3): 88.50071716308594,
 (3, 7): 88.37588500976562,
 (3, 8): 89.33065795898438,
 (3, 11): 88.67903900146484,
 (4, 2): 89.86676025390625,
 (4, 4): 54.29780197143555,
 (4, 6): 89.15436553955078,
 (4, 9): 89.65997314453125,
 (4, 10): 82.01959991455078,
 (4, 11): 88.50983428955078,
 (5, 0): 87.88685607910156,
 (5, 1): 89.89642333984375,
 (5, 2): 88.94784545898438,
 (5, 3): 89.31974792480469,
 (5, 4): 87.45087432861328,
 (5, 5): 89.35881042480469,
 (5, 6): 88.36607360839844,
 (5, 8): 87.88916015625,
 (6, 1): 89.3843994140625,
 (6, 2): 89.60713958740234,
 (6, 3): 88.78077697753906,

In [None]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(4, 4): 54.29780197143555,
 (9, 1): 70.10905456542969,
 (1, 5): 77.03057098388672,
 (8, 11): 80.89986419677734,
 (7, 11): 81.49203491210938,
 (1, 11): 81.51630401611328,
 (4, 10): 82.01959991455078,
 (0, 9): 84.0519790649414,
 (6, 10): 85.72323608398438,
 (8, 6): 87.06271362304688,
 (1, 0): 87.0976333618164,
 (5, 4): 87.45087432861328,
 (8, 0): 87.4754867553711,
 (7, 0): 87.53032684326172,
 (7, 10): 87.60797882080078,
 (11, 8): 87.84275817871094,
 (5, 0): 87.88685607910156,
 (5, 8): 87.88916015625,
 (5, 6): 88.36607360839844,
 (3, 7): 88.37588500976562,
 (3, 3): 88.50071716308594,
 (4, 11): 88.50983428955078,
 (2, 6): 88.54174041748047,
 (2, 0): 88.55322265625,
 (3, 11): 88.67903900146484,
 (2, 2): 88.68795776367188,
 (6, 9): 88.75731658935547,
 (6, 3): 88.78077697753906,
 (7, 8): 88.85366821289062,
 (0, 10): 88.8834228515625,
 (3, 0): 88.91248321533203,
 (5, 2): 88.94784545898438,
 (2, 10): 89.0117416381836,
 (7, 6): 89.02937316894531,
 (1, 7): 89.040283203125,
 (4, 6): 89.1543655395

### iter fwd backw, threshold 20

In [None]:
threshold = 20
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
99.63545227050781

Removed: (0, 1)
97.8546142578125

Removed: (0, 2)
93.77347564697266

Removed: (0, 3)
93.27129364013672

Removed: (0, 4)
93.41110229492188

Removed: (0, 5)
92.48627471923828

Removed: (0, 6)
94.19954681396484

Removed: (0, 7)
92.69132995605469

Removed: (0, 8)
90.80460357666016

Removed: (0, 9)
84.43527221679688

Removed: (0, 10)
91.81703186035156

Removed: (0, 11)
92.19535827636719

Removed: (1, 0)
89.43771362304688

Removed: (1, 1)
88.7485580444336

Removed: (1, 2)
88.58959197998047

Removed: (1, 3)
87.7369384765625

Removed: (1, 4)
87.66887664794922

Removed: (1, 6)
87.9463119506836

Removed: (1, 7)
87.06246185302734

Removed: (1, 8)
86.33380126953125

Removed: (1, 9)
87.08216857910156

Removed: (1, 10)
87.17735290527344

Removed: (1, 11)
87.16203308105469

Removed: (2, 0)
85.96482849121094

Removed: (2, 1)
85.83142852783203

Removed: (2, 2)
84.40343475341797

Removed: (2, 3)
83.68087005615234

Removed: (2, 4)
83.70391082763672


In [None]:
curr_circuit

[(1, 5),
 (3, 3),
 (3, 7),
 (3, 10),
 (3, 11),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 6),
 (6, 3),
 (6, 8),
 (6, 10),
 (7, 0),
 (7, 2),
 (7, 7),
 (7, 8),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 9),
 (8, 11),
 (9, 1),
 (11, 8)]

In [None]:
fb_20 = curr_circuit.copy()
len(fb_20)

32

#### loop rmv and check for most impt heads

In [None]:
circ = fb_20
mean_ablate_by_lst(circ, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 80.0162


80.01622772216797

In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=True).item()
    lh_scores[lh] = new_score

In [None]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(4, 4): 44.0450553894043,
 (9, 1): 59.55636215209961,
 (1, 5): 63.676727294921875,
 (7, 11): 67.86986541748047,
 (8, 11): 69.75837707519531,
 (4, 10): 71.58486938476562,
 (7, 10): 77.24897003173828,
 (5, 6): 77.36872100830078,
 (6, 10): 77.62753295898438,
 (8, 6): 77.66889190673828,
 (5, 4): 78.35254669189453,
 (7, 0): 78.39317321777344,
 (7, 2): 78.40939331054688,
 (11, 8): 78.55787658691406,
 (5, 0): 78.76732635498047,
 (6, 8): 78.90299224853516,
 (6, 3): 79.09789276123047,
 (8, 0): 79.11331176757812,
 (5, 3): 79.12004089355469,
 (7, 8): 79.15681457519531,
 (4, 7): 79.21139526367188,
 (8, 1): 79.2576675415039,
 (3, 11): 79.30284118652344,
 (8, 8): 79.31413269042969,
 (4, 11): 79.39554595947266,
 (3, 7): 79.53402709960938,
 (3, 10): 79.6149673461914,
 (5, 2): 79.66963195800781,
 (3, 3): 79.68397521972656,
 (4, 6): 79.69416809082031,
 (7, 7): 79.86949920654297,
 (8, 9): 79.87236022949219}

In [None]:
lh_scores = {(4, 4): 44.0450553894043,
 (9, 1): 59.55636215209961,
 (1, 5): 63.676727294921875,
 (7, 11): 67.86986541748047,
 (8, 11): 69.75837707519531,
 (4, 10): 71.58486938476562,
 (7, 10): 77.24897003173828,}

for lh, score in lh_scores.items():
    print(lh, round(80-score, 2))

(4, 4) 35.95
(9, 1) 20.44
(1, 5) 16.32
(7, 11) 12.13
(8, 11) 10.24
(4, 10) 8.42
(7, 10) 2.75


In [None]:
fb_rand_20, rand_score = find_circuit_rand(fb_20, threshold=25, numRandIters=30)

TypeError: ignored

In [None]:
fb_25_scores = {key: value for key, value in lh_scores.items() if value < 75}
fb_25_scores

{(1, 5): 63.676727294921875,
 (4, 4): 44.0450553894043,
 (4, 10): 71.58486938476562,
 (7, 11): 67.86986541748047,
 (8, 11): 69.75837707519531,
 (9, 1): 59.55636215209961}

In [None]:
circ = list(fb_25_scores.keys())
mean_ablate_by_lst(circ, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 42.9680


42.967979431152344

We are removing multiple at a time, not just one head, which is why the score is not 75%. All those heads that don't work "as one" work "together".

### iter fwd backw, threshold 25

In [None]:
threshold = 25
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
99.63545227050781

Removed: (0, 1)
97.8546142578125

Removed: (0, 2)
93.77347564697266

Removed: (0, 3)
93.27129364013672

Removed: (0, 4)
93.41110229492188

Removed: (0, 5)
92.48627471923828

Removed: (0, 6)
94.19954681396484

Removed: (0, 7)
92.69132995605469

Removed: (0, 8)
90.80460357666016

Removed: (0, 9)
84.43527221679688

Removed: (0, 10)
91.81703186035156

Removed: (0, 11)
92.19535827636719

Removed: (1, 0)
89.43771362304688

Removed: (1, 1)
88.7485580444336

Removed: (1, 2)
88.58959197998047

Removed: (1, 3)
87.7369384765625

Removed: (1, 4)
87.66887664794922

Removed: (1, 6)
87.9463119506836

Removed: (1, 7)
87.06246185302734

Removed: (1, 8)
86.33380126953125

Removed: (1, 9)
87.08216857910156

Removed: (1, 10)
87.17735290527344

Removed: (1, 11)
87.16203308105469

Removed: (2, 0)
85.96482849121094

Removed: (2, 1)
85.83142852783203

Removed: (2, 2)
84.40343475341797

Removed: (2, 3)
83.68087005615234

Removed: (2, 4)
83.70391082763672


In [None]:
curr_circuit

[(1, 5),
 (3, 10),
 (4, 4),
 (4, 7),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 3),
 (5, 4),
 (5, 6),
 (6, 8),
 (6, 10),
 (7, 2),
 (7, 8),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (11, 10)]

In [None]:
fb_25 = curr_circuit.copy()
len(fb_25)

23

#### loop rmv and check for most impt heads

In [None]:
circ = fb_25
mean_ablate_by_lst(circ, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 75.0805


75.0804672241211

In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=True).item()
    lh_scores[lh] = new_score

removed: (1, 5)
Average logit difference (circuit / full) %: 58.4193
removed: (3, 10)
Average logit difference (circuit / full) %: 74.6676
removed: (4, 4)
Average logit difference (circuit / full) %: 43.9454
removed: (4, 7)
Average logit difference (circuit / full) %: 74.6017
removed: (4, 10)
Average logit difference (circuit / full) %: 66.3976
removed: (4, 11)
Average logit difference (circuit / full) %: 73.4884
removed: (5, 0)
Average logit difference (circuit / full) %: 74.4927
removed: (5, 3)
Average logit difference (circuit / full) %: 74.0895
removed: (5, 4)
Average logit difference (circuit / full) %: 73.6528
removed: (5, 6)
Average logit difference (circuit / full) %: 71.2443
removed: (6, 8)
Average logit difference (circuit / full) %: 73.0884
removed: (6, 10)
Average logit difference (circuit / full) %: 72.8877
removed: (7, 2)
Average logit difference (circuit / full) %: 72.7953
removed: (7, 8)
Average logit difference (circuit / full) %: 74.2235
removed: (7, 10)
Average logit

In [None]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(4, 4): 43.945350646972656,
 (9, 1): 55.33553695678711,
 (1, 5): 58.41929244995117,
 (7, 11): 62.95230484008789,
 (8, 11): 64.5703353881836,
 (4, 10): 66.3975601196289,
 (5, 6): 71.24427032470703,
 (7, 10): 71.80767822265625,
 (7, 2): 72.79530334472656,
 (6, 10): 72.8876953125,
 (8, 6): 72.96757507324219,
 (6, 8): 73.08843231201172,
 (4, 11): 73.48839569091797,
 (5, 4): 73.6528091430664,
 (11, 10): 73.99994659423828,
 (5, 3): 74.08948516845703,
 (7, 8): 74.22348022460938,
 (8, 0): 74.29803466796875,
 (8, 1): 74.34387969970703,
 (8, 8): 74.45545959472656,
 (5, 0): 74.49271392822266,
 (4, 7): 74.60174560546875,
 (3, 10): 74.66755676269531}

## Prune backwds-fwds iteratively

### iter fwd backw, threshold 3

In [20]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
98.94367218017578

Removed: (11, 1)
98.86380767822266

Removed: (11, 2)
99.14659118652344

Removed: (11, 3)
99.03900909423828

Removed: (11, 4)
99.36112213134766

Removed: (11, 5)
99.41478729248047

Removed: (11, 6)
99.54025268554688

Removed: (11, 7)
99.49662780761719

Removed: (11, 8)
97.3982925415039

Removed: (11, 9)
97.30856323242188

Removed: (11, 11)
98.52055358886719

Removed: (10, 0)
98.54439544677734

Removed: (10, 1)
98.11260223388672

Removed: (10, 2)
99.50712585449219

Removed: (10, 3)
99.3063735961914

Removed: (10, 4)
99.43148040771484

Removed: (10, 5)
99.10381317138672

Removed: (10, 6)
99.18989562988281

Removed: (10, 7)
98.46072387695312

Removed: (10, 8)
98.5821304321289

Removed: (10, 9)
98.50953674316406

Removed: (10, 10)
98.78120422363281

Removed: (10, 11)
98.71550750732422

Removed: (9, 0)
98.85236358642578

Removed: (9, 2)
98.90916442871094

Removed: (9, 3)
98.90960693359375

Removed: (9, 4)
98.38214111328125

Removed: 

In [21]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (0, 2),
 (0, 3),
 (0, 5),
 (0, 8),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 5),
 (1, 7),
 (2, 0),
 (2, 2),
 (2, 3),
 (2, 4),
 (2, 6),
 (2, 9),
 (3, 3),
 (3, 7),
 (3, 11),
 (4, 4),
 (4, 7),
 (4, 10),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 6),
 (5, 8),
 (5, 11),
 (6, 1),
 (6, 3),
 (6, 6),
 (6, 8),
 (6, 10),
 (7, 2),
 (7, 6),
 (7, 8),
 (7, 10),
 (7, 11),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (11, 10)]

In [24]:
len(bf_3)

43

#### loop rmv and check for most impt heads

In [25]:
circ = bf_3
mean_ablate_by_lst(circ, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 97.0464


97.04641723632812

In [26]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 85.6726
removed: (0, 2)
Average logit difference (circuit / full) %: 93.0492
removed: (0, 3)
Average logit difference (circuit / full) %: 95.4401
removed: (0, 5)
Average logit difference (circuit / full) %: 95.8171
removed: (0, 8)
Average logit difference (circuit / full) %: 96.0902
removed: (0, 9)
Average logit difference (circuit / full) %: 88.2736
removed: (0, 10)
Average logit difference (circuit / full) %: 93.3338
removed: (1, 0)
Average logit difference (circuit / full) %: 95.2886
removed: (1, 5)
Average logit difference (circuit / full) %: 83.2106
removed: (1, 7)
Average logit difference (circuit / full) %: 96.9281
removed: (2, 0)
Average logit difference (circuit / full) %: 96.5791
removed: (2, 2)
Average logit difference (circuit / full) %: 96.1768
removed: (2, 3)
Average logit difference (circuit / full) %: 96.8371
removed: (2, 4)
Average logit difference (circuit / full) %: 96.9129
removed: (2, 6)
Average logit dif

In [27]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(4, 4): 61.07541275024414,
 (7, 11): 75.43741607666016,
 (9, 1): 78.79251098632812,
 (1, 5): 83.21063995361328,
 (0, 1): 85.67262268066406,
 (8, 11): 85.94454956054688,
 (0, 9): 88.27363586425781,
 (4, 10): 88.40310668945312,
 (0, 2): 93.04920959472656,
 (0, 10): 93.3338394165039,
 (5, 6): 93.94637298583984,
 (8, 6): 94.6375503540039,
 (6, 10): 94.64041137695312,
 (7, 10): 94.6871566772461,
 (6, 6): 95.02118682861328,
 (5, 4): 95.10240936279297,
 (1, 0): 95.28858184814453,
 (0, 3): 95.44009399414062,
 (4, 7): 95.59111785888672,
 (11, 10): 95.67446899414062,
 (7, 6): 95.779296875,
 (7, 2): 95.78099060058594,
 (0, 5): 95.81710815429688,
 (5, 3): 95.88018035888672,
 (7, 8): 95.90811157226562,
 (5, 8): 96.0057144165039,
 (0, 8): 96.09024810791016,
 (6, 3): 96.10076141357422,
 (2, 2): 96.1767807006836,
 (6, 8): 96.19751739501953,
 (5, 11): 96.27434539794922,
 (2, 6): 96.3456802368164,
 (5, 2): 96.34602355957031,
 (3, 3): 96.37776184082031,
 (8, 8): 96.4482192993164,
 (2, 0): 96.57907104492

# loop rmv and check for most impt heads

https://colab.research.google.com/drive/12HF5UCvMERizkhOiYJKDziahgVq_3KD9#scrollTo=C2EgKgmJS4qb&line=1&uniqifier=1

In [None]:
mean_ablate_by_lst(fb_3, model, orig_score, print_output=True).item()

In [None]:
for lh in fb_3:
    copy_circuit = fb_3.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=True).item()

# random removal

Because order seems to matter for fb and bf, try random

Stop condition: if choice of X previous heads don't end up in a better score, stop.

In [None]:
import random

def find_circuit_rand(curr_circuit=None, threshold=10, numRandIters=200):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == None:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for r in range(numRandIters):  #a little bit more than total number of heads
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        randNum = random.randint(0, len(curr_circuit) - 1)
        lh = curr_circuit[randNum]

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove(lh)

        new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove(lh)

            print("\nRemoved:", lh)
            print(new_score)

    return curr_circuit, new_score

In [None]:
rand_10, rand_score = find_circuit_rand(None, threshold=10)
rand_score


Removed: (2, 5)
99.83625793457031

Removed: (8, 10)
102.23005676269531

Removed: (6, 6)
101.97724151611328

Removed: (8, 8)
101.46814727783203

Removed: (6, 7)
101.4791488647461

Removed: (3, 8)
101.03436279296875

Removed: (5, 7)
101.07064056396484

Removed: (9, 11)
100.80884552001953

Removed: (10, 11)
100.70403289794922

Removed: (7, 10)
102.40557098388672

Removed: (0, 9)
101.17261505126953

Removed: (6, 1)
100.37947845458984

Removed: (8, 7)
99.57455444335938

Removed: (9, 2)
99.63677215576172

Removed: (8, 3)
100.70015716552734

Removed: (0, 5)
100.18648529052734

Removed: (7, 6)
98.87080383300781

Removed: (6, 5)
99.7287826538086

Removed: (4, 3)
99.5683364868164

Removed: (4, 1)
99.43795013427734

Removed: (7, 5)
98.3880386352539

Removed: (1, 0)
95.69098663330078

Removed: (4, 0)
95.67157745361328

Removed: (0, 2)
90.1860122680664

Removed: (2, 7)
90.2947769165039

Removed: (2, 4)
90.08573150634766

Removed: (0, 6)
91.59649658203125

Removed: (11, 0)
90.78836822509766

Remove

90.72700500488281

In [None]:
rand_10

[(0, 0),
 (0, 1),
 (0, 3),
 (0, 4),
 (0, 7),
 (0, 8),
 (0, 10),
 (1, 1),
 (1, 2),
 (1, 3),
 (1, 5),
 (1, 11),
 (2, 0),
 (2, 2),
 (2, 8),
 (3, 5),
 (3, 6),
 (3, 7),
 (3, 10),
 (3, 11),
 (4, 4),
 (4, 6),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 11),
 (6, 0),
 (6, 3),
 (6, 8),
 (6, 9),
 (6, 10),
 (7, 0),
 (7, 2),
 (7, 8),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 5),
 (8, 6),
 (8, 11),
 (9, 1),
 (9, 3),
 (9, 4),
 (9, 6),
 (9, 8),
 (10, 3),
 (10, 5),
 (10, 6),
 (10, 8),
 (11, 3),
 (11, 6),
 (11, 7),
 (11, 8),
 (11, 10)]

In [None]:
len(rand_10)

57

## bf after rand

In [None]:
threshold = 3
curr_circuit = rand_10
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1


## loop rmv and check for most impt heads

In [None]:
circ = rand_10
mean_ablate_by_lst(circ, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 90.7270


90.72700500488281

In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=True).item()
    lh_scores[lh] = new_score
lh_scores

In [None]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(4, 4): 44.29812240600586,
 (9, 1): 64.92707061767578,
 (1, 5): 74.46099090576172,
 (8, 11): 77.79611206054688,
 (7, 11): 81.18801879882812,
 (4, 10): 81.23968505859375,
 (0, 1): 82.17684936523438,
 (1, 11): 82.28862762451172,
 (5, 6): 86.2139892578125,
 (8, 6): 87.38471984863281,
 (6, 9): 88.07061004638672,
 (7, 0): 88.13016510009766,
 (5, 0): 88.3578109741211,
 (5, 4): 88.50399780273438,
 (11, 8): 88.52482604980469,
 (4, 11): 88.58031463623047,
 (6, 10): 88.68748474121094,
 (6, 8): 88.74179077148438,
 (3, 7): 88.83816528320312,
 (0, 8): 89.2793960571289,
 (9, 6): 89.29234313964844,
 (8, 0): 89.35371398925781,
 (7, 8): 89.46964263916016,
 (11, 10): 89.4909439086914,
 (6, 3): 89.51261901855469,
 (2, 2): 89.5793228149414,
 (1, 1): 89.70718383789062,
 (3, 6): 89.84967041015625,
 (8, 1): 89.85130310058594,
 (0, 3): 89.8616714477539,
 (2, 0): 89.9072036743164,
 (8, 5): 89.94767761230469,
 (7, 2): 89.95777893066406,
 (9, 4): 90.01729583740234,
 (10, 5): 90.12210845947266,
 (2, 8): 90.13710

## compare

In [None]:
set.intersection(set(fb_10), set(rand_10))

{(0, 10),
 (1, 3),
 (1, 5),
 (1, 11),
 (2, 0),
 (2, 2),
 (2, 8),
 (3, 7),
 (3, 11),
 (4, 4),
 (4, 6),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 4),
 (5, 5),
 (5, 6),
 (6, 3),
 (6, 9),
 (6, 10),
 (7, 0),
 (7, 2),
 (7, 8),
 (7, 11),
 (8, 0),
 (8, 6),
 (8, 11),
 (9, 1),
 (11, 8)}

In [None]:
mean_ablate_by_lst(set.intersection(set(fb_10), set(rand_10)), model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 69.0322


69.03221130371094

In [None]:
set.intersection(set(fb_20), set(rand_10))

{(1, 5),
 (3, 7),
 (3, 10),
 (3, 11),
 (4, 4),
 (4, 6),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 4),
 (5, 6),
 (6, 3),
 (6, 8),
 (6, 10),
 (7, 0),
 (7, 2),
 (7, 8),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 6),
 (8, 11),
 (9, 1),
 (11, 8)}

In [None]:
mean_ablate_by_lst(set.intersection(set(fb_20), set(rand_10)), model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 74.3014


74.30139923095703

In [None]:
len(set.intersection(set(fb_20), set(rand_10)))

24

# compare with desc

from: https://colab.research.google.com/drive/1odPpf7w_gBG8ZfAB2L6SXZszsDUk1CGA#scrollTo=ET--8aulD8pE&line=1&uniqifier=1

In [None]:
decr_circ = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (1, 0), (1, 5), (2, 2), (2, 4), (2, 9), (3, 0), (3, 3), (3, 7), (3, 10), (4, 6), (4, 7), (4, 10), (4, 11), (5, 1), (5, 5), (5, 6), (6, 1), (6, 7), (6, 9), (7, 2), (7, 10), (7, 11), (8, 1), (8, 6), (8, 8), (8, 10), (9, 5), (10, 7), (11, 0), (11, 8), (11, 11)]

In [None]:
set(incr_circ) - set(decr_circ)

In [None]:
set(decr_circ) - set(incr_circ)

# manual removal

In [None]:
# https://colab.research.google.com/drive/1CHRn-AMko9RNrl1bqiCwB7DS-rz1CoBP#scrollTo=KZiVdGTC6QlP&line=1&uniqifier=1
# V1 plus L0, L2, L3, L6 minus 6.3, 6.4
circuit = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (0, 10), (1, 5), (2, 2), (2, 9), (3, 0), (3, 3), (3, 7), (4, 4), (5, 5), (6, 1), (6, 6), (6, 9), (6, 10), (7, 10), (7, 11), (8, 8), (9, 1), (10, 7)]
mean_ablate_by_lst(circuit, model, orig_score, print_output=False).item()

41.985618591308594