<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-5niggyxn
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-5niggyxn
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.0-py3-none-any.whl (260 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.0/261.0 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x79ff04cb3850>

Plotting helper functions:

In [6]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [8]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9100, done.[K
remote: Counting objects: 100% (1814/1814), done.[K
remote: Compressing objects: 100% (288/288), done.[K
remote: Total 9100 (delta 1609), reused 1602 (delta 1523), pack-reused 7286[K
Receiving objects: 100% (9100/9100), 155.60 MiB | 33.59 MiB/s, done.
Resolving deltas: 100% (5502/5502), done.


In [9]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [10]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

In [11]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

Repalce io_tokens with correct answer (next, which is '5') and s_tokens with incorrect (current, which repeats)

In [12]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
}

In [13]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+3),
            'corr': str(i+4),
            'incorr': str(i),
            'text': f"{i} {i+1} {i+2} {i+3}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 11)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [14]:
import random

def generate_prompts_list_corr(x ,y):
    prompts_list = []
    for i in range(x, y):
        r1 = random.randint(15, 30)
        r2 = random.randint(15, 30)
        r3 = random.randint(15, 30)
        r4 = random.randint(15, 30)
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': str(r1),
            'incorr': str(i+4),
            'text': f"{r1} {r2} {r3} {r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(1, 11)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)
prompts_list_2

[{'S1': '20',
  'S2': '22',
  'S3': '26',
  'S4': '25',
  'corr': '20',
  'incorr': '5',
  'text': '20 22 26 25'},
 {'S1': '29',
  'S2': '16',
  'S3': '17',
  'S4': '20',
  'corr': '29',
  'incorr': '6',
  'text': '29 16 17 20'},
 {'S1': '20',
  'S2': '20',
  'S3': '19',
  'S4': '29',
  'corr': '20',
  'incorr': '7',
  'text': '20 20 19 29'},
 {'S1': '16',
  'S2': '16',
  'S3': '20',
  'S4': '25',
  'corr': '16',
  'incorr': '8',
  'text': '16 16 20 25'},
 {'S1': '20',
  'S2': '25',
  'S3': '22',
  'S4': '30',
  'corr': '20',
  'incorr': '9',
  'text': '20 25 22 30'},
 {'S1': '22',
  'S2': '17',
  'S3': '22',
  'S4': '27',
  'corr': '22',
  'incorr': '10',
  'text': '22 17 22 27'},
 {'S1': '17',
  'S2': '15',
  'S3': '29',
  'S4': '22',
  'corr': '17',
  'incorr': '11',
  'text': '17 15 29 22'},
 {'S1': '20',
  'S2': '20',
  'S3': '23',
  'S4': '30',
  'corr': '20',
  'incorr': '12',
  'text': '20 20 23 30'},
 {'S1': '15',
  'S2': '28',
  'S3': '24',
  'S4': '25',
  'corr': '15',
  'in

Logit diff is correct - incorr token. Here, correct is S5, and incorr is S4.

Because of this, it's possible to have logit diffs HIGHER than the "full circuit" because the correct token will still be at first place, but the logit scores assigned will just be bigger (perhaps incorrect is scored even lower in the non-full circuit with a higher logit diff score)?

# Ablation Expm Functions

In [15]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [16]:
def mean_ablate_by_lst(lst, model, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        # "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        # "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

We can also prevent redundant computation of the full circuit score by storing it and just passing it in to the function.

# Ablate the model and compare with original

### try full circuit from repeatLast iter fb

In [17]:
curr_circuit = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (3, 0), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 8), (4, 10), (4, 11), (5, 4), (5, 5), (5, 9), (6, 1), (6, 6), (6, 10), (7, 6), (7, 10), (7, 11), (8, 1), (8, 2), (8, 6), (8, 8), (9, 1), (9, 5), (10, 7), (11, 10)]
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 43.7334


43.73337173461914

In [18]:
curr_circuit = [(9, 1)]
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 1.6951


1.695137619972229

## compare with repeatRandElem

In [19]:
repeatRand_backw_3 = [(0, 0), (0, 1), (0, 2), (0, 3), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (1, 7), (2, 0), (2, 2), (2, 9), (3, 0), (4, 2), (4, 4), (4, 7), (4, 8), (4, 9), (4, 10), (5, 0), (5, 3), (5, 4), (5, 5), (5, 6), (6, 1), (6, 3), (6, 4), (6, 6), (6, 8), (6, 10), (7, 10), (7, 11), (8, 11), (9, 1), (10, 1)]
mean_ablate_by_lst(repeatRand_backw_3, model, print_output=True).item()

Average logit difference (circuit / full) %: 70.7167


70.71671295166016

In [20]:
repeatRand_backw_20 = [(0, 1), (0, 2), (0, 3), (0, 9), (0, 10), (0, 11), (1, 0), (1, 5), (1, 7), (1, 8), (2, 0), (2, 2), (2, 7), (2, 9), (3, 0), (3, 3), (4, 4), (4, 6), (4, 7), (4, 9), (4, 10), (4, 11), (5, 0), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (6, 1), (6, 6), (6, 9), (6, 10), (7, 10), (7, 11), (9, 1)]
mean_ablate_by_lst(repeatRand_backw_20, model, print_output=True).item()

Average logit difference (circuit / full) %: 62.4968


62.4968147277832

In [21]:
repeatRand_fb_3 = [(0, 1), (0, 2), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (2, 0), (2, 2), (2, 6), (2, 7), (3, 0), (3, 11), (4, 4), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10), (5, 0), (5, 3), (5, 4), (5, 5), (5, 6), (6, 1), (6, 6), (6, 10), (7, 10), (7, 11), (8, 8), (8, 9), (8, 11), (9, 1), (11, 10)]
mean_ablate_by_lst(repeatRand_fb_3, model, print_output=True).item()

Average logit difference (circuit / full) %: 58.8458


58.84584045410156

## Prune backwards

In [22]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(11, -1, -1):  # go thru all heads in a layer first
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

Removed: (11, 0)
98.16059112548828


Removed: (11, 1)
98.0081558227539


Removed: (11, 2)
98.06500244140625


Removed: (11, 3)
97.54293060302734


Removed: (11, 4)
99.056396484375


Removed: (11, 5)
99.13742065429688


Removed: (11, 6)
99.3099594116211


Removed: (11, 7)
99.19345092773438


Removed: (11, 8)
98.11976623535156


Removed: (11, 9)
97.83706665039062


Removed: (11, 11)
99.45267486572266


Removed: (10, 0)
99.34770965576172


Removed: (10, 1)
98.8477783203125


Removed: (10, 2)
101.41998291015625


Removed: (10, 3)
101.2891845703125


Removed: (10, 4)
101.08541107177734


Removed: (10, 5)
101.35785675048828


Removed: (10, 6)
101.4205551147461


Removed: (10, 7)
103.82918548583984


Removed: (10, 8)
103.9638442993164


Removed: (10, 9)
104.15493774414062


Removed: (10, 10)
104.01830291748047


Removed: (10, 11)
104.1028823852539


Removed: (9, 0)
104.28934478759766


Removed: (9, 2)
104.42497253417969


Removed: (9, 3)
105.49421691894531


Removed: (9, 4)
105.81078338623047

In [23]:
mean_ablate_by_lst(curr_circuit, model, print_output=True)

Average logit difference (circuit / full) %: 103.1818


tensor(103.1818, device='cuda:0')

In [24]:
backw_3 = curr_circuit.copy()
backw_3

[(0, 0),
 (0, 1),
 (0, 2),
 (0, 3),
 (0, 5),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 1),
 (1, 5),
 (1, 6),
 (1, 7),
 (1, 8),
 (2, 2),
 (2, 8),
 (2, 10),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 8),
 (4, 3),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 9),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 6),
 (5, 8),
 (5, 11),
 (6, 1),
 (6, 3),
 (6, 9),
 (6, 10),
 (6, 11),
 (7, 0),
 (7, 9),
 (7, 10),
 (7, 11),
 (8, 11),
 (9, 1),
 (11, 10)]

In [25]:
len(backw_3)

47

Now try 10% threshold:

In [26]:
def find_circuit_backw(threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    # Start with full circuit
    curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("Removed:", (layer, head))
                print(new_score)
                print("\n")

    return curr_circuit

In [27]:
curr_circuit = find_circuit_backw(10)

Removed: (11, 0)
98.16059112548828


Removed: (11, 1)
98.0081558227539


Removed: (11, 2)
98.06500244140625


Removed: (11, 3)
97.54293060302734


Removed: (11, 4)
99.056396484375


Removed: (11, 5)
99.13742065429688


Removed: (11, 6)
99.3099594116211


Removed: (11, 7)
99.19345092773438


Removed: (11, 8)
98.11976623535156


Removed: (11, 9)
97.83706665039062


Removed: (11, 10)
96.54454803466797


Removed: (11, 11)
98.15660858154297


Removed: (10, 0)
98.06269836425781


Removed: (10, 1)
97.5815658569336


Removed: (10, 2)
100.13105010986328


Removed: (10, 3)
99.99195098876953


Removed: (10, 4)
99.80280303955078


Removed: (10, 5)
100.05796813964844


Removed: (10, 6)
100.1131820678711


Removed: (10, 7)
102.08489990234375


Removed: (10, 8)
102.23777770996094


Removed: (10, 9)
102.39921569824219


Removed: (10, 10)
102.27693939208984


Removed: (10, 11)
102.38016510009766


Removed: (9, 0)
102.56954193115234


Removed: (9, 2)
102.70443725585938


Removed: (9, 3)
103.735282897949

In [28]:
backw_10 = curr_circuit.copy()
backw_10

[(0, 0),
 (0, 1),
 (0, 2),
 (0, 3),
 (0, 5),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 1),
 (1, 3),
 (1, 4),
 (1, 5),
 (1, 6),
 (1, 7),
 (1, 8),
 (2, 0),
 (2, 2),
 (2, 8),
 (2, 9),
 (2, 10),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 8),
 (3, 11),
 (4, 4),
 (4, 6),
 (4, 8),
 (4, 9),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 6),
 (5, 8),
 (5, 11),
 (6, 2),
 (6, 3),
 (6, 4),
 (6, 8),
 (6, 9),
 (6, 10),
 (6, 11),
 (7, 10),
 (7, 11),
 (8, 11),
 (9, 1)]

In [29]:
mean_ablate_by_lst(curr_circuit, model, print_output=True)

Average logit difference (circuit / full) %: 97.9794


tensor(97.9794, device='cuda:0')

In [30]:
len(backw_10)

49

20%:

In [31]:
%%capture
curr_circuit = find_circuit_backw(20)

In [32]:
backw_20 = curr_circuit.copy()
backw_20

[(0, 0),
 (0, 1),
 (0, 2),
 (0, 3),
 (0, 5),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 1),
 (1, 2),
 (1, 5),
 (1, 6),
 (1, 7),
 (1, 8),
 (2, 0),
 (2, 2),
 (2, 8),
 (2, 9),
 (2, 10),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 11),
 (4, 3),
 (4, 4),
 (4, 6),
 (4, 8),
 (4, 9),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 6),
 (5, 8),
 (5, 11),
 (6, 1),
 (6, 3),
 (6, 9),
 (6, 10),
 (7, 10),
 (7, 11),
 (9, 1)]

In [33]:
mean_ablate_by_lst(curr_circuit, model, print_output=True)

Average logit difference (circuit / full) %: 87.0955


tensor(87.0955, device='cuda:0')

In [34]:
len(backw_20)

44

### set diffs of the three perf lvls

In [35]:
set(backw_3) - set(backw_10)

{(4, 3), (4, 7), (6, 1), (7, 0), (7, 9), (11, 10)}

In [36]:
set(backw_10) - set(backw_3)

{(1, 3), (1, 4), (2, 0), (2, 9), (3, 11), (6, 2), (6, 4), (6, 8)}

In [37]:
set(backw_3) - set(backw_20)

{(3, 8), (4, 7), (6, 11), (7, 0), (7, 9), (8, 11), (11, 10)}

In [38]:
set(backw_10) - set(backw_20)

{(1, 3), (1, 4), (3, 8), (6, 2), (6, 4), (6, 8), (6, 11), (8, 11)}

In [39]:
mean_ablate_by_lst(backw_20, model, print_output=True)

Average logit difference (circuit / full) %: 87.0955


tensor(87.0955, device='cuda:0')

In [40]:
mean_ablate_by_lst(backw_20 + [(10, 2)], model, print_output=True)

Average logit difference (circuit / full) %: 84.1322


tensor(84.1322, device='cuda:0')

In [41]:
mean_ablate_by_lst([x for x in backw_20 if x != (9, 1)], model, print_output=True)

Average logit difference (circuit / full) %: 48.0052


tensor(48.0052, device='cuda:0')

In [42]:
mean_ablate_by_lst([x for x in backw_20 if x != (9, 1)] + [(10, 2)], model, print_output=True)

Average logit difference (circuit / full) %: 46.4977


tensor(46.4977, device='cuda:0')

### set diff w repeatLast and repeatFirstAll circs

In [43]:
repeatFirstAll_backw_3 = [(0, 1), (0, 9), (1, 0), (1, 5), (2, 2), (2, 9), (2, 10), (3, 0), (3, 3), (3, 6), (3, 7), (4, 4), (4, 8), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (6, 1), (6, 3), (6, 4), (6, 6), (6, 9), (6, 10), (7, 1), (7, 2), (7, 6), (7, 7), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (9, 9), (10, 1), (10, 2), (11, 8), (11, 9), (11, 10)]
repeatFirstAll_backw_10 = [(0, 1), (0, 9), (1, 0), (1, 5), (1, 6), (2, 2), (2, 8), (2, 9), (3, 0), (3, 2), (3, 3), (3, 7), (3, 8), (3, 10), (4, 4), (5, 1), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), (5, 10), (6, 0), (6, 1), (6, 3), (6, 4), (6, 6), (6, 9), (6, 10), (6, 11), (7, 0), (7, 6), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (10, 2)]
repeatFirstAll_backw_20 = [(0, 1), (0, 9), (1, 0), (1, 5), (1, 6), (2, 2), (2, 8), (2, 9), (2, 10), (3, 0), (3, 2), (3, 3), (3, 7), (3, 10), (4, 4), (4, 10), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (5, 10), (6, 1), (6, 3), (6, 4), (6, 6), (6, 9), (6, 10), (7, 2), (7, 6), (7, 7), (7, 10), (7, 11), (8, 0), (8, 6), (8, 8), (8, 11), (9, 1)]

In [44]:
mean_ablate_by_lst(repeatFirstAll_backw_3, model, print_output=True)

Average logit difference (circuit / full) %: 37.1344


tensor(37.1344, device='cuda:0')

In [45]:
mean_ablate_by_lst(repeatFirstAll_backw_10, model, print_output=True)

Average logit difference (circuit / full) %: 32.2838


tensor(32.2838, device='cuda:0')

In [46]:
mean_ablate_by_lst(repeatFirstAll_backw_20, model, print_output=True)

Average logit difference (circuit / full) %: 48.2804


tensor(48.2804, device='cuda:0')

In [47]:
repLast_backw_3 = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (0, 10), (1, 0), (1, 4), (1, 5), (2, 2), (2, 8), (2, 9), (3, 0), (3, 2), (3, 3), (3, 7), (4, 4), (4, 7), (4, 10), (5, 1), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (5, 9), (6, 1), (6, 4), (6, 6), (6, 10), (6, 11), (7, 6), (7, 10), (7, 11), (8, 0), (8, 5), (8, 6), (8, 8), (9, 1), (10, 7)]
set(backw_3) - set(repLast_backw_3)

{(0, 0),
 (0, 2),
 (1, 1),
 (1, 6),
 (1, 7),
 (1, 8),
 (2, 10),
 (3, 8),
 (4, 3),
 (4, 6),
 (4, 8),
 (4, 9),
 (4, 11),
 (5, 0),
 (5, 2),
 (5, 11),
 (6, 3),
 (6, 9),
 (7, 0),
 (7, 9),
 (8, 11),
 (11, 10)}

In [48]:
mean_ablate_by_lst(repLast_backw_3, model, print_output=True)

Average logit difference (circuit / full) %: 51.0269


tensor(51.0269, device='cuda:0')

In [49]:
set(repLast_backw_3) - set(backw_3)

{(0, 7),
 (1, 4),
 (2, 9),
 (3, 2),
 (5, 1),
 (5, 5),
 (5, 9),
 (6, 4),
 (6, 6),
 (7, 6),
 (8, 0),
 (8, 5),
 (8, 6),
 (8, 8),
 (10, 7)}

In [50]:
set(backw_3) - set(repeatFirstAll_backw_3)

{(0, 0),
 (0, 2),
 (0, 3),
 (0, 5),
 (0, 10),
 (1, 1),
 (1, 6),
 (1, 7),
 (1, 8),
 (2, 8),
 (3, 8),
 (4, 3),
 (4, 6),
 (4, 7),
 (4, 9),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 11),
 (6, 11),
 (7, 0),
 (7, 9)}

In [51]:
set(repeatFirstAll_backw_3) - set(backw_3)

{(2, 9),
 (3, 6),
 (5, 1),
 (5, 5),
 (6, 4),
 (6, 6),
 (7, 1),
 (7, 2),
 (7, 6),
 (7, 7),
 (8, 0),
 (8, 1),
 (8, 6),
 (8, 8),
 (9, 9),
 (10, 1),
 (10, 2),
 (11, 8),
 (11, 9)}

## Prune forwards

In [52]:
# # Start with full circuit
# curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
# threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

# for layer in range(0, 12):
#     for head in range(12):
#         # Copying the curr_circuit so we can iterate over one and modify the other
#         copy_circuit = curr_circuit.copy()

#         # Temporarily removing the current tuple from the copied circuit
#         copy_circuit.remove((layer, head))

#         new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

#         # print((layer,head), new_score)
#         # If the result is less than the threshold, remove the tuple from the original list
#         if (100 - new_score) < threshold:
#             curr_circuit.remove((layer, head))

#             print("Removed:", (layer, head))
#             print(new_score)
#             print("\n")

## Prune fwds-backwds iteratively

In [53]:
def find_circuit_forw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [54]:
def find_circuit_backw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

### iter fwd backw, threshold 3

In [55]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
97.68428039550781

Removed: (0, 6)
99.71448516845703

Removed: (0, 8)
100.186279296875

Removed: (0, 11)
102.14154052734375

Removed: (1, 1)
102.31873321533203

Removed: (1, 2)
101.02670288085938

Removed: (1, 3)
101.02294921875

Removed: (1, 4)
101.57904052734375

Removed: (1, 6)
101.14369201660156

Removed: (1, 8)
99.4608383178711

Removed: (1, 9)
101.89488220214844

Removed: (1, 10)
103.35824584960938

Removed: (1, 11)
105.71910095214844

Removed: (2, 0)
105.6504898071289

Removed: (2, 1)
105.01492309570312

Removed: (2, 2)
97.0966567993164

Removed: (2, 4)
98.16194915771484

Removed: (2, 5)
98.88838195800781

Removed: (2, 6)
97.89149475097656

Removed: (2, 11)
98.76406860351562

Removed: (3, 1)
99.09585571289062

Removed: (3, 2)
99.56168365478516

Removed: (3, 4)
102.0413818359375

Removed: (3, 5)
101.81766510009766

Removed: (3, 6)
98.82493591308594

Removed: (3, 9)
98.73748016357422

Removed: (3, 10)
103.85990905761719

Removed: (3, 11)
100.90

In [56]:
fb_3 = curr_circuit.copy()
fb_3

[(0, 1),
 (0, 2),
 (0, 3),
 (0, 5),
 (0, 7),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 5),
 (1, 7),
 (2, 8),
 (2, 9),
 (2, 10),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 8),
 (4, 4),
 (4, 7),
 (4, 8),
 (4, 9),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 2),
 (5, 4),
 (5, 6),
 (5, 8),
 (5, 10),
 (6, 1),
 (6, 3),
 (6, 9),
 (6, 10),
 (7, 0),
 (7, 8),
 (7, 10),
 (7, 11),
 (8, 6),
 (8, 11),
 (9, 1)]

In [57]:
mean_ablate_by_lst(fb_3, model, print_output=True)

Average logit difference (circuit / full) %: 97.1427


tensor(97.1427, device='cuda:0')

In [58]:
mean_ablate_by_lst(fb_3 + [(6, 9)], model, print_output=True)

Average logit difference (circuit / full) %: 97.1427


tensor(97.1427, device='cuda:0')

#### compare

In [59]:
set(backw_3) - set(fb_3)

{(0, 0),
 (1, 1),
 (1, 6),
 (1, 8),
 (2, 2),
 (4, 3),
 (4, 6),
 (5, 3),
 (5, 11),
 (6, 11),
 (7, 9),
 (11, 10)}

In [60]:
set(fb_3) - set(backw_3)

{(0, 7), (2, 9), (5, 10), (7, 8), (8, 6)}

### iter fwd backw, threshold 20

In [61]:
# threshold = 20
# curr_circuit = []
# prev_score = 100
# new_score = 0
# iter = 1
# while prev_score != new_score:
#     print('\nfwd prune, iter ', str(iter))
#     # track changes in circuit as for some reason it doesn't work with scores
#     old_circuit = curr_circuit.copy() # save old before finding new one
#     curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
#     if curr_circuit == old_circuit:
#         break
#     print('\nbackw prune, iter ', str(iter))
#     # prev_score = new_score # save old score before finding new one
#     old_circuit = curr_circuit.copy() # save old before finding new one
#     curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
#     if curr_circuit == old_circuit:
#         break
#     iter += 1

In [62]:
# curr_circuit

## Prune backwds-fwds iteratively

### iter fwd backw, threshold 3

In [63]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
98.16059112548828

Removed: (11, 1)
98.0081558227539

Removed: (11, 2)
98.06500244140625

Removed: (11, 3)
97.54293060302734

Removed: (11, 4)
99.056396484375

Removed: (11, 5)
99.13742065429688

Removed: (11, 6)
99.3099594116211

Removed: (11, 7)
99.19345092773438

Removed: (11, 8)
98.11976623535156

Removed: (11, 9)
97.83706665039062

Removed: (11, 11)
99.45267486572266

Removed: (10, 0)
99.34770965576172

Removed: (10, 1)
98.8477783203125

Removed: (10, 2)
101.41998291015625

Removed: (10, 3)
101.2891845703125

Removed: (10, 4)
101.08541107177734

Removed: (10, 5)
101.35785675048828

Removed: (10, 6)
101.4205551147461

Removed: (10, 7)
103.82918548583984

Removed: (10, 8)
103.9638442993164

Removed: (10, 9)
104.15493774414062

Removed: (10, 10)
104.01830291748047

Removed: (10, 11)
104.1028823852539

Removed: (9, 0)
104.28934478759766

Removed: (9, 2)
104.42497253417969

Removed: (9, 3)
105.49421691894531

Removed: (9, 4)
105.81078338623047

R

In [64]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (0, 2),
 (0, 3),
 (0, 5),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 5),
 (1, 7),
 (2, 2),
 (2, 8),
 (2, 10),
 (3, 3),
 (3, 7),
 (3, 8),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 9),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 6),
 (5, 8),
 (6, 1),
 (6, 3),
 (6, 9),
 (6, 10),
 (6, 11),
 (7, 0),
 (7, 9),
 (7, 10),
 (7, 11),
 (8, 11),
 (9, 1),
 (11, 10)]

#### compare

In [65]:
len(bf_3)

40

In [66]:
len(fb_3)

40

In [67]:
set(backw_3) - set(bf_3)

{(0, 0), (1, 1), (1, 6), (1, 8), (3, 0), (4, 3), (5, 11)}

In [68]:
set(bf_3) - set(backw_3)

set()

In [69]:
set(fb_3) - (set(fb_3) - set(bf_3))

{(0, 7), (2, 9), (3, 0), (5, 10), (7, 8), (8, 6)}

In [70]:
set(bf_3) - set(fb_3)

{(2, 2), (4, 6), (5, 3), (6, 11), (7, 9), (11, 10)}

Get score of fb_3 without nodes it has that bf_3 doesn't have

this is set intersection: https://chat.openai.com/c/c15f48a7-226b-4c89-8ad9-a39a471867f5

In [72]:
mean_ablate_by_lst(list(set(fb_3) - (set(fb_3) - set(bf_3))), model, print_output=True)

Average logit difference (circuit / full) %: 82.4827


tensor(82.4827, device='cuda:0')

In [73]:
mean_ablate_by_lst(list(set(bf_3) - (set(bf_3) - set(fb_3))), model, print_output=True)

Average logit difference (circuit / full) %: 82.4827


tensor(82.4827, device='cuda:0')

In [74]:
(set(fb_3) - (set(fb_3) - set(bf_3))) == (set(bf_3) - (set(bf_3) - set(fb_3)))

True