Only keep (not ablate) the query activations (query vector, the output of query weights times inputs) of certain positions. Keep all the key activations; the query positions that were kept will automatically attend to relevant key positions by matrix multiplication.

# Setup
(No need to change anything)

## import libs

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-en8z19cx
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-en8z19cx
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7da58c1e3be0>

Plotting helper functions:

In [6]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-medium",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-medium into HookedTransformer


## Import functions from repo

In [8]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9106, done.[K
remote: Counting objects: 100% (1820/1820), done.[K
remote: Compressing objects: 100% (289/289), done.[K
remote: Total 9106 (delta 1614), reused 1608 (delta 1528), pack-reused 7286[K
Receiving objects: 100% (9106/9106), 155.60 MiB | 16.25 MiB/s, done.
Resolving deltas: 100% (5507/5507), done.


In [9]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [10]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

In [11]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [12]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
    'S5': 4,
    'S6': 5,
}

In [13]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+3),
            'S5': str(i+4),
            'S6': str(i+5),
            'corr': str(i+6),
            'incorr': str(i+5),
            'text': f"{i} {i+1} {i+2} {i+3} {i+4} {i+5}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 101)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [14]:
import random

def generate_prompts_list_corr(x ,y):
    prompts_list = []
    for i in range(x, y):
        r1 = random.randint(1, 100)
        r2 = random.randint(1, 100)
        r3 = random.randint(1, 100)
        r4 = random.randint(1, 100)
        r5 = random.randint(1, 100)
        r6 = random.randint(1, 100)
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'S5': str(r5),
            'S6': str(r6),
            'corr': str(r1),
            'incorr': str(i+6),
            'text': f"{r1} {r2} {r3} {r4} {r5} {r6}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(1, 101)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)
prompts_list_2

[{'S1': '81',
  'S2': '92',
  'S3': '80',
  'S4': '94',
  'S5': '7',
  'S6': '38',
  'corr': '81',
  'incorr': '7',
  'text': '81 92 80 94 7 38'},
 {'S1': '81',
  'S2': '75',
  'S3': '4',
  'S4': '1',
  'S5': '45',
  'S6': '11',
  'corr': '81',
  'incorr': '8',
  'text': '81 75 4 1 45 11'},
 {'S1': '21',
  'S2': '24',
  'S3': '28',
  'S4': '36',
  'S5': '77',
  'S6': '19',
  'corr': '21',
  'incorr': '9',
  'text': '21 24 28 36 77 19'},
 {'S1': '19',
  'S2': '43',
  'S3': '16',
  'S4': '89',
  'S5': '92',
  'S6': '35',
  'corr': '19',
  'incorr': '10',
  'text': '19 43 16 89 92 35'},
 {'S1': '80',
  'S2': '88',
  'S3': '23',
  'S4': '35',
  'S5': '65',
  'S6': '28',
  'corr': '80',
  'incorr': '11',
  'text': '80 88 23 35 65 28'},
 {'S1': '64',
  'S2': '98',
  'S3': '27',
  'S4': '84',
  'S5': '90',
  'S6': '100',
  'corr': '64',
  'incorr': '12',
  'text': '64 98 27 84 90 100'},
 {'S1': '53',
  'S2': '62',
  'S3': '19',
  'S4': '17',
  'S5': '96',
  'S6': '85',
  'corr': '53',
  'inco

Logit diff is correct - incorr token. Here, correct is S5, and incorr is S4.

Because of this, it's possible to have logit diffs HIGHER than the "full circuit" because the correct token will still be at first place, but the logit scores assigned will just be bigger (perhaps incorrect is scored even lower in the non-full circuit with a higher logit diff score)?

# Ablation Expm Functions

In [15]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

We can also prevent redundant computation of the full circuit score by storing it and just passing it in to the function.

In [16]:
ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)
orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)

In [17]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)
logits_to_ave_logit_diff_2(ioi_logits_original, dataset)

tensor(8.2426, device='cuda:0')

In [18]:
def mean_ablate_by_lst(CIRCUIT, SEQ_POS_TO_KEEP, model, orig_score, print_output=True):
    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    # ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model_abl = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model_abl(dataset.toks)

    # orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)

    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

In [19]:
lst = [(0, 0)]

CIRCUIT = {
    "number mover": lst,
    "number mover 5": lst,
    "number mover 4": lst,
    "number mover 3": lst,
    "number mover 2": lst,
    "number mover 1": lst,
}

SEQ_POS_TO_KEEP = {
    "number mover": "end",
    "number mover 5": "S5",
    "number mover 4": "S4",
    "number mover 3": "S3",
    "number mover 2": "S2",
    "number mover 1": "S1",
}

model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
model_abl = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
ioi_logits_minimal = model_abl(dataset.toks)
logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)

tensor(-0.0447, device='cuda:0')

# Ablate the model and compare with original

In [20]:
lst = [(layer, head) for layer in range(24) for head in range(16)]

CIRCUIT = {
    "number mover": lst,
    "number mover 5": lst,
    "number mover 4": lst,
    "number mover 3": lst,
    "number mover 2": lst,
    "number mover 1": lst,
}

SEQ_POS_TO_KEEP = {
    "number mover": "end",
    "number mover 5": "S5",
    "number mover 4": "S4",
    "number mover 3": "S3",
    "number mover 2": "S2",
    "number mover 1": "S1",
}

mean_ablate_by_lst(CIRCUIT, SEQ_POS_TO_KEEP, model, orig_score, print_output=False).item()

100.0

In [21]:
def circuit_from_headsList(headsList):
    CIRCUIT = {
    "number mover": headsList,
    "number mover 5": headsList,
    "number mover 4": headsList,
    "number mover 3": headsList,
    "number mover 2": headsList,
    "number mover 1": headsList,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        "number mover 5": "S5",
        "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }
    return CIRCUIT, SEQ_POS_TO_KEEP

In [22]:
CIRCUIT, SEQ_POS_TO_KEEP = circuit_from_headsList([(0,0)])
mean_ablate_by_lst(CIRCUIT, SEQ_POS_TO_KEEP, model, orig_score, print_output=False).item()

-0.5424507856369019

In [23]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(24) for head in range(16)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(23, -1, -1):  # go thru all heads in a layer first
    for head in range(16):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        CIRCUIT, SEQ_POS_TO_KEEP = circuit_from_headsList(copy_circuit)
        new_score = mean_ablate_by_lst(CIRCUIT, SEQ_POS_TO_KEEP, model, orig_score, print_output=False).item()

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

Removed: (23, 0)
100.05462646484375


Removed: (23, 1)
99.90754699707031


Removed: (23, 2)
99.92961883544922


Removed: (23, 3)
99.65225982666016


Removed: (23, 4)
99.30018615722656


Removed: (23, 5)
99.90575408935547


Removed: (23, 6)
100.0451889038086


Removed: (23, 7)
100.4909439086914


Removed: (23, 8)
99.06684875488281


Removed: (23, 9)
99.07476806640625


Removed: (23, 10)
99.26447296142578


Removed: (23, 11)
99.05616760253906


Removed: (23, 12)
99.20416259765625


Removed: (23, 13)
99.84165954589844


Removed: (23, 14)
99.87984466552734


Removed: (23, 15)
99.71666717529297


Removed: (22, 0)
99.68083953857422


Removed: (22, 1)
99.62361145019531


Removed: (22, 2)
100.02324676513672


Removed: (22, 3)
100.0439224243164


Removed: (22, 4)
100.10868835449219


Removed: (22, 5)
100.6133041381836


Removed: (22, 6)
100.76860809326172


Removed: (22, 7)
100.71442413330078


Removed: (22, 8)
100.02642822265625


Removed: (22, 9)
100.20832824707031


Removed: (22, 10)
100.314

In [24]:
curr_circuit

[(0, 1),
 (0, 4),
 (0, 8),
 (0, 9),
 (0, 10),
 (0, 15),
 (1, 2),
 (1, 5),
 (1, 11),
 (1, 14),
 (2, 2),
 (2, 3),
 (2, 4),
 (2, 5),
 (2, 6),
 (2, 7),
 (2, 8),
 (2, 9),
 (2, 10),
 (2, 12),
 (2, 14),
 (2, 15),
 (3, 0),
 (3, 1),
 (3, 2),
 (3, 3),
 (3, 4),
 (3, 5),
 (3, 8),
 (3, 9),
 (3, 10),
 (3, 12),
 (3, 13),
 (3, 15),
 (4, 1),
 (4, 2),
 (4, 11),
 (4, 15),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 8),
 (5, 11),
 (5, 12),
 (5, 15),
 (6, 0),
 (6, 1),
 (6, 3),
 (6, 4),
 (6, 10),
 (6, 11),
 (6, 12),
 (6, 14),
 (6, 15),
 (7, 13),
 (7, 15),
 (8, 0),
 (8, 1),
 (8, 7),
 (8, 9),
 (8, 10),
 (8, 12),
 (9, 1),
 (9, 3),
 (9, 4),
 (9, 5),
 (9, 9),
 (9, 14),
 (10, 1),
 (10, 3),
 (10, 14),
 (10, 15),
 (11, 0),
 (11, 1),
 (11, 4),
 (11, 5),
 (11, 6),
 (11, 7),
 (11, 12),
 (12, 4),
 (12, 8),
 (12, 12),
 (12, 15),
 (13, 0),
 (13, 6),
 (13, 7),
 (13, 8),
 (13, 13),
 (14, 5),
 (14, 14),
 (15, 8),
 (15, 9),
 (15, 12),
 (15, 13),
 (16, 3),
 (16, 11),
 (17, 3),
 (17, 12),
 (18, 3),
 (18, 9),
 (18, 10),
 (18, 11),
 (18, 13

In [25]:
len(curr_circuit)

107

384 total

## comapre to 246

In [26]:
circ_123 = [(0, 9), (1, 2), (1, 11), (1, 14), (2, 3), (2, 4), (2, 5), (2, 7), (2, 9), (2, 10), (2, 14), (3, 0), (3, 3), (3, 4), (3, 9), (3, 13), (3, 15), (4, 6), (4, 7), (4, 8), (4, 9), (4, 13), (5, 5), (5, 11), (5, 12), (5, 13), (6, 12), (6, 15), (7, 2), (8, 10), (9, 5), (10, 1), (10, 4), (10, 14), (11, 1), (11, 3), (11, 4), (11, 5), (11, 6), (12, 15), (13, 13), (14, 5), (14, 14), (15, 7), (15, 10), (17, 12), (19, 1)]
len(circ_123)

47

In [27]:
circ_246 =  [(0, 2), (0, 3), (0, 4), (0, 5), (0, 9), (0, 10), (0, 14), (1, 2), (1, 4), (1, 7), (1, 14), (2, 3), (2, 4), (2, 5), (2, 7), (2, 8), (2, 9), (2, 15), (3, 0), (3, 3), (3, 13), (3, 14), (3, 15), (4, 2), (4, 6), (4, 8), (4, 10), (4, 11), (5, 8), (6, 14), (6, 15), (7, 2), (7, 11), (7, 13), (8, 0), (9, 3), (9, 4), (9, 5), (9, 6), (9, 12), (9, 15), (10, 1), (10, 4), (10, 9), (10, 10), (10, 13), (10, 14), (11, 1), (11, 4), (11, 5), (11, 8), (12, 1), (12, 4), (12, 12), (12, 13), (12, 15), (13, 5), (13, 12), (13, 13), (14, 5), (14, 14), (15, 5), (15, 7), (15, 11), (15, 12), (15, 15), (16, 6), (16, 7), (16, 9), (16, 11), (16, 13), (16, 14), (17, 0), (17, 1), (17, 12), (18, 3), (18, 11), (18, 13), (19, 1), (19, 4), (20, 0), (20, 1), (20, 14), (21, 0), (21, 2), (21, 7)]
len(circ_246)

86

In [28]:
sorted(list(set(circ_123)-set(circ_246)))

[(1, 11),
 (2, 10),
 (2, 14),
 (3, 4),
 (3, 9),
 (4, 7),
 (4, 9),
 (4, 13),
 (5, 5),
 (5, 11),
 (5, 12),
 (5, 13),
 (6, 12),
 (8, 10),
 (11, 3),
 (11, 6),
 (15, 10)]

In [29]:
sorted(list(set(circ_246)-set(circ_123)))

[(0, 2),
 (0, 3),
 (0, 4),
 (0, 5),
 (0, 10),
 (0, 14),
 (1, 4),
 (1, 7),
 (2, 8),
 (2, 15),
 (3, 14),
 (4, 2),
 (4, 10),
 (4, 11),
 (5, 8),
 (6, 14),
 (7, 11),
 (7, 13),
 (8, 0),
 (9, 3),
 (9, 4),
 (9, 6),
 (9, 12),
 (9, 15),
 (10, 9),
 (10, 10),
 (10, 13),
 (11, 8),
 (12, 1),
 (12, 4),
 (12, 12),
 (12, 13),
 (13, 5),
 (13, 12),
 (15, 5),
 (15, 11),
 (15, 12),
 (15, 15),
 (16, 6),
 (16, 7),
 (16, 9),
 (16, 11),
 (16, 13),
 (16, 14),
 (17, 0),
 (17, 1),
 (18, 3),
 (18, 11),
 (18, 13),
 (19, 4),
 (20, 0),
 (20, 1),
 (20, 14),
 (21, 0),
 (21, 2),
 (21, 7)]

# manually rmv and check for most impt heads

https://colab.research.google.com/drive/12HF5UCvMERizkhOiYJKDziahgVq_3KD9#scrollTo=C2EgKgmJS4qb&line=1&uniqifier=1

In [30]:
# incr_circ = [(0, 10), (0, 11), (1, 2), (1, 6), (1, 7), (1, 8), (1, 9), (1, 14), (2, 4), (2, 5), (2, 6), (2, 7), (2, 9), (2, 10), (2, 15), (3, 4), (3, 5), (3, 10), (3, 11), (3, 12), (3, 13), (3, 15), (4, 1), (4, 2), (4, 3), (4, 13), (5, 0), (5, 1), (5, 4), (5, 5), (5, 8), (5, 11), (5, 13), (5, 15), (6, 0), (6, 4), (6, 5), (6, 10), (6, 11), (6, 12), (6, 13), (6, 14), (6, 15), (7, 4), (7, 6), (7, 15), (8, 0), (8, 10), (8, 12), (9, 14), (9, 15), (10, 3), (10, 14), (10, 15), (11, 0), (11, 1), (11, 2), (11, 3), (11, 4), (11, 5), (11, 6), (11, 7), (12, 8), (12, 15), (13, 0), (13, 13), (14, 5), (14, 11), (14, 14), (15, 5), (15, 7), (15, 12), (15, 13), (16, 8), (16, 11), (16, 14), (16, 15), (17, 3), (18, 13), (19, 12), (20, 1), (20, 7), (20, 14), (21, 0), (21, 1), (21, 7), (21, 14), (22, 8)]
incr_circ = curr_circuit
CIRCUIT, SEQ_POS_TO_KEEP = circuit_from_headsList(incr_circ)
mean_ablate_by_lst(CIRCUIT, SEQ_POS_TO_KEEP, model, orig_score, print_output=True).item()

Average logit difference (circuit / full) %: 97.1925


97.19249725341797

In [31]:
lh_scores = {}
for lh in incr_circ:
    copy_circuit = incr_circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    CIRCUIT, SEQ_POS_TO_KEEP = circuit_from_headsList(copy_circuit)
    new_score = mean_ablate_by_lst(CIRCUIT, SEQ_POS_TO_KEEP, model, orig_score, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 96.5166
removed: (0, 4)
Average logit difference (circuit / full) %: 96.6474
removed: (0, 8)
Average logit difference (circuit / full) %: 96.6877
removed: (0, 9)
Average logit difference (circuit / full) %: 95.9630
removed: (0, 10)
Average logit difference (circuit / full) %: 96.7188
removed: (0, 15)
Average logit difference (circuit / full) %: 96.8809
removed: (1, 2)
Average logit difference (circuit / full) %: 96.4395
removed: (1, 5)
Average logit difference (circuit / full) %: 96.8441
removed: (1, 11)
Average logit difference (circuit / full) %: 95.4817
removed: (1, 14)
Average logit difference (circuit / full) %: 92.9865
removed: (2, 2)
Average logit difference (circuit / full) %: 96.9801
removed: (2, 3)
Average logit difference (circuit / full) %: 96.4369
removed: (2, 4)
Average logit difference (circuit / full) %: 94.2528
removed: (2, 5)
Average logit difference (circuit / full) %: 96.6296
removed: (2, 6)
Average logit 

In [32]:
lh_scores

{(0, 1): 96.5166244506836,
 (0, 4): 96.64743041992188,
 (0, 8): 96.68769073486328,
 (0, 9): 95.96304321289062,
 (0, 10): 96.71881103515625,
 (0, 15): 96.88088989257812,
 (1, 2): 96.43949890136719,
 (1, 5): 96.84407806396484,
 (1, 11): 95.48165130615234,
 (1, 14): 92.98648071289062,
 (2, 2): 96.98013305664062,
 (2, 3): 96.43694305419922,
 (2, 4): 94.2528076171875,
 (2, 5): 96.62957000732422,
 (2, 6): 97.13822174072266,
 (2, 7): 96.7586441040039,
 (2, 8): 96.91163635253906,
 (2, 9): 96.64895629882812,
 (2, 10): 95.77836608886719,
 (2, 12): 96.69879150390625,
 (2, 14): 96.60031127929688,
 (2, 15): 96.87557220458984,
 (3, 0): 94.92894744873047,
 (3, 1): 96.99434661865234,
 (3, 2): 96.73584747314453,
 (3, 3): 93.64968872070312,
 (3, 4): 96.95536804199219,
 (3, 5): 96.58637237548828,
 (3, 8): 96.63741302490234,
 (3, 9): 96.57608795166016,
 (3, 10): 96.47249603271484,
 (3, 12): 97.07164764404297,
 (3, 13): 94.6766586303711,
 (3, 15): 96.22369384765625,
 (4, 1): 94.40135955810547,
 (4, 2): 95.

In [33]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(8, 10): 51.31289291381836,
 (14, 14): 74.38078308105469,
 (11, 4): 75.99982452392578,
 (10, 14): 85.2322769165039,
 (10, 3): 88.27239227294922,
 (7, 15): 88.93096923828125,
 (6, 15): 88.94437408447266,
 (6, 1): 90.62203216552734,
 (6, 12): 90.757080078125,
 (13, 13): 92.2644271850586,
 (11, 1): 92.62682342529297,
 (20, 14): 92.7154541015625,
 (15, 12): 92.7745590209961,
 (14, 5): 92.86366271972656,
 (1, 14): 92.98648071289062,
 (5, 8): 93.05807495117188,
 (5, 12): 93.06798553466797,
 (16, 11): 93.26941680908203,
 (5, 11): 93.57185363769531,
 (3, 3): 93.64968872070312,
 (8, 9): 93.92313385009766,
 (8, 12): 94.05313873291016,
 (11, 6): 94.2373275756836,
 (2, 4): 94.2528076171875,
 (7, 13): 94.25381469726562,
 (4, 1): 94.40135955810547,
 (18, 3): 94.50316619873047,
 (11, 5): 94.5951156616211,
 (12, 8): 94.64053344726562,
 (15, 13): 94.64264678955078,
 (3, 13): 94.6766586303711,
 (10, 1): 94.70188903808594,
 (13, 0): 94.7420883178711,
 (9, 14): 94.83311462402344,
 (3, 0): 94.928947448730