<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-lnz_nidy
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-lnz_nidy
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.0-py3-none-any.whl (260 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.0/261.0 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7c591e92f6d0>

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [None]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9106, done.[K
remote: Counting objects: 100% (1820/1820), done.[K
remote: Compressing objects: 100% (289/289), done.[K
remote: Total 9106 (delta 1614), reused 1608 (delta 1528), pack-reused 7286[K
Receiving objects: 100% (9106/9106), 155.60 MiB | 24.62 MiB/s, done.
Resolving deltas: 100% (5507/5507), done.


In [None]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [None]:
import ioi_circuit_extraction as ioi_circuit_extraction

# test prompts

In [None]:
example_prompt = "five four three two"
example_answer = " one"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'five', ' four', ' three', ' two']
Tokenized answer: [' one']


Top 0th token. Logit: 13.84 Prob: 18.13% Token: | two|
Top 1th token. Logit: 13.69 Prob: 15.50% Token: | three|
Top 2th token. Logit: 13.57 Prob: 13.81% Token: | one|
Top 3th token. Logit: 12.59 Prob:  5.17% Token: | five|
Top 4th token. Logit: 12.57 Prob:  5.08% Token: | four|
Top 5th token. Logit: 11.84 Prob:  2.44% Token: |
|
Top 6th token. Logit: 11.70 Prob:  2.12% Token: | seven|
Top 7th token. Logit: 11.68 Prob:  2.08% Token: | six|
Top 8th token. Logit: 11.00 Prob:  1.06% Token: | a|
Top 9th token. Logit: 10.73 Prob:  0.81% Token: | eight|


In [None]:
example_prompt = "six five four three two"
example_answer = " one"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'six', ' five', ' four', ' three', ' two']
Tokenized answer: [' one']


Top 0th token. Logit: 14.36 Prob: 21.06% Token: | two|
Top 1th token. Logit: 14.20 Prob: 17.86% Token: | three|
Top 2th token. Logit: 14.11 Prob: 16.31% Token: | one|
Top 3th token. Logit: 12.81 Prob:  4.44% Token: | four|
Top 4th token. Logit: 12.76 Prob:  4.23% Token: | six|
Top 5th token. Logit: 12.68 Prob:  3.92% Token: | seven|
Top 6th token. Logit: 12.56 Prob:  3.46% Token: | five|
Top 7th token. Logit: 11.31 Prob:  1.00% Token: | eight|
Top 8th token. Logit: 11.27 Prob:  0.96% Token: | nine|
Top 9th token. Logit: 11.22 Prob:  0.91% Token: |
|


In [None]:
example_prompt = "4 6 8 10 12"
example_answer = " 14"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '4', ' 6', ' 8', ' 10', ' 12']
Tokenized answer: [' 14']


Top 0th token. Logit: 17.82 Prob: 23.88% Token: | 14|
Top 1th token. Logit: 17.81 Prob: 23.71% Token: | 13|
Top 2th token. Logit: 17.30 Prob: 14.19% Token: | 15|
Top 3th token. Logit: 16.49 Prob:  6.31% Token: | 16|
Top 4th token. Logit: 16.23 Prob:  4.89% Token: | 12|
Top 5th token. Logit: 15.97 Prob:  3.77% Token: |
|
Top 6th token. Logit: 15.87 Prob:  3.41% Token: | 17|
Top 7th token. Logit: 15.75 Prob:  3.01% Token: | 18|
Top 8th token. Logit: 15.43 Prob:  2.20% Token: | 19|
Top 9th token. Logit: 14.76 Prob:  1.12% Token: | 11|


# Generate dataset with multiple prompts

Replace io_tokens with correct answer (next in seq) and s_tokens with incorrect (the repeat)

In [None]:
class Dataset:
    def __init__(self, prompts, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["S11"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["S10"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'S11')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                    target_token = prompt[targ]
                else:
                    target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

NOTE: if the predicted token is not at high logit, then mean ablation can potentially lead to higher scores. Make sure the predicted token is near high 90s logit.

With 4 elems, gpt-2 small usually only gets to 50% for the correct token for decr, unlike incr, for 4 member seq. Often takes seqs of 10 elems to get to mid-90s (which is 4 seq of incr). So use 10 elems. See:

https://colab.research.google.com/drive/1ahWI9e0NMeAjdFNnj2vEj4d4aYIGsoNP#scrollTo=Nkbv00d0Wn2z&line=1&uniqifier=1

In [None]:
# for i in range(14, 4, -1):
for i in range(21, 11, -1):
    print(i)

21
20
19
18
17
16
15
14
13
12


In [None]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y, -1):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i-1),
            'S3': str(i-2),
            'S4': str(i-3),
            'S5': str(i-4),
            'S6': str(i-5),
            'S7': str(i-6),
            'S8': str(i-7),
            'S9': str(i-8),
            'S10': str(i-9),
            'S11': str(i-10),
            'text': f"{i} {i-1} {i-2} {i-3} {i-4} {i-5} {i-6} {i-7} {i-8} {i-9}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(21, 11)
dataset = Dataset(prompts_list, model.tokenizer, S1_is_first=True)

In [None]:
def generate_prompts_list_corr(x ,y):
    prompts_list = []
    for i in range(x, y, -1):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i-1),
            'S3': str(i-2),
            'S4': str(i-3),
            'S5': str(i-4),
            'S6': str(i-5),
            'S7': str(i-6),
            'S8': str(i-7),
            'S9': str(i-8),
            'S10': str(i-8),
            'S11': str(i-9),
            'text': f"{i} {i-1} {i-2} {i-3} {i-4} {i-5} {i-6} {i-7} {i-8} {i-8}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(21, 11)
dataset_2 = Dataset(prompts_list_2, model.tokenizer, S1_is_first=True)

Logit diff is correct - incorr token. Here, correct is S5, and incorr is S4.

Because of this, it's possible to have logit diffs HIGHER than the "full circuit" because the correct token will still be at first place, but the logit scores assigned will just be bigger (perhaps incorrect is scored even lower in the non-full circuit with a higher logit diff score)?

# Ablation Expm Functions

In [None]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [None]:
def mean_ablate_by_lst(lst, model, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        "number mover 10": lst,
        "number mover 9": lst,
        "number mover 8": lst,
        "number mover 7": lst,
        "number mover 6": lst,
        "number mover 5": lst,
        "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        "number mover 10": "S10",
        "number mover 9": "S9",
        "number mover 8": "S8",
        "number mover 7": "S7",
        "number mover 6": "S6",
        "number mover 5": "S5",
        "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

We can also prevent redundant computation of the full circuit score by storing it and just passing it in to the function.

We can also prevent redundant computation by storing the original logits instead of re-computing each time.

In [None]:
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
mean_ablate_by_lst(curr_circuit, model, print_output=False).item()

100.0

In [None]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)
orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
orig_score

tensor(3.7557, device='cuda:0')

# Ablate the model and compare with original

## Work backwards

https://www.notion.so/wlg1/Search-Methods-brainstorm-15a3020ab00b40adb79b0acf3622f5f4?pvs=4#dd6b43247d4945eda1d70ca4d4bae01d

In [None]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(11, -1, -1):  # go thru all heads in a layer first
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("\nRemoved:", (layer, head))
            print(new_score)


Removed: (11, 1)
99.579345703125

Removed: (11, 2)
99.49618530273438

Removed: (11, 3)
99.52944946289062

Removed: (11, 4)
99.99576568603516

Removed: (11, 5)
99.1644515991211

Removed: (11, 6)
98.92974090576172

Removed: (11, 7)
98.90752410888672

Removed: (11, 8)
98.03706359863281

Removed: (11, 9)
97.40703582763672

Removed: (11, 10)
97.48768615722656

Removed: (11, 11)
97.08882141113281

Removed: (10, 0)
97.125

Removed: (10, 4)
97.19386291503906

Removed: (10, 6)
97.27295684814453

Removed: (10, 8)
97.31184387207031

Removed: (10, 9)
97.21022033691406

Removed: (10, 10)
97.4030990600586

Removed: (10, 11)
97.34559631347656

Removed: (9, 0)
97.21297454833984

Removed: (9, 1)
97.07715606689453

Removed: (9, 2)
97.11175537109375

Removed: (9, 4)
97.18354034423828

Removed: (9, 6)
97.24261474609375

Removed: (9, 7)
97.54266357421875

Removed: (9, 8)
97.41954040527344

Removed: (9, 9)
102.97260284423828

Removed: (9, 10)
102.77648162841797

Removed: (9, 11)
101.89380645751953

Removed

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 98.6286


98.62860870361328

In [None]:
lst = curr_circuit

CIRCUIT = {
    "number mover": lst,
    "number mover 10": lst,
    "number mover 9": lst,
    "number mover 8": lst,
    "number mover 7": lst,
    "number mover 6": lst,
    "number mover 5": lst,
    "number mover 4": lst,
    "number mover 3": lst,
    "number mover 2": lst,
    "number mover 1": lst,
}

SEQ_POS_TO_KEEP = {
    "number mover": "end",
    "number mover 10": "S10",
    "number mover 9": "S9",
    "number mover 8": "S8",
    "number mover 7": "S7",
    "number mover 6": "S6",
    "number mover 5": "S5",
    "number mover 4": "S4",
    "number mover 3": "S3",
    "number mover 2": "S2",
    "number mover 1": "S1",
}

model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
ioi_logits_minimal = model(dataset.toks)
new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
new_score

tensor(3.7042, device='cuda:0')

In [None]:
lst = [(9,1)]

CIRCUIT = {
    "number mover": lst,
    "number mover 10": lst,
    "number mover 9": lst,
    "number mover 8": lst,
    "number mover 7": lst,
    "number mover 6": lst,
    "number mover 5": lst,
    "number mover 4": lst,
    "number mover 3": lst,
    "number mover 2": lst,
    "number mover 1": lst,
}

SEQ_POS_TO_KEEP = {
    "number mover": "end",
    "number mover 10": "S10",
    "number mover 9": "S9",
    "number mover 8": "S8",
    "number mover 7": "S7",
    "number mover 6": "S6",
    "number mover 5": "S5",
    "number mover 4": "S4",
    "number mover 3": "S3",
    "number mover 2": "S2",
    "number mover 1": "S1",
}

model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
ioi_logits_minimal = model(dataset.toks)
new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
new_score

tensor(0.2837, device='cuda:0')

In [None]:
curr_circuit

[(0, 1),
 (0, 2),
 (0, 7),
 (0, 9),
 (1, 0),
 (1, 2),
 (1, 3),
 (1, 4),
 (1, 5),
 (1, 8),
 (2, 0),
 (2, 2),
 (2, 3),
 (2, 4),
 (2, 5),
 (2, 8),
 (2, 9),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 8),
 (3, 10),
 (3, 11),
 (4, 6),
 (4, 11),
 (5, 0),
 (5, 1),
 (5, 5),
 (5, 6),
 (6, 6),
 (6, 7),
 (6, 8),
 (6, 9),
 (7, 10),
 (7, 11),
 (8, 1),
 (8, 6),
 (8, 8),
 (9, 3),
 (9, 5),
 (10, 1),
 (10, 2),
 (10, 3),
 (10, 5),
 (10, 7),
 (11, 0)]

### compare to incr seq circuit

In [None]:
decr_circ = [(0, 1), (0, 2), (0, 7), (0, 9), (1, 0), (1, 2), (1, 3), (1, 4), (1, 5), (1, 8), (2, 0), (2, 2), (2, 3), (2, 4), (2, 5), (2, 8), (2, 9), (3, 0), (3, 3), (3, 7), (3, 8), (3, 10), (3, 11), (4, 6), (4, 11), (5, 0), (5, 1), (5, 5), (5, 6), (6, 6), (6, 7), (6, 8), (6, 9), (7, 10), (7, 11), (8, 1), (8, 6), (8, 8), (9, 3), (9, 5), (10, 1), (10, 2), (10, 3), (10, 5), (10, 7), (11, 0)]
len(decr_circ)

46

In [None]:
incr_circ = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (0, 10), (1, 0), (1, 4), (1, 5), (2, 2), (2, 8), (2, 9), (3, 0), (3, 2), (3, 3), (3, 7), (4, 4), (4, 7), (4, 10), (5, 1), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (5, 9), (6, 1), (6, 4), (6, 6), (6, 10), (6, 11), (7, 6), (7, 10), (7, 11), (8, 0), (8, 5), (8, 6), (8, 8), (9, 1), (10, 7)]
len(incr_circ)

40

In [None]:
# in decr, not incr
decr_not_incr = list(set(decr_circ) - set(incr_circ))
print(len(decr_not_incr))
# decr_not_incr

25


In [None]:
incr_not_decr = list(set(incr_circ) - set(decr_circ))
print(len(incr_not_decr))
# incr_not_decr

19


So only half in common. Compared to non-number circuits?

### 10% threshold:

In [None]:
def find_circuit_backw(threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    # Start with full circuit
    curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("Removed:", (layer, head))
                print(new_score)
                print("\n")

    return curr_circuit

In [None]:
curr_circuit = find_circuit_backw(10)

Removed: (11, 0)
96.00396728515625


Removed: (11, 1)
95.59683227539062


Removed: (11, 2)
95.50833892822266


Removed: (11, 3)
95.51766967773438


Removed: (11, 4)
95.98452758789062


Removed: (11, 5)
95.16443634033203


Removed: (11, 6)
94.9478759765625


Removed: (11, 7)
94.8895492553711


Removed: (11, 8)
94.04351806640625


Removed: (11, 9)
93.41999053955078


Removed: (11, 10)
93.49880981445312


Removed: (11, 11)
93.18167114257812


Removed: (10, 0)
93.20730590820312


Removed: (10, 1)
92.98998260498047


Removed: (10, 2)
91.66492462158203


Removed: (10, 3)
91.17400360107422


Removed: (10, 4)
91.14872741699219


Removed: (10, 5)
90.63921356201172


Removed: (10, 6)
90.67035675048828


Removed: (10, 8)
90.59402465820312


Removed: (10, 9)
90.47775268554688


Removed: (10, 10)
90.64414978027344


Removed: (10, 11)
90.60283660888672


Removed: (9, 0)
90.50228118896484


Removed: (9, 1)
90.39917755126953


Removed: (9, 2)
90.39645385742188


Removed: (9, 4)
90.29425811767578


Rem

Try this method on greater-than task to see if recovers circuit similar to paper.

In [None]:
curr_circuit

[(0, 1),
 (0, 2),
 (0, 3),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 1),
 (1, 5),
 (1, 8),
 (2, 2),
 (2, 3),
 (2, 4),
 (2, 5),
 (2, 8),
 (2, 9),
 (3, 0),
 (3, 1),
 (3, 3),
 (3, 7),
 (3, 8),
 (3, 10),
 (3, 11),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 11),
 (5, 1),
 (5, 5),
 (5, 6),
 (6, 1),
 (6, 7),
 (6, 8),
 (6, 9),
 (7, 10),
 (7, 11),
 (8, 1),
 (8, 8),
 (9, 3),
 (9, 5),
 (9, 7),
 (10, 7)]

In [None]:
%%capture
curr_circuit = find_circuit_backw(20)

In [None]:
curr_circuit

[(0, 1),
 (0, 2),
 (0, 3),
 (0, 5),
 (0, 9),
 (0, 11),
 (1, 0),
 (1, 1),
 (1, 5),
 (2, 0),
 (2, 2),
 (2, 3),
 (2, 4),
 (2, 8),
 (2, 9),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 8),
 (3, 10),
 (4, 7),
 (4, 8),
 (4, 10),
 (4, 11),
 (5, 1),
 (5, 5),
 (5, 6),
 (5, 9),
 (6, 1),
 (6, 9),
 (7, 10),
 (7, 11),
 (8, 6),
 (8, 8),
 (8, 10),
 (10, 7)]

## mean ablation the circuit pruned by iterative path patching

From:

https://colab.research.google.com/drive/1onREXMNmc9ks0xpwDslUX2pdG0RSYtWS#scrollTo=ehsYSXYO_25N&line=6&uniqifier=1

In [None]:
test_circ = [(0,1), (3,0), (4,4), (5,5), (5,8), (6,6), (7,11), (9,1)]

mean_ablate_by_lst(test_circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 11.0121


11.012136459350586

From:

https://colab.research.google.com/drive/1onREXMNmc9ks0xpwDslUX2pdG0RSYtWS#scrollTo=V8JWdlVokmpL&line=6&uniqifier=1

In [None]:
test_circ = [(0, 1), (0, 5), (0, 10), (1, 5), (3, 0), (4, 4), (4, 8), (5, 1), (5, 4), (5, 5), (5, 8), (6, 1), (6, 6), (6, 9), (6, 10), (7, 6), (7, 10), (7, 11), (8, 8), (9, 1), (10, 7)]
mean_ablate_by_lst(test_circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 15.9552


15.955167770385742

## Prune forwards

In [None]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(0, 12):
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

Removed: (0, 0)
98.1362533569336


Removed: (0, 3)
97.31721496582031


Removed: (0, 4)
97.5737075805664


Removed: (0, 8)
98.42399597167969


Removed: (0, 10)
97.3432388305664


Removed: (1, 4)
98.35460662841797


Removed: (1, 6)
98.62120819091797


Removed: (1, 7)
99.97992706298828


Removed: (1, 8)
100.36349487304688


Removed: (1, 9)
99.07020568847656


Removed: (1, 10)
98.22621154785156


Removed: (1, 11)
98.57820129394531


Removed: (2, 0)
97.25369262695312


Removed: (2, 1)
98.31084442138672


Removed: (2, 3)
98.36981201171875


Removed: (2, 5)
97.4117202758789


Removed: (2, 6)
98.24701690673828


Removed: (2, 7)
98.62277221679688


Removed: (2, 8)
98.3319091796875


Removed: (2, 10)
99.87090301513672


Removed: (2, 11)
100.121826171875


Removed: (3, 1)
98.31285095214844


Removed: (3, 2)
98.00856018066406


Removed: (3, 4)
100.95858001708984


Removed: (3, 5)
101.34181213378906


Removed: (3, 6)
98.85667419433594


Removed: (3, 8)
98.59002685546875


Removed: (3, 9)
98.5112915

In [None]:
curr_circuit

[(0, 1),
 (0, 2),
 (0, 5),
 (0, 6),
 (0, 7),
 (0, 9),
 (0, 11),
 (1, 0),
 (1, 1),
 (1, 2),
 (1, 3),
 (1, 5),
 (2, 2),
 (2, 4),
 (2, 9),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 10),
 (3, 11),
 (4, 7),
 (4, 11),
 (5, 1),
 (5, 2),
 (5, 5),
 (5, 6),
 (5, 8),
 (5, 11),
 (6, 1),
 (6, 4),
 (6, 5),
 (6, 7),
 (6, 9),
 (7, 2),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 5),
 (8, 6),
 (8, 7),
 (8, 8),
 (8, 10),
 (9, 5),
 (10, 7),
 (11, 0),
 (11, 8),
 (11, 9),
 (11, 11)]

## prune fwds then back iteratively- fns

In [None]:
def find_circuit_forw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [None]:
def find_circuit_backw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [None]:
curr_circuit = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (3, 0), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 8), (4, 10), (4, 11), (5, 4), (5, 5), (5, 9), (6, 1), (6, 6), (6, 10), (7, 6), (7, 10), (7, 11), (8, 1), (8, 2), (8, 6), (8, 8), (9, 1), (9, 5), (10, 7), (11, 10)]
curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=2)

In [None]:
curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=2)

### iter fwd backw, threshold 2

In [None]:
threshold = 2
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
98.1362533569336

Removed: (0, 4)
98.33680725097656

Removed: (0, 8)
99.32208251953125

Removed: (0, 11)
98.65502166748047

Removed: (1, 2)
98.06524658203125

Removed: (1, 4)
99.14366149902344

Removed: (1, 6)
99.34202575683594

Removed: (1, 7)
100.14029693603516

Removed: (1, 8)
100.11001586914062

Removed: (1, 9)
100.07388305664062

Removed: (1, 10)
99.47722625732422

Removed: (1, 11)
100.22380828857422

Removed: (2, 0)
99.36809539794922

Removed: (2, 1)
100.16586303710938

Removed: (2, 3)
100.25131225585938

Removed: (2, 5)
98.904296875

Removed: (2, 6)
99.13590240478516

Removed: (2, 7)
99.78890991210938

Removed: (2, 8)
98.92216491699219

Removed: (2, 10)
100.53916931152344

Removed: (2, 11)
100.54846954345703

Removed: (3, 1)
99.47002410888672

Removed: (3, 2)
99.65966796875

Removed: (3, 4)
103.1480712890625

Removed: (3, 5)
103.5404281616211

Removed: (3, 6)
100.66483306884766

Removed: (3, 8)
100.59037780761719

Removed: (3, 9)
100.42975616

In [None]:
curr_circuit

[(0, 1),
 (0, 3),
 (0, 5),
 (0, 7),
 (0, 9),
 (1, 0),
 (1, 5),
 (2, 2),
 (2, 4),
 (2, 9),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 10),
 (4, 6),
 (4, 7),
 (4, 10),
 (4, 11),
 (5, 1),
 (5, 5),
 (5, 6),
 (6, 1),
 (6, 7),
 (6, 9),
 (7, 2),
 (7, 10),
 (7, 11),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 10),
 (9, 5),
 (10, 7),
 (11, 0),
 (11, 8),
 (11, 11)]

Stragenly, if you prune fwd first instead of backw first, there are heads there that are present in "fwd first" that are not in "backw first", such as 11.11. This means the circuit is highly variable on candidate head selection order and doesn't always converge to the same circuit. Devise a more robust method that does.

### compare to incr seq circuit

In [None]:
decr_circ = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (1, 0), (1, 5), (2, 2), (2, 4), (2, 9), (3, 0), (3, 3), (3, 7), (3, 10), (4, 6), (4, 7), (4, 10), (4, 11), (5, 1), (5, 5), (5, 6), (6, 1), (6, 7), (6, 9), (7, 2), (7, 10), (7, 11), (8, 1), (8, 6), (8, 8), (8, 10), (9, 5), (10, 7), (11, 0), (11, 8), (11, 11)]
len(decr_circ)

36

Obtain below from: https://colab.research.google.com/drive/1CHRn-AMko9RNrl1bqiCwB7DS-rz1CoBP#scrollTo=e8OFeKuxzM3R

In [None]:
incr_circ = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (0, 10), (1, 0), (1, 5), (2, 2), (3, 0), (3, 3), (3, 7), (3, 10), (4, 4), (4, 6), (4, 7), (4, 8), (4, 10), (4, 11), (5, 4), (5, 5), (5, 8), (5, 9), (6, 1), (6, 6), (6, 10), (7, 6), (7, 10), (7, 11), (8, 0), (8, 6), (8, 8), (9, 1), (9, 5), (10, 7)]
len(incr_circ)

35

In [None]:
# in decr, not incr
decr_not_incr = list(set(decr_circ) - set(incr_circ))
print(len(decr_not_incr))
# decr_not_incr

12


In [None]:
decr_not_incr

[(2, 4),
 (8, 1),
 (11, 0),
 (8, 10),
 (5, 1),
 (6, 7),
 (2, 9),
 (5, 6),
 (7, 2),
 (11, 8),
 (6, 9),
 (11, 11)]

In [None]:
incr_not_decr = list(set(incr_circ) - set(decr_circ))
print(len(incr_not_decr))
incr_not_decr

11


[(4, 4),
 (0, 10),
 (5, 8),
 (5, 4),
 (8, 0),
 (7, 6),
 (6, 10),
 (4, 8),
 (6, 6),
 (5, 9),
 (9, 1)]

### iter fwd backw, threshold 25

In [None]:
threshold = 25
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
98.1362533569336

Removed: (0, 1)
79.41375732421875

Removed: (0, 2)
78.06918334960938

Removed: (0, 3)
75.37031555175781

Removed: (0, 4)
76.41148376464844

Removed: (0, 6)
76.55220794677734

Removed: (0, 7)
76.40447235107422

Removed: (0, 8)
75.88678741455078

Removed: (0, 10)
84.5561752319336

Removed: (0, 11)
83.57477569580078

Removed: (1, 0)
79.58251953125

Removed: (1, 1)
77.78614044189453

Removed: (1, 2)
76.90454864501953

Removed: (1, 3)
75.53748321533203

Removed: (1, 4)
76.80133056640625

Removed: (1, 6)
77.70860290527344

Removed: (1, 7)
79.79562377929688

Removed: (1, 8)
79.15569305419922

Removed: (1, 9)
77.77128601074219

Removed: (1, 10)
78.7978515625

Removed: (2, 0)
76.75865173339844

Removed: (2, 1)
76.01905822753906

Removed: (2, 6)
75.34153747558594

Removed: (2, 7)
76.35138702392578

Removed: (2, 11)
77.48845672607422

Removed: (3, 1)
75.66424560546875

Removed: (3, 2)
77.00819396972656

Removed: (3, 4)
80.5550537109375

Remov

In [None]:
curr_circuit

[(0, 5),
 (0, 9),
 (1, 5),
 (1, 11),
 (2, 2),
 (2, 3),
 (2, 4),
 (2, 5),
 (2, 8),
 (2, 9),
 (2, 10),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 10),
 (3, 11),
 (4, 6),
 (4, 11),
 (5, 5),
 (6, 1),
 (6, 9),
 (6, 11),
 (7, 6),
 (7, 7),
 (7, 8),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 2),
 (8, 6),
 (8, 8),
 (9, 5),
 (11, 0),
 (11, 8)]

# loop rmv and check for most impt heads

from: https://colab.research.google.com/drive/1odPpf7w_gBG8ZfAB2L6SXZszsDUk1CGA#scrollTo=ET--8aulD8pE&line=1&uniqifier=1

In [None]:
decr_circ = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (1, 0), (1, 5), (2, 2), (2, 4), (2, 9), (3, 0), (3, 3), (3, 7), (3, 10), (4, 6), (4, 7), (4, 10), (4, 11), (5, 1), (5, 5), (5, 6), (6, 1), (6, 7), (6, 9), (7, 2), (7, 10), (7, 11), (8, 1), (8, 6), (8, 8), (8, 10), (9, 5), (10, 7), (11, 0), (11, 8), (11, 11)]
mean_ablate_by_lst(decr_circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 98.0855


98.08546447753906

In [None]:
for lh in decr_circ:
    copy_circuit = decr_circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()

removed: (0, 1)
Average logit difference (circuit / full) %: 82.5277
removed: (0, 3)
Average logit difference (circuit / full) %: 97.1011
removed: (0, 5)
Average logit difference (circuit / full) %: 97.3105
removed: (0, 7)
Average logit difference (circuit / full) %: 95.5662
removed: (0, 9)
Average logit difference (circuit / full) %: 97.4591
removed: (1, 0)
Average logit difference (circuit / full) %: 93.7403
removed: (1, 5)
Average logit difference (circuit / full) %: 95.2882
removed: (2, 2)
Average logit difference (circuit / full) %: 93.2919
removed: (2, 4)
Average logit difference (circuit / full) %: 96.4548
removed: (2, 9)
Average logit difference (circuit / full) %: 94.6283
removed: (3, 0)
Average logit difference (circuit / full) %: 93.9613
removed: (3, 3)
Average logit difference (circuit / full) %: 92.7457
removed: (3, 7)
Average logit difference (circuit / full) %: 94.1664
removed: (3, 10)
Average logit difference (circuit / full) %: 93.3460
removed: (4, 6)
Average logit dif

# try removing combos

In [None]:
# rmv all from L8: (8, 1), (8, 6), (8, 8), (8, 10), (
decr_circ = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (1, 0), (1, 5), (2, 2), (2, 4), (2, 9), (3, 0), (3, 3), (3, 7), (3, 10), (4, 6), (4, 7), (4, 10), (4, 11), (5, 1), (5, 5), (5, 6), (6, 1), (6, 7), (6, 9), (7, 2), (7, 10), (7, 11), (9, 5), (10, 7), (11, 0), (11, 8), (11, 11)]
mean_ablate_by_lst(decr_circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 81.0462


81.04621887207031

In [None]:
# rmv all from L8 to 11: (8, 1), (8, 6), (8, 8), (8, 10), (9, 5), (10, 7), (11, 0), (11, 8), (11, 11)
decr_circ = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (1, 0), (1, 5), (2, 2), (2, 4), (2, 9), (3, 0), (3, 3), (3, 7), (3, 10), (4, 6), (4, 7), (4, 10), (4, 11), (5, 1), (5, 5), (5, 6), (6, 1), (6, 7), (6, 9), (7, 2), (7, 10), (7, 11)]
mean_ablate_by_lst(decr_circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 65.1191


65.11909484863281

In [None]:
# rmv 7.11 and all from L8 to 11: (8, 1), (8, 6), (8, 8), (8, 10), (9, 5), (10, 7), (11, 0), (11, 8), (11, 11)
decr_circ = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (1, 0), (1, 5), (2, 2), (2, 4), (2, 9), (3, 0), (3, 3), (3, 7), (3, 10), (4, 6), (4, 7), (4, 10), (4, 11), (5, 1), (5, 5), (5, 6), (6, 1), (6, 7), (6, 9), (7, 2), (7, 10)]
mean_ablate_by_lst(decr_circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 55.8640


55.864009857177734

In [None]:
# rmv 7,10, 7.11 and all from L8 to 11: (8, 1), (8, 6), (8, 8), (8, 10), (9, 5), (10, 7), (11, 0), (11, 8), (11, 11)
decr_circ = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (1, 0), (1, 5), (2, 2), (2, 4), (2, 9), (3, 0), (3, 3), (3, 7), (3, 10), (4, 6), (4, 7), (4, 10), (4, 11), (5, 1), (5, 5), (5, 6), (6, 1), (6, 7), (6, 9), (7, 2)]
mean_ablate_by_lst(decr_circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 25.0441


25.044111251831055