<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-d7qmvrxp
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-d7qmvrxp
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

## Import functions from repo

In [None]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

In [None]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

In [None]:
import ioi_circuit_extraction as ioi_circuit_extraction

# test prompts

In [None]:
# modeltest = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
# example_prompt = "The war lasted from the year 1750 to the year 17"
# example_answer = " 51"
# utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'The', ' war', ' lasted', ' from', ' the', ' year', ' 17', '50', ' to', ' the', ' year', ' 17']
Tokenized answer: [' 51']


Top 0th token. Logit: 30.08 Prob: 26.90% Token: |60|
Top 1th token. Logit: 29.10 Prob: 10.12% Token: |75|
Top 2th token. Logit: 29.02 Prob:  9.33% Token: |70|
Top 3th token. Logit: 28.62 Prob:  6.29% Token: |90|
Top 4th token. Logit: 28.43 Prob:  5.19% Token: |80|
Top 5th token. Logit: 28.28 Prob:  4.45% Token: |50|
Top 6th token. Logit: 27.82 Prob:  2.83% Token: |55|
Top 7th token. Logit: 27.41 Prob:  1.87% Token: |65|
Top 8th token. Logit: 27.34 Prob:  1.75% Token: |76|
Top 9th token. Logit: 27.17 Prob:  1.47% Token: |71|


# Generate dataset with multiple prompts

greater-than paper, p2: use 10k examples of "The <noun> lasted from the year XXYY to the year XX”"

120 random nouns, years from 1000 to 1899 where YY in {2...98} inclusive

Here, we use the same noun (war), and the same prefix XX (17), so there are 97 examples. If we start at 11 to avoid front '0' if single digit, it's 98-11+1 = 88 examples

Thus our 01-corruption only needs 1 sample

In [None]:
def get_prompts_pos_dicts(input_text, YY):
    pos_dict = {}
    prompt_dict = {}
    tokens_list = model.tokenizer(input_text)['input_ids']

    for index, token in enumerate(tokens_list):
        token_as_string = model.tokenizer.decode(token)
        if token_as_string == str(YY):
            key = 'YY'
        else:
            key = 'T'+str(index)
        # key = 'T'+str(index)
        pos_dict[key] = index
        prompt_dict[key] = token_as_string
    prompt_dict['text'] = input_text

    return pos_dict, prompt_dict

In [None]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, YY_int_list):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )

        # list of YY as int for each prompt
        self.YY_int_list = YY_int_list

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [None]:
def generate_prompts_list(x, y):
    prompts_list = []
    for YY in range(x, y):
        input_text = f'The war lasted from the year 17{YY} to the year 17'
        pos_dict, prompt_dict = get_prompts_pos_dicts(input_text, YY)
        prompts_list.append(prompt_dict)
    return pos_dict, prompts_list

# pos_dict, prompts_list = generate_prompts_list(45, 55)
pos_dict, prompts_list = generate_prompts_list(10, 90)
YY_int_list = [i for i in range(10, 90)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

In [None]:
def generate_prompts_list_corr(x, y):
    prompts_list = []
    # for i in range(x, y):
    YY = '01'
    input_text = f'The war lasted from the year 17{YY} to the year 17'
    pos_dict, prompt_dict = get_prompts_pos_dicts(input_text, YY)
    prompts_list.append(prompt_dict)
    return pos_dict, prompts_list

# prompts_list = generate_prompts_list(45, 55)
pos_dict, prompts_list_2 = generate_prompts_list_corr(10, 90)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

# Ablation Expm Functions

In [None]:
# obtain the logits of each number between YY and 99, where YY is a two digit integer

import torch

def get_logits_for_range(logits, start_num, end_num):
    """
    :param logits: The logits tensor with dimensions [batch size, seq len, vocab size]
    :param start_num: The starting number
    :param end_num: The ending number
    :return: A tensor containing logits for numbers between start_num and end_num
    """
    # Getting indices for numbers between start_num and end_num
    indices = []
    for num in range(start_num, end_num+1):
        num_as_vocabID = model.tokenizer(str(num))['input_ids'][0]
        indices.append(num_as_vocabID)

    # Extract logits for these indices
    # logits_for_range = logits[:, logits.size(1)-1, indices]
    logits_for_range = logits[logits.size(0)-1, indices]

    return logits_for_range

Each sample in the batch has a different indices due to having a different YY.

In [None]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)
ioi_logits_original.size()

torch.Size([87, 12, 50257])

In [None]:
ioi_logits_original.size(0)

87

In [None]:
ioi_logits_original[0].size()

torch.Size([12, 50257])

In [None]:
YY = dataset.YY_int_list[0]

In [None]:
ioi_logits_original[0, :, :].size()

torch.Size([12, 50257])

In [None]:
logits_greaterThan = get_logits_for_range(ioi_logits_original[0, :, :], YY, 99)
logits_greaterThan.size() # output size is 1 dim tensor with size number of tokens between YY and 99 inclusive

torch.Size([89])

In [None]:
sum(logits_greaterThan)

tensor(2033.8448, device='cuda:0')

In [None]:
 answer_logit_perSamp = []
 answer_logit_perSamp.append(sum(logits_greaterThan).item())
 sum(answer_logit_perSamp)/len(answer_logit_perSamp)

2033.8448486328125

In [None]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.
    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Get the right logits; anything greater than YY
    # range(logits.size(0)) for every input in the batch
    # dataset.word_idx["end"]: at the last pos, so "what's the next prediction after end?"
    # what's the logit of the YY token (whose pos at an input seq is recorded in the dataset by dataset.YY_tokenIDs)
    # only correct dataset indices of corr and incorr tokens matters

    # we need to measure the greaterThan and lessThan differently for each prompt
    answer_logit_perSamp = []
    for samp_id in range(logits.size(0)):
        YY = dataset.YY_int_list[samp_id]

        logits_greaterThan = get_logits_for_range(logits[samp_id, :, :], YY, 99)
        # logits_greaterThan_sum = logits_greaterThan.sum(dim=1)
        logits_greaterThan_sum = sum(logits_greaterThan)

        # get the wrong logits; anything less than YY
        logits_lessThan = get_logits_for_range(logits[samp_id, :, :], 00, YY-1)
        # logits_lessThan_sum = logits_lessThan.sum(dim=1)
        logits_lessThan_sum = sum(logits_lessThan)

        # Find logit difference of corr minus incorr; sum up all tokens between YY and 99, minus sum of all YY and 00
        GL_diff = logits_greaterThan_sum - logits_lessThan_sum
        answer_logit_perSamp.append(GL_diff.item())

    return answer_logit_perSamp if per_prompt else sum(answer_logit_perSamp)/len(answer_logit_perSamp)

In [None]:
orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
orig_score

94.07917592717313

In [None]:
CIRCUIT = {}
SEQ_POS_TO_KEEP = {}
lst =  [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
for ind, key in enumerate(pos_dict.keys()):
    headName = "head" + str(ind)
    CIRCUIT[headName] = lst
    SEQ_POS_TO_KEEP[headName] = key

In [None]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
ioi_logits_minimal = model(dataset.toks)  # make sure text in clean vs corr have same num tokens for each prompt
new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)

In [None]:
new_score

38.141064874057115

In [None]:
def mean_ablate_by_lst(lst, model, print_output=True):
    # CIRCUIT = {
    #     "number mover": lst,
    #     "number mover 2": lst,
    # }

    # SEQ_POS_TO_KEEP = {
    #     "number mover": "end",
    #     "number mover 2": "YY",
    # }
    CIRCUIT = {}
    SEQ_POS_TO_KEEP = {}

    for ind, key in enumerate(pos_dict.keys()):
        headName = "head" + str(ind)
        CIRCUIT[headName] = lst
        SEQ_POS_TO_KEEP[headName] = key

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)  # make sure text in clean vs corr have same num tokens for each prompt

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return (100 * new_score / orig_score)

# try other tasks circs

In [None]:
# fb 80, digits incr
# https://colab.research.google.com/drive/1mFWmGAKtigFcqqWWMCwU7wWQY2HT5ZOo#scrollTo=lJEY-Zs2g_a5&line=1&uniqifier=1
circuit = [(1, 5), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (6, 3), (6, 8), (6, 10), (7, 0), (7, 2), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (11, 8)]
mean_ablate_by_lst(circuit, model, print_output=True)

Average logit difference (circuit / full) %: 64.1249


64.12491445516194

In [None]:
# fb 80, numwords
# https://colab.research.google.com/drive/1QTv-4osLHadCAay0beew-xlXszPCG88s#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(3, 2), (4, 4), (4, 8), (4, 10), (4, 11), (5, 5), (5, 6), (5, 7), (5, 8), (6, 1), (6, 7), (6, 9), (6, 10), (7, 0), (7, 2), (7, 5), (7, 6), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (10, 2)]
mean_ablate_by_lst(circuit, model, print_output=True)

Average logit difference (circuit / full) %: 81.7744


81.77439646862463

In [None]:
# fb 80, months
# https://colab.research.google.com/drive/1lhQqlizYGMC11vzp6I9mJ3dyxIr8tV3l#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(4, 4), (7, 11), (8, 6), (8, 9), (8, 11), (9, 1), (9, 5), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True)

Average logit difference (circuit / full) %: 48.0494


48.049414541567856

In [None]:
# greater-than
circuit = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, print_output=True)

Average logit difference (circuit / full) %: 85.3844


85.38436010989578

In [None]:
# IOI
circuit = [(0, 1), (0, 10), (2, 2), (3, 0), (4, 11), (5, 5), (5, 8), (5, 9), (6, 9), (7, 3), (7, 9), (8, 6), (8, 10), (9, 0), (9, 6), (9, 7), (9, 9), (10, 0), (10, 1), (10, 2), (10, 6), (10, 7), (10, 10), (11, 2), (11, 9), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True)

Average logit difference (circuit / full) %: 28.5169


28.51689814059981

# Ablate the model tests

Check on original 1 sample instead to check it has same score (it must have, else error)

In [None]:
pos_dict, prompts_list = generate_prompts_list(50, 51)
YY_int_list = [i for i in range(50, 51)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(50, 51)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]

In [None]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)
orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
orig_score

303.3883056640625

In [None]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
ioi_logits_minimal = model(dataset.toks)  # make sure text in clean vs corr have same num tokens for each prompt
new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)

In [None]:
new_score

258.4454345703125

In [None]:
(100 * new_score / orig_score)

85.18635350977746

In [None]:
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 85.1864


## check on 45 to 55

In [None]:
pos_dict, prompts_list = generate_prompts_list(45, 55)
YY_int_list = [i for i in range(45, 55)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(45, 55)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 90.6458


## check on 11 to 98

In [None]:
x, y = 11, 98
pos_dict, prompts_list = generate_prompts_list(x, y)
YY_int_list = [i for i in range(x, y)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(x, y)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 40.5413


## check on 11 to 90

In [None]:
x, y = 11, 90
pos_dict, prompts_list = generate_prompts_list(x, y)
YY_int_list = [i for i in range(x, y)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(x, y)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 84.0732


In [None]:
x, y = 11, 97
pos_dict, prompts_list = generate_prompts_list(x, y)
YY_int_list = [i for i in range(x, y)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(x, y)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 53.3356


In [None]:
x, y = 11, 93
pos_dict, prompts_list = generate_prompts_list(x, y)
YY_int_list = [i for i in range(x, y)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(x, y)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 76.6722


In [None]:
x, y = 11, 91
pos_dict, prompts_list = generate_prompts_list(x, y)
YY_int_list = [i for i in range(x, y)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(x, y)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 82.1590


In [None]:
x, y = 11, 80
pos_dict, prompts_list = generate_prompts_list(x, y)
YY_int_list = [i for i in range(x, y)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(x, y)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 94.5349


In [None]:
x, y = 11, 89
pos_dict, prompts_list = generate_prompts_list(x, y)
YY_int_list = [i for i in range(x, y)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(x, y)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 85.8438


In [None]:
x, y = 11, 85
pos_dict, prompts_list = generate_prompts_list(x, y)
YY_int_list = [i for i in range(x, y)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(x, y)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 91.1493


In [None]:
x, y = 11, 70
pos_dict, prompts_list = generate_prompts_list(x, y)
YY_int_list = [i for i in range(x, y)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(x, y)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 98.9650


In [None]:
x, y = 11, 60
pos_dict, prompts_list = generate_prompts_list(x, y)
YY_int_list = [i for i in range(x, y)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(x, y)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 101.2740


In [None]:
x, y = 10, 90
pos_dict, prompts_list = generate_prompts_list(x, y)
YY_int_list = [i for i in range(x, y)]
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, YY_int_list)

pos_dict, prompts_list_2 = generate_prompts_list_corr(x, y)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, YY_int_list)

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 85.3843


So the closer y is to the end, the worse this gets, due to the inequality between less than and greater than. To avoid skewing the data either way, stick with 10 to 90

In [None]:
dataset.N

80

## Test Greater-Than vs other circuits

See how greater-than circuit performs on the greater-than task; it should be similar to the paper. Else, either greater-than paper has issues (less likely) or this mean ablation code/setup was not generalized correctly (more likely).

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)
new_score

Average logit difference (circuit / full) %: 85.3843


85.38428847602803

## sanity check tests

In [None]:
greater_than = []
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 15.7910


Likely still has score due to MLPs

In [None]:
greater_than = [(0, 1), (0, 3), (9,1)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 40.9717


In [None]:
greater_than = [(layer, head) for layer in range(12) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 100.0000


In [None]:
import random
num_of_tuples = 9  # Number of tuples you want
greater_than = [(random.randint(0, 9), random.randint(0, 9)) for _ in range(num_of_tuples)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 15.4973


### add heads to orig paper circ

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(10, 7)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 85.3383


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(layer, head) for layer in range(0) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 85.3843


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(layer, head) for layer in range(0, 4) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 81.8693


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(layer, head) for layer in range(0, 6) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 85.6789


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(layer, head) for layer in range(0, 9) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 91.4473


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(layer, head) for layer in range(8, 9) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 90.2036


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(layer, head) for layer in range(5, 9) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 94.2599


# Ablate by seq pos

In [None]:
def mean_ablate_by_seqpos(lst, model, print_output=True):
    CIRCUIT = {}
    SEQ_POS_TO_KEEP = {}
    for ind, key in enumerate(["YY", "end"]):
        headName = "head" + str(ind)
        CIRCUIT[headName] = lst
        SEQ_POS_TO_KEEP[headName] = key

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)  # make sure text in clean vs corr have same num tokens for each prompt

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    return new_score

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
new_score = mean_ablate_by_seqpos(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 90.9373


In [None]:
def mean_ablate_by_seqpos(lst, model, print_output=True):
    CIRCUIT = {
        "end": [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)],
        "YY": [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9)],
    }

    SEQ_POS_TO_KEEP = {
        "end": "end",
        "YY": "YY",
    }
    # CIRCUIT = {}
    # SEQ_POS_TO_KEEP = {}

    # # for ind, key in enumerate(pos_dict.keys()):
    # # ind = 7
    # # key = "T7"

    # for ind, key in enumerate(["T7", "end"]):
    #     headName = "head" + str(ind)
    #     CIRCUIT[headName] = lst
    #     SEQ_POS_TO_KEEP[headName] = key

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)  # make sure text in clean vs corr have same num tokens for each prompt

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    return new_score

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
new_score = mean_ablate_by_seqpos(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 91.0982


In [None]:
def mean_ablate_by_seqpos(lst, model, print_output=True):
    CIRCUIT = {
        "end": [(9, 1)],
        "YY": [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11)],
    }

    SEQ_POS_TO_KEEP = {
        "end": "end",
        "YY": "YY",
    }
    # CIRCUIT = {}
    # SEQ_POS_TO_KEEP = {}

    # # for ind, key in enumerate(pos_dict.keys()):
    # # ind = 7
    # # key = "T7"

    # for ind, key in enumerate(["T7", "end"]):
    #     headName = "head" + str(ind)
    #     CIRCUIT[headName] = lst
    #     SEQ_POS_TO_KEEP[headName] = key

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)  # make sure text in clean vs corr have same num tokens for each prompt

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    return new_score

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
new_score = mean_ablate_by_seqpos(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 36.1281


In [None]:
def mean_ablate_by_seqpos(lst, model, print_output=True):
    CIRCUIT = {
        "end": [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)],
        # "YY": [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11)],
    }

    SEQ_POS_TO_KEEP = {
        "end": "end",
        # "YY": "T7",
    }
    # CIRCUIT = {}
    # SEQ_POS_TO_KEEP = {}

    # # for ind, key in enumerate(pos_dict.keys()):
    # # ind = 7
    # # key = "T7"

    # for ind, key in enumerate(["T7", "end"]):
    #     headName = "head" + str(ind)
    #     CIRCUIT[headName] = lst
    #     SEQ_POS_TO_KEEP[headName] = key

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)  # make sure text in clean vs corr have same num tokens for each prompt

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    return new_score

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
new_score = mean_ablate_by_seqpos(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 82.2822


# Prune backwards

You need to modify this because the prev "scores" were just of ONE token, whereas here they are sums of MULTIPLE tokens so they will be greater than 100 at times. So can't do 100-new_score if return raw score, must return newscore/oldscore as percentage.

In [None]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(11, -1, -1):  # go thru all heads in a layer first
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False)

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

Removed: (11, 0)
98.43448974119529


Removed: (11, 1)
98.57008671326153


Removed: (11, 2)
98.62260420787868


Removed: (11, 3)
98.6062238146972


Removed: (11, 4)
98.36458469047959


Removed: (11, 5)
98.38953423263045


Removed: (11, 6)
98.4319515733463


Removed: (11, 7)
98.41585562771179


Removed: (11, 9)
98.37635180830902


Removed: (11, 10)
97.97124838318094


Removed: (11, 11)
97.95907738715684


Removed: (10, 0)
97.99258141603245


Removed: (10, 1)
98.03429630193449


Removed: (10, 2)
98.29856913340076


Removed: (10, 3)
98.21821869956787


Removed: (10, 5)
98.2302404072824


Removed: (10, 6)
98.3037494377429


Removed: (10, 8)
98.6005337974124


Removed: (10, 9)
98.58944967359925


Removed: (10, 10)
98.43119020296747


Removed: (10, 11)
98.42999483014583


Removed: (9, 0)
98.46547293032926


Removed: (9, 2)
98.56369237505905


Removed: (9, 3)
98.42277274158404


Removed: (9, 4)
98.51721625842676


Removed: (9, 5)
99.17450975905143


Removed: (9, 6)
101.00485987831254


Removed

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True)

Average logit difference (circuit / full) %: 99.4358


99.43581415243364

In [None]:
backw_3 = curr_circuit.copy()
backw_3

[(0, 3),
 (0, 5),
 (5, 1),
 (5, 5),
 (6, 9),
 (7, 10),
 (8, 8),
 (8, 11),
 (9, 1),
 (10, 4),
 (10, 7),
 (11, 8)]

In [None]:
len(backw_3)

12

## prune by 10% threshold

In [None]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 10  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(11, -1, -1):  # go thru all heads in a layer first
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False)

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

Removed: (11, 0)
98.43448974119529


Removed: (11, 1)
98.57008671326153


Removed: (11, 2)
98.62260420787868


Removed: (11, 3)
98.6062238146972


Removed: (11, 4)
98.36458469047959


Removed: (11, 5)
98.38953423263045


Removed: (11, 6)
98.4319515733463


Removed: (11, 7)
98.41585562771179


Removed: (11, 8)
96.93358933008298


Removed: (11, 9)
96.89825012134473


Removed: (11, 10)
96.4948229906641


Removed: (11, 11)
96.48109806328915


Removed: (10, 0)
96.54533868900167


Removed: (10, 1)
96.58525491735963


Removed: (10, 2)
96.81215342052467


Removed: (10, 3)
96.76442034964302


Removed: (10, 5)
96.77290019221297


Removed: (10, 6)
96.8510606198617


Removed: (10, 7)
95.40291789759321


Removed: (10, 8)
95.67583692146007


Removed: (10, 9)
95.67065821663553


Removed: (10, 10)
95.54388684952765


Removed: (10, 11)
95.54294073486571


Removed: (9, 0)
95.5713479009008


Removed: (9, 2)
95.66676632372227


Removed: (9, 3)
95.5422540086417


Removed: (9, 4)
95.6389141902953


Removed:

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True)

Average logit difference (circuit / full) %: 92.2317


92.23174415643436

In [None]:
backw_10 = curr_circuit.copy()
backw_10

[(0, 1), (5, 1), (5, 5), (6, 9), (7, 10), (8, 11), (9, 1), (10, 4)]

In [None]:
len(backw_10)

8

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]

In [None]:
mean_ablate_by_lst([(0, 1), (5, 5), (6, 9), (7, 10), (8, 11), (9, 1)], model, print_output=True)

Average logit difference (circuit / full) %: 81.5705


81.57053386373906

# Prune fwds-backwds iteratively

In [None]:
def find_circuit_forw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False)

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [None]:
def find_circuit_backw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False)

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

### iter fwd backw, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
100.20776640653463

Removed: (0, 2)
101.16978347244907

Removed: (0, 3)
99.25190188507842

Removed: (0, 4)
99.42954244391403

Removed: (0, 6)
99.39908496289875

Removed: (0, 7)
98.74450424412802

Removed: (0, 8)
98.57850017585096

Removed: (0, 9)
98.4728363096126

Removed: (0, 10)
99.19824446758318

Removed: (0, 11)
99.02801927235299

Removed: (1, 0)
99.2378589202868

Removed: (1, 1)
99.30176897902813

Removed: (1, 2)
99.26472841673328

Removed: (1, 3)
99.01089030493353

Removed: (1, 4)
99.12583803792336

Removed: (1, 5)
100.31772204508688

Removed: (1, 6)
100.17648490736778

Removed: (1, 7)
100.65617970320078

Removed: (1, 8)
100.6652345723488

Removed: (1, 9)
100.56871381531292

Removed: (1, 10)
100.59008070462035

Removed: (1, 11)
100.92439187993585

Removed: (2, 0)
101.07729323318533

Removed: (2, 1)
102.02257641252812

Removed: (2, 2)
102.1973778276192

Removed: (2, 3)
102.17504456309356

Removed: (2, 4)
100.44306104418172

Removed: (2, 5)
100.

In [None]:
fb_3 = curr_circuit.copy()
fb_3

[(0, 1), (5, 1), (5, 5), (6, 9), (7, 10), (8, 8), (8, 11), (9, 1), (10, 4)]

In [None]:
mean_ablate_by_lst(fb_3, model, print_output=True)

Average logit difference (circuit / full) %: 97.8576


97.85758452965342

#### compare

In [None]:
set(backw_3) - set(fb_3)

{(0, 3), (0, 5), (10, 7), (11, 8)}

In [None]:
set(fb_3) - set(backw_3)

{(0, 1)}

### iter fwd backw, threshold 20

In [None]:
# threshold = 20
# curr_circuit = []
# prev_score = 100
# new_score = 0
# iter = 1
# while prev_score != new_score:
#     print('\nfwd prune, iter ', str(iter))
#     # track changes in circuit as for some reason it doesn't work with scores
#     old_circuit = curr_circuit.copy() # save old before finding new one
#     curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
#     if curr_circuit == old_circuit:
#         break
#     print('\nbackw prune, iter ', str(iter))
#     # prev_score = new_score # save old score before finding new one
#     old_circuit = curr_circuit.copy() # save old before finding new one
#     curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
#     if curr_circuit == old_circuit:
#         break
#     iter += 1

In [None]:
# curr_circuit

# Prune backwds-fwds iteratively

### iter fwd backw, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
98.43448974119529

Removed: (11, 1)
98.57008671326153

Removed: (11, 2)
98.62260420787868

Removed: (11, 3)
98.6062238146972

Removed: (11, 4)
98.36458469047959

Removed: (11, 5)
98.38953423263045

Removed: (11, 6)
98.4319515733463

Removed: (11, 7)
98.41585562771179

Removed: (11, 9)
98.37635180830902

Removed: (11, 10)
97.97124838318094

Removed: (11, 11)
97.95907738715684

Removed: (10, 0)
97.99258141603245

Removed: (10, 1)
98.03429630193449

Removed: (10, 2)
98.29856913340076

Removed: (10, 3)
98.21821869956787

Removed: (10, 5)
98.2302404072824

Removed: (10, 6)
98.3037494377429

Removed: (10, 8)
98.6005337974124

Removed: (10, 9)
98.58944967359925

Removed: (10, 10)
98.43119020296747

Removed: (10, 11)
98.42999483014583

Removed: (9, 0)
98.46547293032926

Removed: (9, 2)
98.56369237505905

Removed: (9, 3)
98.42277274158404

Removed: (9, 4)
98.51721625842676

Removed: (9, 5)
99.17450975905143

Removed: (9, 6)
101.00485987831254

Removed: (9

In [None]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 5),
 (5, 1),
 (5, 5),
 (6, 9),
 (7, 10),
 (8, 8),
 (8, 11),
 (9, 1),
 (10, 4),
 (11, 8)]

#### compare

In [None]:
len(bf_3)

10

In [None]:
len(fb_3)

9

In [None]:
set(backw_3) - set(bf_3)

{(0, 3), (10, 7)}

In [None]:
set(bf_3) - set(backw_3)

set()

In [None]:
set(fb_3) - (set(fb_3) - set(bf_3))

{(5, 1), (5, 5), (6, 9), (7, 10), (8, 8), (8, 11), (9, 1), (10, 4)}

In [None]:
set(bf_3) - set(fb_3)

{(0, 5), (11, 8)}

Get score of fb_3 without nodes it has that bf_3 doesn't have

this is set intersection

In [None]:
mean_ablate_by_lst(list(set(fb_3) - (set(fb_3) - set(bf_3))), model, print_output=True)

Average logit difference (circuit / full) %: 91.7935


91.79345500643358

In [None]:
mean_ablate_by_lst(list(set(bf_3) - (set(bf_3) - set(fb_3))), model, print_output=True)

Average logit difference (circuit / full) %: 91.7935


91.79345500643358

In [None]:
(set(fb_3) - (set(fb_3) - set(bf_3))) == (set(bf_3) - (set(bf_3) - set(fb_3)))

True