<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-8wt8uhwj
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-8wt8uhwj
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [2]:
# # Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
# import plotly.io as pio

# if IN_COLAB or not DEBUG_MODE:
#     # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
#     pio.renderers.default = "colab"
# else:
#     pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x79a69aff7a60>

Plotting helper functions:

In [6]:
# def imshow(tensor, renderer=None, **kwargs):
#     px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

# def line(tensor, renderer=None, **kwargs):
#     px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

# def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
#     x = utils.to_numpy(x)
#     y = utils.to_numpy(y)
#     px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [8]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9106, done.[K
remote: Counting objects: 100% (1818/1818), done.[K
remote: Compressing objects: 100% (288/288), done.[K
remote: Total 9106 (delta 1611), reused 1607 (delta 1527), pack-reused 7288[K
Receiving objects: 100% (9106/9106), 155.60 MiB | 15.57 MiB/s, done.
Resolving deltas: 100% (5506/5506), done.


In [9]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [10]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

In [67]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(prompt["corr"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [68]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
}

In [88]:
def generate_prompts_list(x ,y):
    words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve', ' thirteen', ' fourteen', ' fifteen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty']
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': words[i],
            'S2': words[i+1],
            'S3': words[i+2],
            'S4': words[i+3],
            'corr': words[i+4],
            'incorr': words[i+3],  # this is arbitrary
            'text': f"{words[i]}{words[i+1]}{words[i+2]}{words[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(0, 16)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [70]:
# import random

# def generate_prompts_list_corr(x ,y):
#     words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve', ' thirteen', ' fourteen', ' fifteen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty']
#     prompts_list = []
#     for i in range(x, y):
#         r1 = random.choice(words)
#         r2 = random.choice(words)
#         while True:
#             r3_ind = random.randint(0,len(words)-1)
#             r4_ind = random.randint(0,len(words)-1)
#             if words[r3_ind] != words[r4_ind-1]:
#                 break
#         r3 = words[r3_ind]
#         r4 = words[r4_ind]
#         prompt_dict = {
#             'S1': str(r1),
#             'S2': str(r2),
#             'S3': str(r3),
#             'S4': str(r4),
#             'corr': str(r1),
#             'incorr': str(r4),
#             'text': f"{r1}{r2}{r3}{r4}"
#         }
#         prompts_list.append(prompt_dict)
#     return prompts_list

# # prompts_list_2 = generate_prompts_list_corr(0, 8)
# prompts_list_2 = generate_prompts_list_corr(0, 16)
# dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

In [91]:
prompts_list_2 = [{'S1': ' ten',
  'S2': ' two',
  'S3': ' thirteen',
  'S4': ' four',
  'corr': ' ten',
  'incorr': ' four',
  'text': ' ten two thirteen four'},
 {'S1': ' fifteen',
  'S2': ' seven',
  'S3': ' eight',
  'S4': ' eleven',
  'corr': ' fifteen',
  'incorr': ' eleven',
  'text': ' fifteen seven eight eleven'},
 {'S1': ' eight',
  'S2': ' eight',
  'S3': ' eight',
  'S4': ' one',
  'corr': ' eight',
  'incorr': ' one',
  'text': ' eight eight eight one'},
 {'S1': ' sixteen',
  'S2': ' eleven',
  'S3': ' thirteen',
  'S4': ' sixteen',
  'corr': ' sixteen',
  'incorr': ' sixteen',
  'text': ' sixteen eleven thirteen sixteen'},
 {'S1': ' eight',
  'S2': ' fifteen',
  'S3': ' three',
  'S4': ' twenty',
  'corr': ' eight',
  'incorr': ' twenty',
  'text': ' eight fifteen three twenty'},
 {'S1': ' fourteen',
  'S2': ' three',
  'S3': ' four',
  'S4': ' seven',
  'corr': ' fourteen',
  'incorr': ' seven',
  'text': ' fourteen three four seven'},
 {'S1': ' seventeen',
  'S2': ' twelve',
  'S3': ' nineteen',
  'S4': ' ten',
  'corr': ' seventeen',
  'incorr': ' ten',
  'text': ' seventeen twelve nineteen ten'},
 {'S1': ' ten',
  'S2': ' ten',
  'S3': ' six',
  'S4': ' three',
  'corr': ' ten',
  'incorr': ' three',
  'text': ' ten ten six three'},
 {'S1': ' nine',
  'S2': ' one',
  'S3': ' one',
  'S4': ' thirteen',
  'corr': ' nine',
  'incorr': ' thirteen',
  'text': ' nine one one thirteen'},
 {'S1': ' eleven',
  'S2': ' ten',
  'S3': ' four',
  'S4': ' eighteen',
  'corr': ' eleven',
  'incorr': ' eighteen',
  'text': ' eleven ten four eighteen'},
 {'S1': ' twenty',
  'S2': ' seven',
  'S3': ' twelve',
  'S4': ' fourteen',
  'corr': ' twenty',
  'incorr': ' fourteen',
  'text': ' twenty seven twelve fourteen'},
 {'S1': ' thirteen',
  'S2': ' eight',
  'S3': ' ten',
  'S4': ' one',
  'corr': ' thirteen',
  'incorr': ' one',
  'text': ' thirteen eight ten one'},
 {'S1': ' three',
  'S2': ' sixteen',
  'S3': ' seven',
  'S4': ' six',
  'corr': ' three',
  'incorr': ' six',
  'text': ' three sixteen seven six'},
 {'S1': ' twenty',
  'S2': ' one',
  'S3': ' sixteen',
  'S4': ' nine',
  'corr': ' twenty',
  'incorr': ' nine',
  'text': ' twenty one sixteen nine'},
 {'S1': ' nineteen',
  'S2': ' eight',
  'S3': ' thirteen',
  'S4': ' ten',
  'corr': ' nineteen',
  'incorr': ' ten',
  'text': ' nineteen eight thirteen ten'},
 {'S1': ' four',
  'S2': ' fourteen',
  'S3': ' three',
  'S4': ' one',
  'corr': ' four',
  'incorr': ' one',
  'text': ' four fourteen three one'}]

dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

# Ablation Expm Functions

In [16]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [17]:
def mean_ablate_by_lst(lst, model, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

We can also prevent redundant computation of the full circuit score by storing it and just passing it in to the function.

In [18]:
def find_circuit_forw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [19]:
def find_circuit_backw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

# try other tasks circs

## gt, IOI

In [None]:
# greater-than
circuit = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 20.8534


20.853376388549805

In [None]:
# IOI
circuit = [(0, 1), (0, 10), (2, 2), (3, 0), (4, 11), (5, 5), (5, 8), (5, 9), (6, 9), (7, 3), (7, 9), (8, 6), (8, 10), (9, 0), (9, 6), (9, 7), (9, 9), (10, 0), (10, 1), (10, 2), (10, 6), (10, 7), (10, 10), (11, 2), (11, 9), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 12.2694


12.269357681274414

## fb 80

In [None]:
# fb 80, digits incr
# https://colab.research.google.com/drive/1mFWmGAKtigFcqqWWMCwU7wWQY2HT5ZOo#scrollTo=lJEY-Zs2g_a5&line=1&uniqifier=1
circuit = [(1, 5), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (6, 3), (6, 8), (6, 10), (7, 0), (7, 2), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (11, 8)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 74.7523


74.7523422241211

In [None]:
# fb 80, numwords
# https://colab.research.google.com/drive/1QTv-4osLHadCAay0beew-xlXszPCG88s#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(3, 2), (4, 4), (4, 8), (4, 10), (4, 11), (5, 5), (5, 6), (5, 7), (5, 8), (6, 1), (6, 7), (6, 9), (6, 10), (7, 0), (7, 2), (7, 5), (7, 6), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (10, 2)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 81.9778


81.9777603149414

In [None]:
# fb 80, months
# https://colab.research.google.com/drive/1lhQqlizYGMC11vzp6I9mJ3dyxIr8tV3l#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(4, 4), (7, 11), (8, 6), (8, 9), (8, 11), (9, 1), (9, 5), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 36.0939


36.09392547607422

## bf 97

In [71]:
# digits incr
circuit = [(0, 1), (0, 2), (0, 3), (0, 5), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (1, 7), (2, 0), (2, 2), (2, 3), (2, 4), (2, 6), (2, 9), (3, 3), (3, 7), (3, 11), (4, 4), (4, 7), (4, 10), (5, 2), (5, 3), (5, 4), (5, 6), (5, 8), (5, 11), (6, 1), (6, 3), (6, 6), (6, 8), (6, 10), (7, 2), (7, 6), (7, 8), (7, 10), (7, 11), (8, 6), (8, 8), (8, 11), (9, 1), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 87.6954


87.69536590576172

In [72]:
# numwords
# https://colab.research.google.com/drive/1QTv-4osLHadCAay0beew-xlXszPCG88s#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(0, 1), (1, 0), (1, 5), (3, 2), (4, 4), (4, 8), (4, 10), (5, 4), (5, 6), (5, 8), (6, 9), (6, 10), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (9, 5), (10, 2), (11, 8)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 66.7360


66.7359848022461

In [73]:
# months
# https://colab.research.google.com/drive/1lhQqlizYGMC11vzp6I9mJ3dyxIr8tV3l#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
# circuit = [(4, 4), (7, 11), (8, 6), (8, 9), (8, 11), (9, 1), (9, 5), (11, 10)]
# mean_ablate_by_lst(circuit, model, print_output=True).item()

# iter fwd backw, threshold 3

In [26]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
102.97616577148438

Removed: (0, 2)
106.65188598632812

Removed: (0, 3)
106.6739273071289

Removed: (0, 4)
110.32544708251953

Removed: (0, 5)
106.00032806396484

Removed: (0, 6)
110.14366149902344

Removed: (0, 7)
111.17089080810547

Removed: (0, 8)
112.52427673339844

Removed: (0, 9)
114.35295867919922

Removed: (0, 10)
111.60527801513672

Removed: (0, 11)
112.82598114013672

Removed: (1, 0)
113.95602416992188

Removed: (1, 1)
113.67182159423828

Removed: (1, 2)
113.4958724975586

Removed: (1, 3)
112.94862365722656

Removed: (1, 4)
112.47454071044922

Removed: (1, 6)
112.90814208984375

Removed: (1, 7)
111.85087585449219

Removed: (1, 8)
111.29533386230469

Removed: (1, 9)
109.3895034790039

Removed: (1, 10)
110.20083618164062

Removed: (1, 11)
108.8719482421875

Removed: (2, 0)
108.31248474121094

Removed: (2, 1)
108.94599151611328

Removed: (2, 2)
108.74954986572266

Removed: (2, 3)
107.06128692626953

Removed: (2, 4)
105.51287078857422

Removed

In [27]:
fb_3 = curr_circuit.copy()
len(fb_3)

26

In [28]:
fb_3

[(0, 1),
 (1, 5),
 (3, 6),
 (3, 10),
 (4, 4),
 (4, 5),
 (4, 7),
 (4, 8),
 (4, 9),
 (4, 10),
 (4, 11),
 (5, 2),
 (5, 3),
 (5, 8),
 (6, 1),
 (6, 6),
 (6, 10),
 (7, 6),
 (7, 8),
 (7, 9),
 (7, 11),
 (8, 8),
 (8, 10),
 (8, 11),
 (9, 1),
 (10, 7)]

#### loop rmv and check for most impt heads

In [29]:
circ = fb_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 97.0569


97.056884765625

In [30]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 81.2147
removed: (1, 5)
Average logit difference (circuit / full) %: 82.4977
removed: (3, 6)
Average logit difference (circuit / full) %: 96.1502
removed: (3, 10)
Average logit difference (circuit / full) %: 93.6352
removed: (4, 4)
Average logit difference (circuit / full) %: 63.5501
removed: (4, 5)
Average logit difference (circuit / full) %: 96.6572
removed: (4, 7)
Average logit difference (circuit / full) %: 96.0463
removed: (4, 8)
Average logit difference (circuit / full) %: 96.3436
removed: (4, 9)
Average logit difference (circuit / full) %: 96.7544
removed: (4, 10)
Average logit difference (circuit / full) %: 94.8076
removed: (4, 11)
Average logit difference (circuit / full) %: 95.4238
removed: (5, 2)
Average logit difference (circuit / full) %: 96.5898
removed: (5, 3)
Average logit difference (circuit / full) %: 96.7835
removed: (5, 8)
Average logit difference (circuit / full) %: 96.7259
removed: (6, 1)
Average logit d

In [31]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(9, 1): 53.0628662109375,
 (4, 4): 63.55010986328125,
 (7, 11): 74.47904968261719,
 (10, 7): 78.7812271118164,
 (0, 1): 81.21467590332031,
 (1, 5): 82.49767303466797,
 (8, 8): 85.69908142089844,
 (6, 10): 90.48758697509766,
 (8, 11): 90.9936294555664,
 (7, 6): 93.22164916992188,
 (6, 6): 93.53330993652344,
 (3, 10): 93.63520050048828,
 (6, 1): 93.73930358886719,
 (4, 10): 94.80760955810547,
 (4, 11): 95.42378234863281,
 (4, 7): 96.04631042480469,
 (3, 6): 96.15017700195312,
 (4, 8): 96.34359741210938,
 (5, 2): 96.58976745605469,
 (4, 5): 96.6572036743164,
 (5, 8): 96.72591400146484,
 (8, 10): 96.74845123291016,
 (4, 9): 96.7543716430664,
 (7, 9): 96.78247833251953,
 (5, 3): 96.78353118896484,
 (7, 8): 96.82975769042969}

# iter fwd backw, threshold 20

In [None]:
threshold = 20
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
99.6773910522461

Removed: (0, 1)
87.98929595947266

Removed: (0, 2)
88.6512451171875

Removed: (0, 3)
86.90736389160156

Removed: (0, 4)
87.87069702148438

Removed: (0, 5)
85.47981262207031

Removed: (0, 6)
85.468994140625

Removed: (0, 7)
85.32069396972656

Removed: (0, 8)
84.54505920410156

Removed: (0, 9)
84.93225860595703

Removed: (0, 10)
87.507568359375

Removed: (0, 11)
87.77931213378906

Removed: (1, 0)
87.68569946289062

Removed: (1, 1)
86.57756805419922

Removed: (1, 2)
86.78600311279297

Removed: (1, 3)
86.860107421875

Removed: (1, 4)
86.6047592163086

Removed: (1, 5)
81.6650619506836

Removed: (1, 6)
81.76591491699219

Removed: (1, 7)
82.36595916748047

Removed: (1, 8)
83.85513305664062

Removed: (1, 9)
83.27083587646484

Removed: (1, 10)
83.19534301757812

Removed: (1, 11)
82.02559661865234

Removed: (2, 0)
82.97765350341797

Removed: (2, 1)
84.57157135009766

Removed: (2, 2)
83.71525573730469

Removed: (2, 3)
82.80513000488281

Remov

In [None]:
fb_20 = curr_circuit.copy()
fb_20

[(3, 2),
 (4, 4),
 (4, 8),
 (4, 10),
 (4, 11),
 (5, 5),
 (5, 6),
 (5, 7),
 (5, 8),
 (6, 1),
 (6, 7),
 (6, 9),
 (6, 10),
 (7, 0),
 (7, 2),
 (7, 5),
 (7, 6),
 (7, 7),
 (7, 8),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (10, 2)]

In [None]:
mean_ablate_by_lst(fb_20, model, print_output=True)

Average logit difference (circuit / full) %: 80.0140


tensor(80.0140, device='cuda:0')

In [None]:
len(fb_20)

28

#### loop rmv and check for most impt heads

In [None]:
circ = fb_20
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 80.0140


80.01403045654297

In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()
    lh_scores[lh] = new_score

removed: (3, 2)
Average logit difference (circuit / full) %: 78.5359
removed: (4, 4)
Average logit difference (circuit / full) %: 57.5247
removed: (4, 8)
Average logit difference (circuit / full) %: 79.5581
removed: (4, 10)
Average logit difference (circuit / full) %: 77.7767
removed: (4, 11)
Average logit difference (circuit / full) %: 78.8688
removed: (5, 5)
Average logit difference (circuit / full) %: 79.1987
removed: (5, 6)
Average logit difference (circuit / full) %: 74.9998
removed: (5, 7)
Average logit difference (circuit / full) %: 79.7443
removed: (5, 8)
Average logit difference (circuit / full) %: 77.3542
removed: (6, 1)
Average logit difference (circuit / full) %: 78.4455
removed: (6, 7)
Average logit difference (circuit / full) %: 79.5844
removed: (6, 9)
Average logit difference (circuit / full) %: 79.8339
removed: (6, 10)
Average logit difference (circuit / full) %: 75.6201
removed: (7, 0)
Average logit difference (circuit / full) %: 79.8278
removed: (7, 2)
Average logit d

In [None]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(9, 1): 52.848934173583984,
 (4, 4): 57.52471923828125,
 (7, 11): 65.4083023071289,
 (10, 2): 67.60511016845703,
 (8, 8): 71.76107788085938,
 (8, 11): 73.92449188232422,
 (5, 6): 74.99983978271484,
 (6, 10): 75.62007904052734,
 (7, 10): 77.34323120117188,
 (5, 8): 77.35415649414062,
 (8, 6): 77.66063690185547,
 (4, 10): 77.77672576904297,
 (7, 6): 78.29383087158203,
 (6, 1): 78.44552612304688,
 (3, 2): 78.53585815429688,
 (8, 0): 78.65070343017578,
 (7, 2): 78.7663345336914,
 (8, 1): 78.78251647949219,
 (4, 11): 78.86881256103516,
 (5, 5): 79.1987075805664,
 (7, 7): 79.32286071777344,
 (4, 8): 79.5581283569336,
 (6, 7): 79.58438873291016,
 (7, 5): 79.66378021240234,
 (5, 7): 79.74430847167969,
 (7, 8): 79.74847412109375,
 (7, 0): 79.8277816772461,
 (6, 9): 79.83391571044922}

# iter backw fwd, threshold 3

In [20]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
98.7602310180664

Removed: (11, 1)
99.11569213867188

Removed: (11, 2)
99.78870391845703

Removed: (11, 3)
99.87152099609375

Removed: (11, 4)
100.74242401123047

Removed: (11, 5)
100.7147445678711

Removed: (11, 6)
100.79930877685547

Removed: (11, 7)
100.76506805419922

Removed: (11, 8)
100.13632202148438

Removed: (11, 9)
100.12635040283203

Removed: (11, 10)
97.10176086425781

Removed: (11, 11)
97.61882019042969

Removed: (10, 0)
97.8098373413086

Removed: (10, 1)
99.5031967163086

Removed: (10, 2)
114.83074188232422

Removed: (10, 3)
116.00149536132812

Removed: (10, 4)
115.85767364501953

Removed: (10, 5)
115.88167572021484

Removed: (10, 6)
116.03074645996094

Removed: (10, 8)
115.3823013305664

Removed: (10, 9)
115.25849914550781

Removed: (10, 10)
115.7605209350586

Removed: (10, 11)
116.23766326904297

Removed: (9, 0)
116.51197814941406

Removed: (9, 2)
116.07962799072266

Removed: (9, 3)
116.22007751464844

Removed: (9, 4)
116.05059814

In [21]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (0, 7),
 (0, 8),
 (0, 9),
 (0, 10),
 (1, 5),
 (3, 3),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 10),
 (4, 11),
 (5, 1),
 (5, 6),
 (6, 1),
 (6, 6),
 (6, 10),
 (7, 10),
 (7, 11),
 (8, 8),
 (9, 1),
 (10, 7)]

In [25]:
len(bf_3)

22

#### loop rmv and check for most impt heads

In [22]:
circ = bf_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 97.1164


97.11639404296875

In [23]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 80.6055
removed: (0, 7)
Average logit difference (circuit / full) %: 94.9209
removed: (0, 8)
Average logit difference (circuit / full) %: 96.1987
removed: (0, 9)
Average logit difference (circuit / full) %: 93.2235
removed: (0, 10)
Average logit difference (circuit / full) %: 88.5607
removed: (1, 5)
Average logit difference (circuit / full) %: 86.5919
removed: (3, 3)
Average logit difference (circuit / full) %: 93.3704
removed: (4, 4)
Average logit difference (circuit / full) %: 65.4092
removed: (4, 6)
Average logit difference (circuit / full) %: 96.7369
removed: (4, 7)
Average logit difference (circuit / full) %: 96.3039
removed: (4, 10)
Average logit difference (circuit / full) %: 91.6029
removed: (4, 11)
Average logit difference (circuit / full) %: 96.0994
removed: (5, 1)
Average logit difference (circuit / full) %: 96.9964
removed: (5, 6)
Average logit difference (circuit / full) %: 93.5298
removed: (6, 1)
Average logit d

In [24]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(9, 1): 54.06932830810547,
 (4, 4): 65.40917205810547,
 (7, 11): 73.4704360961914,
 (10, 7): 77.78365325927734,
 (0, 1): 80.60545349121094,
 (8, 8): 86.26164245605469,
 (1, 5): 86.59187316894531,
 (0, 10): 88.5606918334961,
 (6, 10): 90.8823471069336,
 (4, 10): 91.60287475585938,
 (6, 6): 92.8971939086914,
 (7, 10): 93.00191497802734,
 (0, 9): 93.2234878540039,
 (6, 1): 93.24954986572266,
 (3, 3): 93.37042999267578,
 (5, 6): 93.52978515625,
 (0, 7): 94.92088317871094,
 (4, 11): 96.09938049316406,
 (0, 8): 96.19866943359375,
 (4, 7): 96.30391693115234,
 (4, 6): 96.73689270019531,
 (5, 1): 96.99639892578125}

# compare fb and bf

In [32]:
print(len(fb_3))
print(len(bf_3))

26
22


In [33]:
set(fb_3) - set(bf_3)

{(3, 6),
 (3, 10),
 (4, 5),
 (4, 8),
 (4, 9),
 (5, 2),
 (5, 3),
 (5, 8),
 (7, 6),
 (7, 8),
 (7, 9),
 (8, 10),
 (8, 11)}

In [34]:
set(bf_3) - set(fb_3)

{(0, 7), (0, 8), (0, 9), (0, 10), (3, 3), (4, 6), (5, 1), (5, 6), (7, 10)}

# iter backw fwd, threshold 3, rand DS 2

In [74]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
98.70532989501953

Removed: (11, 1)
98.90830993652344

Removed: (11, 2)
99.5697021484375

Removed: (11, 3)
99.80223083496094

Removed: (11, 4)
100.62265014648438

Removed: (11, 5)
100.61029815673828

Removed: (11, 6)
100.6954345703125

Removed: (11, 7)
100.64874267578125

Removed: (11, 8)
99.81201934814453

Removed: (11, 9)
99.73603820800781

Removed: (11, 11)
100.2791748046875

Removed: (10, 0)
100.45072174072266

Removed: (10, 1)
101.93017578125

Removed: (10, 2)
116.63346862792969

Removed: (10, 3)
117.73070526123047

Removed: (10, 4)
117.59419250488281

Removed: (10, 5)
117.6452865600586

Removed: (10, 6)
117.7850112915039

Removed: (10, 8)
117.15862274169922

Removed: (10, 9)
116.9596176147461

Removed: (10, 10)
117.43917846679688

Removed: (10, 11)
117.76687622070312

Removed: (9, 0)
117.90486907958984

Removed: (9, 2)
117.58502960205078

Removed: (9, 3)
117.49701690673828

Removed: (9, 4)
117.26477813720703

Removed: (9, 5)
111.69598388671

In [75]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (0, 7),
 (0, 9),
 (0, 10),
 (1, 5),
 (4, 4),
 (4, 10),
 (5, 4),
 (5, 6),
 (5, 8),
 (6, 1),
 (6, 6),
 (6, 10),
 (7, 10),
 (7, 11),
 (8, 8),
 (9, 1),
 (10, 7),
 (11, 10)]

In [76]:
len(bf_3)

19

#### loop rmv and check for most impt heads

In [77]:
circ = bf_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 97.4716


97.47164154052734

In [95]:
bf_3  # using i-1

[(0, 1),
 (0, 6),
 (0, 7),
 (0, 9),
 (0, 10),
 (3, 3),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 10),
 (5, 0),
 (5, 1),
 (5, 4),
 (5, 6),
 (6, 6),
 (6, 9),
 (6, 10),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 2),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (9, 9),
 (9, 10),
 (9, 11),
 (10, 2),
 (11, 8),
 (11, 10)]

In [94]:
circ = bf_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 53.8096


53.80961608886719

remove anything after 9.1 and try again

In [97]:
bf_3 = [(0, 1), (0, 6), (0, 7), (0, 9), (0, 10), (3, 3), (4, 4), (4, 6), (4, 7), (4, 10), (5, 0), (5, 1), (5, 4), (5, 6), (6, 6), (6, 9), (6, 10), (7, 10), (7, 11), (8, 0), (8, 1), (8, 2), (8, 6), (8, 8), (8, 11), (9, 1)]
circ = bf_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 73.0040


73.00395202636719

In [96]:
bf_3 = [(0, 1), (0, 7), (0, 9), (0, 10), (1, 5), (4, 4), (4, 10), (5, 4), (5, 6), (5, 8), (6, 1), (6, 6), (6, 10), (7, 10), (7, 11), (8, 8), (9, 1), (10, 7), (11, 10)]
circ = bf_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 97.4716


97.47164154052734

try using circ from first rand ds

In [98]:
bf_3 = [(0, 1), (0, 7), (0, 8), (0, 9), (0, 10), (1, 5), (3, 3), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 1), (5, 6), (6, 1), (6, 6), (6, 10), (7, 10), (7, 11), (8, 8), (9, 1), (10, 7)]
circ = bf_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 94.3886


94.38860321044922

back to this one

In [110]:
bf_3 = [(0, 1), (0, 7), (0, 9), (0, 10), (1, 5), (4, 4), (4, 10), (5, 4), (5, 6), (5, 8), (6, 1), (6, 6), (6, 10), (7, 10), (7, 11), (8, 8), (9, 1), (10, 7), (11, 10)]
circ = bf_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 97.4716


97.47164154052734

In [109]:
bf_3 = [(0, 1), (0, 7), (0, 9), (0, 10), (1, 5), (4, 4), (4, 10), (5, 4), (5, 6), (5, 8), (6, 1), (6, 6), (6, 10), (7, 10), (7, 11), (8, 8), (9, 1), (10, 7)]
circ = bf_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 94.4813


94.4813461303711

In [78]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 83.8326
removed: (0, 7)
Average logit difference (circuit / full) %: 94.7826
removed: (0, 9)
Average logit difference (circuit / full) %: 93.3176
removed: (0, 10)
Average logit difference (circuit / full) %: 88.6842
removed: (1, 5)
Average logit difference (circuit / full) %: 88.3025
removed: (4, 4)
Average logit difference (circuit / full) %: 68.9785
removed: (4, 10)
Average logit difference (circuit / full) %: 95.8280
removed: (5, 4)
Average logit difference (circuit / full) %: 96.5551
removed: (5, 6)
Average logit difference (circuit / full) %: 93.3702
removed: (5, 8)
Average logit difference (circuit / full) %: 95.2576
removed: (6, 1)
Average logit difference (circuit / full) %: 93.1253
removed: (6, 6)
Average logit difference (circuit / full) %: 94.1710
removed: (6, 10)
Average logit difference (circuit / full) %: 91.7072
removed: (7, 10)
Average logit difference (circuit / full) %: 94.9703
removed: (7, 11)
Average logit

In [79]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(9, 1): 50.99702835083008,
 (4, 4): 68.97853088378906,
 (7, 11): 77.05406951904297,
 (10, 7): 78.75873565673828,
 (0, 1): 83.83258819580078,
 (8, 8): 85.85089874267578,
 (1, 5): 88.30247497558594,
 (0, 10): 88.68418884277344,
 (6, 10): 91.70716094970703,
 (6, 1): 93.12532043457031,
 (0, 9): 93.31764221191406,
 (5, 6): 93.3702392578125,
 (6, 6): 94.17102813720703,
 (11, 10): 94.4813461303711,
 (0, 7): 94.7826156616211,
 (7, 10): 94.97032165527344,
 (5, 8): 95.25759887695312,
 (4, 10): 95.82804107666016,
 (5, 4): 96.55508422851562}

## try other tasks circs

### gt, IOI

In [99]:
# greater-than
circuit = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 15.0553


15.055303573608398

In [100]:
# IOI
circuit = [(0, 1), (0, 10), (2, 2), (3, 0), (4, 11), (5, 5), (5, 8), (5, 9), (6, 9), (7, 3), (7, 9), (8, 6), (8, 10), (9, 0), (9, 6), (9, 7), (9, 9), (10, 0), (10, 1), (10, 2), (10, 6), (10, 7), (10, 10), (11, 2), (11, 9), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: -0.1438


-0.1438295692205429

### fb 80

In [101]:
# fb 80, digits incr
# https://colab.research.google.com/drive/1mFWmGAKtigFcqqWWMCwU7wWQY2HT5ZOo#scrollTo=lJEY-Zs2g_a5&line=1&uniqifier=1
circuit = [(1, 5), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (6, 3), (6, 8), (6, 10), (7, 0), (7, 2), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (11, 8)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 58.1310


58.1309814453125

In [102]:
# fb 80, numwords
# https://colab.research.google.com/drive/1QTv-4osLHadCAay0beew-xlXszPCG88s#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(3, 2), (4, 4), (4, 8), (4, 10), (4, 11), (5, 5), (5, 6), (5, 7), (5, 8), (6, 1), (6, 7), (6, 9), (6, 10), (7, 0), (7, 2), (7, 5), (7, 6), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (10, 2)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 38.1503


38.15029525756836

In [103]:
# fb 80, months
# https://colab.research.google.com/drive/1lhQqlizYGMC11vzp6I9mJ3dyxIr8tV3l#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(4, 4), (7, 11), (8, 6), (8, 9), (8, 11), (9, 1), (9, 5), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 37.6235


37.6235237121582

### bf 97

In [104]:
# digits incr
circuit = [(0, 1), (0, 2), (0, 3), (0, 5), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (1, 7), (2, 0), (2, 2), (2, 3), (2, 4), (2, 6), (2, 9), (3, 3), (3, 7), (3, 11), (4, 4), (4, 7), (4, 10), (5, 2), (5, 3), (5, 4), (5, 6), (5, 8), (5, 11), (6, 1), (6, 3), (6, 6), (6, 8), (6, 10), (7, 2), (7, 6), (7, 8), (7, 10), (7, 11), (8, 6), (8, 8), (8, 11), (9, 1), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 87.6954


87.69536590576172

In [107]:
# numwords
# https://colab.research.google.com/drive/1QTv-4osLHadCAay0beew-xlXszPCG88s#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = bf_3
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 94.3886


94.38860321044922

In [106]:
# months
# https://colab.research.google.com/drive/1lhQqlizYGMC11vzp6I9mJ3dyxIr8tV3l#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
# circuit = [(4, 4), (7, 11), (8, 6), (8, 9), (8, 11), (9, 1), (9, 5), (11, 10)]
# mean_ablate_by_lst(circuit, model, print_output=True).item()

# measure clean logit diff from i+4 to i-1 instead of i+3

In [80]:
def generate_prompts_list(x ,y):
    words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve', ' thirteen', ' fourteen', ' fifteen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty']
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': words[i],
            'S2': words[i+1],
            'S3': words[i+2],
            'S4': words[i+3],
            'corr': words[i+4],
            'incorr': words[i-1],  # this is arbitrary
            'text': f"{words[i]}{words[i+1]}{words[i+2]}{words[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(0, 16)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

## iter backw fwd, threshold 3, rand DS 2

In [81]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
99.36969757080078

Removed: (11, 1)
98.79235076904297

Removed: (11, 2)
98.73162841796875

Removed: (11, 3)
98.41702270507812

Removed: (11, 4)
97.87065887451172

Removed: (11, 5)
98.01094055175781

Removed: (11, 6)
98.01509857177734

Removed: (11, 7)
98.06578826904297

Removed: (11, 9)
97.89273834228516

Removed: (11, 11)
97.76563262939453

Removed: (10, 0)
97.75312042236328

Removed: (10, 1)
97.57929992675781

Removed: (10, 3)
97.40526580810547

Removed: (10, 4)
97.37287139892578

Removed: (10, 5)
97.47503662109375

Removed: (10, 6)
97.48052215576172

Removed: (10, 7)
99.69751739501953

Removed: (10, 8)
99.37645721435547

Removed: (10, 9)
99.1032485961914

Removed: (10, 10)
99.2474365234375

Removed: (10, 11)
99.3924789428711

Removed: (9, 0)
99.38703155517578

Removed: (9, 2)
99.03852081298828

Removed: (9, 3)
98.57138061523438

Removed: (9, 4)
98.59954071044922

Removed: (9, 5)
97.41523742675781

Removed: (9, 6)
97.32701873779297

Removed: (9

In [82]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (0, 6),
 (0, 7),
 (0, 9),
 (0, 10),
 (3, 3),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 10),
 (5, 0),
 (5, 1),
 (5, 4),
 (5, 6),
 (6, 6),
 (6, 9),
 (6, 10),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 2),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (9, 9),
 (9, 10),
 (9, 11),
 (10, 2),
 (11, 8),
 (11, 10)]

In [83]:
len(bf_3)

32

#### loop rmv and check for most impt heads

In [84]:
circ = bf_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 97.0527


97.05271911621094

In [85]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 86.5970
removed: (0, 6)
Average logit difference (circuit / full) %: 93.3924
removed: (0, 7)
Average logit difference (circuit / full) %: 94.6353
removed: (0, 9)
Average logit difference (circuit / full) %: 93.6825
removed: (0, 10)
Average logit difference (circuit / full) %: 94.9744
removed: (3, 3)
Average logit difference (circuit / full) %: 94.2974
removed: (4, 4)
Average logit difference (circuit / full) %: 71.6683
removed: (4, 6)
Average logit difference (circuit / full) %: 96.3988
removed: (4, 7)
Average logit difference (circuit / full) %: 95.8448
removed: (4, 10)
Average logit difference (circuit / full) %: 94.4813
removed: (5, 0)
Average logit difference (circuit / full) %: 95.5036
removed: (5, 1)
Average logit difference (circuit / full) %: 96.8617
removed: (5, 4)
Average logit difference (circuit / full) %: 95.9291
removed: (5, 6)
Average logit difference (circuit / full) %: 91.7710
removed: (6, 6)
Average logit di

In [86]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(9, 1): 65.91766357421875,
 (4, 4): 71.6683349609375,
 (7, 11): 84.63774108886719,
 (10, 2): 85.13801574707031,
 (0, 1): 86.59697723388672,
 (8, 11): 87.88027954101562,
 (8, 8): 89.4716567993164,
 (5, 6): 91.7709732055664,
 (6, 10): 92.83840942382812,
 (7, 10): 93.17102813720703,
 (0, 6): 93.3924331665039,
 (0, 9): 93.68250274658203,
 (8, 6): 93.90229034423828,
 (6, 9): 94.00465393066406,
 (3, 3): 94.29740142822266,
 (4, 10): 94.48126220703125,
 (0, 7): 94.63526916503906,
 (0, 10): 94.97438049316406,
 (11, 8): 95.22659301757812,
 (5, 0): 95.50364685058594,
 (8, 1): 95.56914520263672,
 (4, 7): 95.84475708007812,
 (5, 4): 95.9291000366211,
 (6, 6): 95.97230529785156,
 (11, 10): 96.06417846679688,
 (8, 2): 96.12086486816406,
 (8, 0): 96.23420715332031,
 (4, 6): 96.39884185791016,
 (9, 9): 96.40127563476562,
 (9, 11): 96.47161102294922,
 (9, 10): 96.82534790039062,
 (5, 1): 96.86174011230469}

# test prompts

In [49]:
modeltest = HookedTransformer.from_pretrained("gpt2")

Loaded pretrained model gpt2 into HookedTransformer


In [50]:
example_prompt = " six seven eight nine"
example_answer = " ten"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', ' six', ' seven', ' eight', ' nine']
Tokenized answer: [' ten']


Top 0th token. Logit: 15.01 Prob: 32.39% Token: | ten|
Top 1th token. Logit: 14.25 Prob: 15.09% Token: | 10|
Top 2th token. Logit: 13.41 Prob:  6.57% Token: | nine|
Top 3th token. Logit: 13.36 Prob:  6.25% Token: | eight|
Top 4th token. Logit: 12.34 Prob:  2.24% Token: | seven|
Top 5th token. Logit: 12.32 Prob:  2.20% Token: | twelve|
Top 6th token. Logit: 12.25 Prob:  2.06% Token: | five|
Top 7th token. Logit: 12.15 Prob:  1.85% Token: | six|
Top 8th token. Logit: 11.93 Prob:  1.49% Token: | 12|
Top 9th token. Logit: 11.92 Prob:  1.48% Token: |
|


In [51]:
example_prompt = " thirteen fourteen fifteen sixteen"
example_answer = " seventeen"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', ' thirteen', ' fourteen', ' fifteen', ' sixteen']
Tokenized answer: [' seventeen']


Top 0th token. Logit: 16.13 Prob: 39.38% Token: | seventeen|
Top 1th token. Logit: 15.03 Prob: 13.03% Token: | sixteen|
Top 2th token. Logit: 14.70 Prob:  9.42% Token: | eighteen|
Top 3th token. Logit: 14.61 Prob:  8.55% Token: | twenty|
Top 4th token. Logit: 13.89 Prob:  4.19% Token: | nineteen|
Top 5th token. Logit: 13.30 Prob:  2.33% Token: | thirteen|
Top 6th token. Logit: 13.04 Prob:  1.78% Token: | fifty|
Top 7th token. Logit: 12.47 Prob:  1.01% Token: | thirty|
Top 8th token. Logit: 12.43 Prob:  0.97% Token: | fourteen|
Top 9th token. Logit: 12.42 Prob:  0.96% Token: | seventy|


In [52]:
example_prompt = " nineteen two seventeen seventeen"
example_answer = " seventeen"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', ' nineteen', ' two', ' seventeen', ' seventeen']
Tokenized answer: [' seventeen']


Top 0th token. Logit: 15.35 Prob: 19.37% Token: | eighteen|
Top 1th token. Logit: 15.18 Prob: 16.27% Token: | sixteen|
Top 2th token. Logit: 14.94 Prob: 12.87% Token: | twenty|
Top 3th token. Logit: 14.20 Prob:  6.14% Token: | seventeen|
Top 4th token. Logit: 13.58 Prob:  3.30% Token: | fourteen|
Top 5th token. Logit: 13.36 Prob:  2.64% Token: | thirteen|
Top 6th token. Logit: 13.36 Prob:  2.64% Token: | eight|
Top 7th token. Logit: 13.27 Prob:  2.42% Token: | twelve|
Top 8th token. Logit: 13.26 Prob:  2.39% Token: | nineteen|
Top 9th token. Logit: 13.20 Prob:  2.25% Token: | fifteen|


In [53]:
example_prompt = " seventeen seventeen seventeen seventeen"
example_answer = " seventeen"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', ' seventeen', ' seventeen', ' seventeen', ' seventeen']
Tokenized answer: [' seventeen']


Top 0th token. Logit: 15.69 Prob: 43.89% Token: | seventeen|
Top 1th token. Logit: 14.03 Prob:  8.35% Token: | sixteen|
Top 2th token. Logit: 13.97 Prob:  7.88% Token: | eighteen|
Top 3th token. Logit: 13.83 Prob:  6.82% Token: | thirteen|
Top 4th token. Logit: 12.83 Prob:  2.52% Token: | fifteen|
Top 5th token. Logit: 12.81 Prob:  2.48% Token: | 17|
Top 6th token. Logit: 12.72 Prob:  2.26% Token: | twenty|
Top 7th token. Logit: 12.70 Prob:  2.20% Token: | nineteen|
Top 8th token. Logit: 12.26 Prob:  1.42% Token: |
|
Top 9th token. Logit: 12.21 Prob:  1.35% Token: | fourteen|


In [55]:
model.tokenizer('seventeen')

{'input_ids': [325, 1151, 6429], 'attention_mask': [1, 1, 1]}

In [56]:
model.tokenizer(' seventeen')

{'input_ids': [38741], 'attention_mask': [1]}

In [None]:

model.tokenizer.decode([1440])

# repeat new

replace with a repeating rand number (last two) and predict that as the ‘incorr’ token

to erase info about all numbers, given last token T, use T+2 repeat last two (mod len words if numwords or months), then randomize first two or use T+2 again, or T+4 and T+7 (T+4, T+7, T+2, T+2) to reduce randomness. If this doesn’t make a difference, don’t use it.

In [44]:
words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve', ' thirteen', ' fourteen', ' fifteen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty']
(18+4)%len(words)

2

In [45]:
words[18]

' nineteen'

In [46]:
words[(18+4)%len(words)]

' three'

In [57]:
import random

def generate_prompts_list_corr(x ,y):
    words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve', ' thirteen', ' fourteen', ' fifteen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty']
    prompts_list = []
    for i in range(x, y):
        # r1 = random.choice(words)
        # T = i+3 # get the end
        # r1 = words[(T+4)%len(words)]
        # r2 = words[(T+7)%len(words)]
        # last_ind = (T+2)%len(words)
        # r3 = words[last_ind]
        # r4 = words[last_ind]
        T_ind = (i-1)%len(words)
        r1 = words[T_ind]
        r2 = words[T_ind]
        r3 = words[T_ind]
        r4 = words[T_ind]

        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': str(r4),
            'incorr': str(words[i+4]),
            'text': f"{r1}{r2}{r3}{r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

# prompts_list_2 = generate_prompts_list_corr(0, 8)
prompts_list_2 = generate_prompts_list_corr(0, 16)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)
prompts_list_2

[{'S1': ' twenty',
  'S2': ' twenty',
  'S3': ' twenty',
  'S4': ' twenty',
  'corr': ' twenty',
  'incorr': ' five',
  'text': ' twenty twenty twenty twenty'},
 {'S1': ' one',
  'S2': ' one',
  'S3': ' one',
  'S4': ' one',
  'corr': ' one',
  'incorr': ' six',
  'text': ' one one one one'},
 {'S1': ' two',
  'S2': ' two',
  'S3': ' two',
  'S4': ' two',
  'corr': ' two',
  'incorr': ' seven',
  'text': ' two two two two'},
 {'S1': ' three',
  'S2': ' three',
  'S3': ' three',
  'S4': ' three',
  'corr': ' three',
  'incorr': ' eight',
  'text': ' three three three three'},
 {'S1': ' four',
  'S2': ' four',
  'S3': ' four',
  'S4': ' four',
  'corr': ' four',
  'incorr': ' nine',
  'text': ' four four four four'},
 {'S1': ' five',
  'S2': ' five',
  'S3': ' five',
  'S4': ' five',
  'corr': ' five',
  'incorr': ' ten',
  'text': ' five five five five'},
 {'S1': ' six',
  'S2': ' six',
  'S3': ' six',
  'S4': ' six',
  'corr': ' six',
  'incorr': ' eleven',
  'text': ' six six six six'

In [58]:
# def generate_prompts_list(x ,y, prompts_list_2):
def generate_prompts_list(x ,y):
    words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve', ' thirteen', ' fourteen', ' fifteen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty']
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': words[i],
            'S2': words[i+1],
            'S3': words[i+2],
            'S4': words[i+3],
            'corr': words[i+4],
            # 'incorr': prompts_list_2[i]['corr'],
            'incorr': words[i-1],
            'text': f"{words[i]}{words[i+1]}{words[i+2]}{words[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

# prompts_list = generate_prompts_list(0, 16, prompts_list_2)
prompts_list = generate_prompts_list(0, 16)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)
prompts_list

[{'S1': ' one',
  'S2': ' two',
  'S3': ' three',
  'S4': ' four',
  'corr': ' five',
  'incorr': ' twenty',
  'text': ' one two three four'},
 {'S1': ' two',
  'S2': ' three',
  'S3': ' four',
  'S4': ' five',
  'corr': ' six',
  'incorr': ' one',
  'text': ' two three four five'},
 {'S1': ' three',
  'S2': ' four',
  'S3': ' five',
  'S4': ' six',
  'corr': ' seven',
  'incorr': ' two',
  'text': ' three four five six'},
 {'S1': ' four',
  'S2': ' five',
  'S3': ' six',
  'S4': ' seven',
  'corr': ' eight',
  'incorr': ' three',
  'text': ' four five six seven'},
 {'S1': ' five',
  'S2': ' six',
  'S3': ' seven',
  'S4': ' eight',
  'corr': ' nine',
  'incorr': ' four',
  'text': ' five six seven eight'},
 {'S1': ' six',
  'S2': ' seven',
  'S3': ' eight',
  'S4': ' nine',
  'corr': ' ten',
  'incorr': ' five',
  'text': ' six seven eight nine'},
 {'S1': ' seven',
  'S2': ' eight',
  'S3': ' nine',
  'S4': ' ten',
  'corr': ' eleven',
  'incorr': ' six',
  'text': ' seven eight nine 

## iter backw fwd, threshold 3

In [59]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
100.07949829101562

Removed: (11, 1)
99.67402648925781

Removed: (11, 2)
99.62825775146484

Removed: (11, 3)
99.40247344970703

Removed: (11, 4)
98.77851104736328

Removed: (11, 5)
98.82231140136719

Removed: (11, 6)
98.7921371459961

Removed: (11, 7)
98.77710723876953

Removed: (11, 8)
97.26992797851562

Removed: (11, 9)
97.09835815429688

Removed: (11, 11)
97.01774597167969

Removed: (10, 4)
97.24371337890625

Removed: (10, 5)
97.09815216064453

Removed: (10, 6)
97.12135314941406

Removed: (10, 7)
98.55481719970703

Removed: (10, 8)
98.4288330078125

Removed: (10, 9)
98.31587219238281

Removed: (10, 10)
98.59809875488281

Removed: (10, 11)
98.8136978149414

Removed: (9, 0)
98.71856689453125

Removed: (9, 2)
98.44058990478516

Removed: (9, 3)
99.04450225830078

Removed: (9, 4)
99.18749237060547

Removed: (9, 5)
97.87513732910156

Removed: (9, 6)
97.79296112060547

Removed: (9, 7)
98.1976089477539

Removed: (9, 8)
98.17428588867188

Removed: (9, 

In [60]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (3, 0),
 (4, 4),
 (5, 5),
 (6, 1),
 (6, 7),
 (6, 8),
 (6, 9),
 (6, 10),
 (7, 6),
 (7, 10),
 (7, 11),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (10, 1),
 (10, 2),
 (10, 3),
 (11, 10)]

In [61]:
len(bf_3)

21

#### loop rmv and check for most impt heads

In [62]:
circ = bf_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 97.1206


97.1206283569336

In [63]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 82.3755
removed: (3, 0)
Average logit difference (circuit / full) %: 90.3189
removed: (4, 4)
Average logit difference (circuit / full) %: 62.6386
removed: (5, 5)
Average logit difference (circuit / full) %: 81.3340
removed: (6, 1)
Average logit difference (circuit / full) %: 93.4345
removed: (6, 7)
Average logit difference (circuit / full) %: 91.6200
removed: (6, 8)
Average logit difference (circuit / full) %: 96.0655
removed: (6, 9)
Average logit difference (circuit / full) %: 96.4013
removed: (6, 10)
Average logit difference (circuit / full) %: 83.8460
removed: (7, 6)
Average logit difference (circuit / full) %: 91.4450
removed: (7, 10)
Average logit difference (circuit / full) %: 90.4626
removed: (7, 11)
Average logit difference (circuit / full) %: 81.7064
removed: (8, 1)
Average logit difference (circuit / full) %: 94.8787
removed: (8, 6)
Average logit difference (circuit / full) %: 94.5803
removed: (8, 8)
Average logit d

In [64]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(4, 4): 62.63862609863281,
 (9, 1): 64.31694030761719,
 (5, 5): 81.33404541015625,
 (7, 11): 81.70638275146484,
 (0, 1): 82.37551879882812,
 (6, 10): 83.84600830078125,
 (10, 2): 83.94091796875,
 (8, 11): 84.35529327392578,
 (8, 8): 87.6913070678711,
 (3, 0): 90.31892395019531,
 (7, 10): 90.46261596679688,
 (7, 6): 91.44503784179688,
 (6, 7): 91.62000274658203,
 (6, 1): 93.43449401855469,
 (8, 6): 94.58032989501953,
 (8, 1): 94.8786849975586,
 (11, 10): 95.69868469238281,
 (10, 3): 95.95584106445312,
 (6, 8): 96.06547546386719,
 (10, 1): 96.08500671386719,
 (6, 9): 96.40132904052734}

## try other tasks circs

### gt, IOI

In [None]:
# greater-than
circuit = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 20.8534


20.853376388549805

In [None]:
# IOI
circuit = [(0, 1), (0, 10), (2, 2), (3, 0), (4, 11), (5, 5), (5, 8), (5, 9), (6, 9), (7, 3), (7, 9), (8, 6), (8, 10), (9, 0), (9, 6), (9, 7), (9, 9), (10, 0), (10, 1), (10, 2), (10, 6), (10, 7), (10, 10), (11, 2), (11, 9), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 12.2694


12.269357681274414

### fb 80

In [87]:
# fb 80, digits incr
# https://colab.research.google.com/drive/1mFWmGAKtigFcqqWWMCwU7wWQY2HT5ZOo#scrollTo=lJEY-Zs2g_a5&line=1&uniqifier=1
circuit = [(1, 5), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (6, 3), (6, 8), (6, 10), (7, 0), (7, 2), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (11, 8)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 64.4987


64.4986572265625

In [None]:
# fb 80, numwords
# https://colab.research.google.com/drive/1QTv-4osLHadCAay0beew-xlXszPCG88s#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(3, 2), (4, 4), (4, 8), (4, 10), (4, 11), (5, 5), (5, 6), (5, 7), (5, 8), (6, 1), (6, 7), (6, 9), (6, 10), (7, 0), (7, 2), (7, 5), (7, 6), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (10, 2)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 81.9778


81.9777603149414

In [None]:
# fb 80, months
# https://colab.research.google.com/drive/1lhQqlizYGMC11vzp6I9mJ3dyxIr8tV3l#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(4, 4), (7, 11), (8, 6), (8, 9), (8, 11), (9, 1), (9, 5), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 36.0939


36.09392547607422

### bf 97

In [65]:
# digits incr
circuit = [(0, 1), (0, 2), (0, 3), (0, 5), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (1, 7), (2, 0), (2, 2), (2, 3), (2, 4), (2, 6), (2, 9), (3, 3), (3, 7), (3, 11), (4, 4), (4, 7), (4, 10), (5, 2), (5, 3), (5, 4), (5, 6), (5, 8), (5, 11), (6, 1), (6, 3), (6, 6), (6, 8), (6, 10), (7, 2), (7, 6), (7, 8), (7, 10), (7, 11), (8, 6), (8, 8), (8, 11), (9, 1), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 71.6404


71.64041137695312

In [66]:
# numwords
# https://colab.research.google.com/drive/1QTv-4osLHadCAay0beew-xlXszPCG88s#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(0, 1), (1, 0), (1, 5), (3, 2), (4, 4), (4, 8), (4, 10), (5, 4), (5, 6), (5, 8), (6, 9), (6, 10), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (9, 5), (10, 2), (11, 8)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 86.4105


86.41048431396484

In [None]:
# months
# https://colab.research.google.com/drive/1lhQqlizYGMC11vzp6I9mJ3dyxIr8tV3l#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
# circuit = [(4, 4), (7, 11), (8, 6), (8, 9), (8, 11), (9, 1), (9, 5), (11, 10)]
# mean_ablate_by_lst(circuit, model, print_output=True).item()