<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-f5rdk69j
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-f5rdk69j
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x78cb41223ca0>

Plotting helper functions:

In [6]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [8]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9106, done.[K
remote: Counting objects: 100% (1820/1820), done.[K
remote: Compressing objects: 100% (289/289), done.[K
remote: Total 9106 (delta 1614), reused 1608 (delta 1528), pack-reused 7286[K
Receiving objects: 100% (9106/9106), 155.60 MiB | 23.20 MiB/s, done.
Resolving deltas: 100% (5507/5507), done.


In [9]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [10]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

In [11]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

Repalce io_tokens with correct answer (next, which is '5') and s_tokens with incorrect (current, which repeats)

In [12]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
}

In [13]:
def generate_prompts_list(x ,y):
    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    prompts_list = []
    # for i in range(x, y):
    for i in range(0, 8):
        prompt_dict = {
            'S1': months[i],
            'S2': months[i+1],
            'S3': months[i+2],
            'S4': months[i+3],
            'corr': months[i+4] if i+4 < len(months) else 'None',
            'incorr': months[i],
            'text': f"{months[i]} {months[i+1]} {months[i+2]} {months[i+2]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 11)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [14]:
prompts_list

[{'S1': 'January',
  'S2': 'February',
  'S3': 'March',
  'S4': 'April',
  'corr': 'May',
  'incorr': 'January',
  'text': 'January February March March'},
 {'S1': 'February',
  'S2': 'March',
  'S3': 'April',
  'S4': 'May',
  'corr': 'June',
  'incorr': 'February',
  'text': 'February March April April'},
 {'S1': 'March',
  'S2': 'April',
  'S3': 'May',
  'S4': 'June',
  'corr': 'July',
  'incorr': 'March',
  'text': 'March April May May'},
 {'S1': 'April',
  'S2': 'May',
  'S3': 'June',
  'S4': 'July',
  'corr': 'August',
  'incorr': 'April',
  'text': 'April May June June'},
 {'S1': 'May',
  'S2': 'June',
  'S3': 'July',
  'S4': 'August',
  'corr': 'September',
  'incorr': 'May',
  'text': 'May June July July'},
 {'S1': 'June',
  'S2': 'July',
  'S3': 'August',
  'S4': 'September',
  'corr': 'October',
  'incorr': 'June',
  'text': 'June July August August'},
 {'S1': 'July',
  'S2': 'August',
  'S3': 'September',
  'S4': 'October',
  'corr': 'November',
  'incorr': 'July',
  'text':

In [18]:
import random

def generate_prompts_list_corr(x ,y):
    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    prompts_list = []
    # for i in range(x, y):
    for i in range(0, 8):
        r1 = random.choice(months)
        r2 = random.choice(months)
        while True:
            r3_ind = random.randint(0,len(months)-1)
            r4_ind = random.randint(0,len(months)-1)
            if months[r3_ind] != months[r4_ind-1]:
                break
        r3 = months[r3_ind]
        r4 = months[r4_ind]
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': str(r1),
            'incorr': str(r4),
            'text': f"{r1} {r2} {r3} {r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(1, 11)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)
prompts_list_2

[{'S1': 'May',
  'S2': 'June',
  'S3': 'December',
  'S4': 'June',
  'corr': 'May',
  'incorr': 'June',
  'text': 'May June December June'},
 {'S1': 'April',
  'S2': 'February',
  'S3': 'March',
  'S4': 'October',
  'corr': 'April',
  'incorr': 'October',
  'text': 'April February March October'},
 {'S1': 'August',
  'S2': 'April',
  'S3': 'December',
  'S4': 'December',
  'corr': 'August',
  'incorr': 'December',
  'text': 'August April December December'},
 {'S1': 'October',
  'S2': 'December',
  'S3': 'February',
  'S4': 'November',
  'corr': 'October',
  'incorr': 'November',
  'text': 'October December February November'},
 {'S1': 'January',
  'S2': 'December',
  'S3': 'May',
  'S4': 'August',
  'corr': 'January',
  'incorr': 'August',
  'text': 'January December May August'},
 {'S1': 'April',
  'S2': 'April',
  'S3': 'April',
  'S4': 'September',
  'corr': 'April',
  'incorr': 'September',
  'text': 'April April April September'},
 {'S1': 'November',
  'S2': 'October',
  'S3': 'J

Logit diff is correct - incorr token. Here, correct is S5, and incorr is S4.

Because of this, it's possible to have logit diffs HIGHER than the "full circuit" because the correct token will still be at first place, but the logit scores assigned will just be bigger (perhaps incorrect is scored even lower in the non-full circuit with a higher logit diff score)?

# Ablation Expm Functions

In [19]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [20]:
def mean_ablate_by_lst(lst, model, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        # "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        # "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

In [21]:
def find_circuit_forw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [22]:
def find_circuit_backw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

We can also prevent redundant computation of the full circuit score by storing it and just passing it in to the function.

# Ablate the model and compare with original

### try incr digits circ

https://colab.research.google.com/drive/1mFWmGAKtigFcqqWWMCwU7wWQY2HT5ZOo#scrollTo=lJEY-Zs2g_a5&line=1&uniqifier=1

iter fwd backw, threshold 20

In [23]:
curr_circuit = [(1, 5), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (6, 3), (6, 8), (6, 10), (7, 0), (7, 2), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (11, 8)]
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 96.5418


96.5418472290039

In [24]:
curr_circuit = [(9, 1)]
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 42.9089


42.908897399902344

In [25]:
curr_circuit = [(1, 1)]
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 7.6947


7.694738388061523

In [26]:
curr_circuit = []
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 7.7786


7.778631687164307

## Prune backwards

In [None]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(11, -1, -1):  # go thru all heads in a layer first
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

Removed: (11, 0)
99.77244567871094


Removed: (11, 1)
99.6295166015625


Removed: (11, 2)
99.65966796875


Removed: (11, 3)
99.51300811767578


Removed: (11, 4)
99.73229217529297


Removed: (11, 5)
99.77629089355469


Removed: (11, 6)
99.81084442138672


Removed: (11, 7)
99.97366333007812


Removed: (11, 9)
99.88330841064453


Removed: (11, 10)
98.41625213623047


Removed: (11, 11)
98.8663330078125


Removed: (10, 0)
98.8320083618164


Removed: (10, 1)
99.31415557861328


Removed: (10, 2)
98.99121856689453


Removed: (10, 3)
98.36695098876953


Removed: (10, 4)
98.44320678710938


Removed: (10, 5)
98.33000946044922


Removed: (10, 6)
98.05836486816406


Removed: (10, 7)
97.35382843017578


Removed: (10, 8)
97.62989044189453


Removed: (10, 9)
97.55817413330078


Removed: (10, 10)
97.22981262207031


Removed: (10, 11)
97.37190246582031


Removed: (9, 0)
97.42068481445312


Removed: (9, 2)
97.38284301757812


Removed: (9, 4)
97.38111877441406


Removed: (9, 5)
99.59906768798828


Removed

In [None]:
curr_circuit

[(0, 0),
 (0, 1),
 (0, 3),
 (0, 5),
 (1, 0),
 (2, 2),
 (2, 3),
 (2, 4),
 (2, 5),
 (4, 0),
 (4, 1),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 8),
 (5, 0),
 (5, 8),
 (6, 9),
 (6, 10),
 (7, 3),
 (7, 5),
 (7, 6),
 (7, 7),
 (7, 8),
 (7, 11),
 (8, 6),
 (8, 8),
 (8, 9),
 (8, 11),
 (9, 1),
 (9, 3),
 (9, 11),
 (11, 8)]

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True)

Average logit difference (circuit / full) %: 98.4271


tensor(98.4271, device='cuda:0')

In [None]:
len(curr_circuit)

33

In [None]:
backw_3 = curr_circuit.copy()
backw_3

[(0, 1),
 (0, 5),
 (4, 4),
 (4, 8),
 (5, 0),
 (6, 9),
 (6, 10),
 (7, 5),
 (7, 11),
 (8, 6),
 (8, 8),
 (8, 9),
 (8, 11),
 (9, 1),
 (10, 7),
 (11, 10)]

In [None]:
len(backw_3)

16

Now try 10% threshold:

In [None]:
def find_circuit_backw(threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    # Start with full circuit
    curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("Removed:", (layer, head))
                print(new_score)
                print("\n")

    return curr_circuit

In [None]:
curr_circuit = find_circuit_backw(10)

Removed: (11, 0)
99.78748321533203


Removed: (11, 1)
99.6786117553711


Removed: (11, 2)
99.76538848876953


Removed: (11, 3)
99.75588989257812


Removed: (11, 4)
99.9180679321289


Removed: (11, 5)
99.94261932373047


Removed: (11, 6)
100.0663070678711


Removed: (11, 7)
100.22264099121094


Removed: (11, 8)
97.96847534179688


Removed: (11, 9)
97.76232147216797


Removed: (11, 10)
96.1368637084961


Removed: (11, 11)
96.40777587890625


Removed: (10, 0)
96.45589447021484


Removed: (10, 1)
96.91429138183594


Removed: (10, 2)
96.4229965209961


Removed: (10, 3)
96.04180145263672


Removed: (10, 4)
96.06706237792969


Removed: (10, 5)
95.9412841796875


Removed: (10, 6)
96.0024185180664


Removed: (10, 7)
94.67149353027344


Removed: (10, 8)
94.81072998046875


Removed: (10, 9)
94.75505065917969


Removed: (10, 10)
94.89752960205078


Removed: (10, 11)
95.10176849365234


Removed: (9, 0)
95.11750793457031


Removed: (9, 2)
95.02252197265625


Removed: (9, 3)
94.62155151367188


Remov

In [None]:
backw_10 = curr_circuit.copy()
backw_10

[(0, 1),
 (1, 0),
 (2, 2),
 (4, 4),
 (4, 7),
 (4, 8),
 (5, 0),
 (5, 1),
 (6, 9),
 (6, 10),
 (7, 11),
 (8, 11),
 (9, 1)]

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True)

Average logit difference (circuit / full) %: 90.7358


tensor(90.7358, device='cuda:0')

In [None]:
len(backw_10)

13

20%:

In [None]:
%%capture
curr_circuit = find_circuit_backw(20)

In [None]:
backw_20 = curr_circuit.copy()
backw_20

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True)

In [None]:
len(backw_20)

### set diffs of the three perf lvls

In [None]:
set(backw_3) - set(backw_10)

In [None]:
set(backw_10) - set(backw_3)

In [None]:
set(backw_3) - set(backw_20)

In [None]:
set(backw_10) - set(backw_20)

## Prune forwards

In [None]:
# # Start with full circuit
# curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
# threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

# for layer in range(0, 12):
#     for head in range(12):
#         # Copying the curr_circuit so we can iterate over one and modify the other
#         copy_circuit = curr_circuit.copy()

#         # Temporarily removing the current tuple from the copied circuit
#         copy_circuit.remove((layer, head))

#         new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

#         # print((layer,head), new_score)
#         # If the result is less than the threshold, remove the tuple from the original list
#         if (100 - new_score) < threshold:
#             curr_circuit.remove((layer, head))

#             print("Removed:", (layer, head))
#             print(new_score)
#             print("\n")

## Prune fwds-backwds iteratively

### iter fwd backw, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
98.89445495605469

Removed: (0, 1)
98.38172149658203

Removed: (0, 2)
97.99978637695312

Removed: (0, 3)
98.59839630126953

Removed: (0, 4)
99.2354507446289

Removed: (0, 5)
98.79082489013672

Removed: (0, 6)
98.61289978027344

Removed: (0, 7)
98.08831024169922

Removed: (0, 8)
98.00670623779297

Removed: (0, 9)
97.748291015625

Removed: (0, 10)
99.19110870361328

Removed: (0, 11)
99.34892272949219

Removed: (1, 0)
98.88248443603516

Removed: (1, 1)
98.90789794921875

Removed: (1, 2)
99.0909194946289

Removed: (1, 3)
98.59939575195312

Removed: (1, 4)
98.48860931396484

Removed: (1, 5)
97.42794036865234

Removed: (1, 6)
97.32422637939453

Removed: (1, 7)
97.27044677734375

Removed: (1, 8)
97.0673599243164

Removed: (1, 9)
97.24008178710938

Removed: (1, 10)
97.49906158447266

Removed: (1, 11)
97.39950561523438

Removed: (2, 0)
97.56668853759766

Removed: (2, 1)
97.47888946533203

Removed: (2, 2)
97.61055755615234

Removed: (2, 3)
97.81840515136719



In [None]:
fb_3 = curr_circuit.copy()
fb_3

[(4, 4),
 (4, 7),
 (4, 8),
 (4, 11),
 (5, 0),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 8),
 (5, 10),
 (5, 11),
 (6, 0),
 (6, 7),
 (6, 9),
 (6, 10),
 (7, 5),
 (7, 6),
 (7, 7),
 (7, 10),
 (7, 11),
 (8, 2),
 (8, 3),
 (8, 6),
 (8, 7),
 (8, 8),
 (8, 9),
 (8, 11),
 (9, 1),
 (9, 3),
 (9, 11),
 (10, 3),
 (10, 7),
 (10, 10),
 (11, 0),
 (11, 8),
 (11, 10)]

In [None]:
mean_ablate_by_lst(fb_3, model, print_output=True)

Average logit difference (circuit / full) %: 97.0279


tensor(97.0279, device='cuda:0')

In [None]:
mean_ablate_by_lst(fb_3 + [(6, 9)], model, print_output=True)

Average logit difference (circuit / full) %: 97.0279


tensor(97.0279, device='cuda:0')

In [None]:
len(fb_3)

36

#### compare

In [None]:
set(backw_3) - set(fb_3)

NameError: ignored

In [None]:
set(fb_3) - set(backw_3)

### iter fwd backw, threshold 20

In [None]:
threshold = 20
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1

In [None]:
fb_3 = curr_circuit.copy()
fb_3

[(4, 4),
 (4, 7),
 (4, 8),
 (4, 11),
 (5, 0),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 8),
 (5, 10),
 (5, 11),
 (6, 0),
 (6, 7),
 (6, 9),
 (6, 10),
 (7, 5),
 (7, 6),
 (7, 7),
 (7, 10),
 (7, 11),
 (8, 2),
 (8, 3),
 (8, 6),
 (8, 7),
 (8, 8),
 (8, 9),
 (8, 11),
 (9, 1),
 (9, 3),
 (9, 11),
 (10, 3),
 (10, 7),
 (10, 10),
 (11, 0),
 (11, 8),
 (11, 10)]

In [None]:
mean_ablate_by_lst(fb_3, model, print_output=True)

Average logit difference (circuit / full) %: 97.0279


tensor(97.0279, device='cuda:0')

In [None]:
mean_ablate_by_lst(fb_3 + [(6, 9)], model, print_output=True)

Average logit difference (circuit / full) %: 97.0279


tensor(97.0279, device='cuda:0')

In [None]:
len(fb_3)

36

## Prune backwds-fwds iteratively

### iter fwd backw, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
99.77244567871094

Removed: (11, 1)
99.6295166015625

Removed: (11, 2)
99.65966796875

Removed: (11, 3)
99.51300811767578

Removed: (11, 4)
99.73229217529297

Removed: (11, 5)
99.77629089355469

Removed: (11, 6)
99.81084442138672

Removed: (11, 7)
99.97366333007812

Removed: (11, 9)
99.88330841064453

Removed: (11, 10)
98.41625213623047

Removed: (11, 11)
98.8663330078125

Removed: (10, 0)
98.8320083618164

Removed: (10, 1)
99.31415557861328

Removed: (10, 2)
98.99121856689453

Removed: (10, 3)
98.36695098876953

Removed: (10, 4)
98.44320678710938

Removed: (10, 5)
98.33000946044922

Removed: (10, 6)
98.05836486816406

Removed: (10, 7)
97.35382843017578

Removed: (10, 8)
97.62989044189453

Removed: (10, 9)
97.55817413330078

Removed: (10, 10)
97.22981262207031

Removed: (10, 11)
97.37190246582031

Removed: (9, 0)
97.42068481445312

Removed: (9, 2)
97.38284301757812

Removed: (9, 4)
97.38111877441406

Removed: (9, 5)
99.59906768798828

Removed: (9

In [None]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (0, 5),
 (1, 0),
 (2, 2),
 (2, 4),
 (2, 5),
 (4, 0),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 8),
 (5, 0),
 (5, 8),
 (6, 9),
 (6, 10),
 (7, 3),
 (7, 5),
 (7, 6),
 (7, 7),
 (7, 8),
 (7, 11),
 (8, 6),
 (8, 8),
 (8, 9),
 (8, 11),
 (9, 1),
 (9, 3),
 (9, 11),
 (11, 8)]

In [None]:
len(bf_3)

29

In [None]:
backw_3 = [(0, 1), (0, 5), (4, 4), (4, 8), (5, 0), (6, 9), (6, 10), (7, 5), (7, 11), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (10, 7), (11, 10)]

In [None]:
len(backw_3)

16

In [None]:
mean_ablate_by_lst(bf_3, model, print_output=True)

Average logit difference (circuit / full) %: 97.0140


tensor(97.0140, device='cuda:0')

In [None]:
mean_ablate_by_lst(backw_3, model, print_output=True)

Average logit difference (circuit / full) %: 86.6864


tensor(86.6864, device='cuda:0')

#### compare

In [None]:
len(bf_3)

29

In [None]:
len(fb_3)

36

In [None]:
set(backw_3) - set(bf_3)

{(10, 7), (11, 10)}

In [None]:
set(bf_3) - set(backw_3)

{(1, 0),
 (2, 2),
 (2, 4),
 (2, 5),
 (4, 0),
 (4, 6),
 (4, 7),
 (5, 8),
 (7, 3),
 (7, 6),
 (7, 7),
 (7, 8),
 (9, 3),
 (9, 11),
 (11, 8)}

In [None]:
set(fb_3) - (set(fb_3) - set(bf_3))

{(4, 4),
 (4, 7),
 (4, 8),
 (5, 0),
 (5, 8),
 (6, 9),
 (6, 10),
 (7, 5),
 (7, 6),
 (7, 7),
 (7, 11),
 (8, 6),
 (8, 8),
 (8, 9),
 (8, 11),
 (9, 1),
 (9, 3),
 (9, 11),
 (11, 8)}

In [None]:
set(bf_3) - set(fb_3)

{(0, 1),
 (0, 5),
 (1, 0),
 (2, 2),
 (2, 4),
 (2, 5),
 (4, 0),
 (4, 6),
 (7, 3),
 (7, 8)}

Get score of fb_3 without nodes it has that bf_3 doesn't have

this is set intersection: https://chat.openai.com/c/c15f48a7-226b-4c89-8ad9-a39a471867f5

In [None]:
mean_ablate_by_lst(list(set(fb_3) - (set(fb_3) - set(bf_3))), model, print_output=True)

Average logit difference (circuit / full) %: 86.5529


tensor(86.5529, device='cuda:0')

In [None]:
mean_ablate_by_lst(list(set(bf_3) - (set(bf_3) - set(fb_3))), model, print_output=True)

Average logit difference (circuit / full) %: 86.5529


tensor(86.5529, device='cuda:0')

In [None]:
(set(fb_3) - (set(fb_3) - set(bf_3))) == (set(bf_3) - (set(bf_3) - set(fb_3)))

True

# manually rmv and check for most impt heads

In [None]:
incr_digits_circ = [(0, 1), (0, 2), (0, 3), (0, 5), (0, 9), (0, 10), (1, 0), (1, 5), (1, 7), (2, 2), (2, 8), (2, 10), (3, 3), (3, 7), (3, 8), (4, 4), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (5, 8), (6, 1), (6, 3), (6, 9), (6, 10), (6, 11), (7, 0), (7, 9), (7, 10), (7, 11), (8, 11), (9, 1), (11, 10)]
mean_ablate_by_lst(incr_digits_circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 88.4129


88.41292572021484

https://colab.research.google.com/drive/12HF5UCvMERizkhOiYJKDziahgVq_3KD9#scrollTo=C2EgKgmJS4qb&line=1&uniqifier=1

In [None]:
months_circuit = [(0, 1), (0, 5), (1, 0), (2, 2), (2, 4), (2, 5), (4, 0), (4, 4), (4, 6), (4, 7), (4, 8), (5, 0), (5, 8), (6, 9), (6, 10), (7, 3), (7, 5), (7, 6), (7, 7), (7, 8), (7, 11), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (9, 3), (9, 11), (11, 8)]
mean_ablate_by_lst(months_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 99.4110


99.41104888916016

In [None]:
for lh in months_circuit:
    copy_circuit = months_circuit.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()

removed: (0, 1)
Average logit difference (circuit / full) %: 95.4160
removed: (0, 5)
Average logit difference (circuit / full) %: 98.8310
removed: (1, 0)
Average logit difference (circuit / full) %: 98.9989
removed: (2, 2)
Average logit difference (circuit / full) %: 99.2236
removed: (2, 4)
Average logit difference (circuit / full) %: 99.3856
removed: (2, 5)
Average logit difference (circuit / full) %: 99.0607
removed: (4, 0)
Average logit difference (circuit / full) %: 99.1574
removed: (4, 4)
Average logit difference (circuit / full) %: 65.2456
removed: (4, 6)
Average logit difference (circuit / full) %: 99.2335
removed: (4, 7)
Average logit difference (circuit / full) %: 98.8957
removed: (4, 8)
Average logit difference (circuit / full) %: 98.4582
removed: (5, 0)
Average logit difference (circuit / full) %: 96.6405
removed: (5, 8)
Average logit difference (circuit / full) %: 97.8162
removed: (6, 9)
Average logit difference (circuit / full) %: 98.6372
removed: (6, 10)
Average logit dif

# compare with desc

from: https://colab.research.google.com/drive/1odPpf7w_gBG8ZfAB2L6SXZszsDUk1CGA#scrollTo=ET--8aulD8pE&line=1&uniqifier=1

In [None]:
decr_circ = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (1, 0), (1, 5), (2, 2), (2, 4), (2, 9), (3, 0), (3, 3), (3, 7), (3, 10), (4, 6), (4, 7), (4, 10), (4, 11), (5, 1), (5, 5), (5, 6), (6, 1), (6, 7), (6, 9), (7, 2), (7, 10), (7, 11), (8, 1), (8, 6), (8, 8), (8, 10), (9, 5), (10, 7), (11, 0), (11, 8), (11, 11)]

In [None]:
set(incr_circ) - set(decr_circ)

In [None]:
set(decr_circ) - set(incr_circ)