<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-o3tsl1as
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-o3tsl1as
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.0-py3-none-any.whl (260 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.0/261.0 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

## Import functions from repo

In [None]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

In [None]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

In [None]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

### test prompts

In [None]:
modeltest = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
example_prompt = "The war lasted from the year 1750 to the year 17"
example_answer = " 51"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'The', ' war', ' lasted', ' from', ' the', ' year', ' 17', '50', ' to', ' the', ' year', ' 17']
Tokenized answer: [' 51']


Top 0th token. Logit: 30.08 Prob: 26.90% Token: |60|
Top 1th token. Logit: 29.10 Prob: 10.12% Token: |75|
Top 2th token. Logit: 29.02 Prob:  9.33% Token: |70|
Top 3th token. Logit: 28.62 Prob:  6.29% Token: |90|
Top 4th token. Logit: 28.43 Prob:  5.19% Token: |80|
Top 5th token. Logit: 28.28 Prob:  4.45% Token: |50|
Top 6th token. Logit: 27.82 Prob:  2.83% Token: |55|
Top 7th token. Logit: 27.41 Prob:  1.87% Token: |65|
Top 8th token. Logit: 27.34 Prob:  1.75% Token: |76|
Top 9th token. Logit: 27.17 Prob:  1.47% Token: |71|


In [None]:
example_prompt = "The war lasted from the year 1701 to the year 17"
example_answer = " 51"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'The', ' war', ' lasted', ' from', ' the', ' year', ' 17', '01', ' to', ' the', ' year', ' 17']
Tokenized answer: [' 51']


Top 0th token. Logit: 26.03 Prob:  5.33% Token: |20|
Top 1th token. Logit: 26.00 Prob:  5.19% Token: |15|
Top 2th token. Logit: 25.89 Prob:  4.62% Token: |02|
Top 3th token. Logit: 25.85 Prob:  4.45% Token: |12|
Top 4th token. Logit: 25.75 Prob:  4.02% Token: |10|
Top 5th token. Logit: 25.64 Prob:  3.62% Token: |18|
Top 6th token. Logit: 25.58 Prob:  3.38% Token: |05|
Top 7th token. Logit: 25.56 Prob:  3.32% Token: |03|
Top 8th token. Logit: 25.53 Prob:  3.23% Token: |16|
Top 9th token. Logit: 25.52 Prob:  3.20% Token: |04|


#### less-than using in-context

See pg22 of greater-than paper: we address the tasks “The <noun> ended in the year 17YY
and started in the year 17” and “The <noun> lasted from the year 7YY BC to the year 7”, which do
use our circuit, but should not do so.

These complete it with 'greater-than'

In [None]:
example_prompt = "The war ended in the year 1750 and started in the year 17"
example_answer = " 49"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'The', ' war', ' ended', ' in', ' the', ' year', ' 17', '50', ' and', ' started', ' in', ' the', ' year', ' 17']
Tokenized answer: [' 49']


Top 0th token. Logit: 26.42 Prob: 18.75% Token: |60|
Top 1th token. Logit: 25.20 Prob:  5.55% Token: |50|
Top 2th token. Logit: 25.19 Prob:  5.51% Token: |61|
Top 3th token. Logit: 25.03 Prob:  4.66% Token: |51|
Top 4th token. Logit: 24.97 Prob:  4.38% Token: |55|
Top 5th token. Logit: 24.74 Prob:  3.51% Token: |52|
Top 6th token. Logit: 24.68 Prob:  3.30% Token: |75|
Top 7th token. Logit: 24.62 Prob:  3.10% Token: |59|
Top 8th token. Logit: 24.62 Prob:  3.10% Token: |70|
Top 9th token. Logit: 24.60 Prob:  3.05% Token: |56|


In [None]:
example_prompt = "The war ended in the year 1790 and started in the year 1780. The war ended in the year 1750 and started in the year 17"
example_answer = " 49"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'The', ' war', ' ended', ' in', ' the', ' year', ' 17', '90', ' and', ' started', ' in', ' the', ' year', ' 17', '80', '.', ' The', ' war', ' ended', ' in', ' the', ' year', ' 17', '50', ' and', ' started', ' in', ' the', ' year', ' 17']
Tokenized answer: [' 49']


Top 0th token. Logit: 26.13 Prob: 16.52% Token: |80|
Top 1th token. Logit: 25.79 Prob: 11.82% Token: |60|
Top 2th token. Logit: 25.79 Prob: 11.77% Token: |70|
Top 3th token. Logit: 25.65 Prob: 10.25% Token: |90|
Top 4th token. Logit: 25.36 Prob:  7.65% Token: |50|
Top 5th token. Logit: 24.92 Prob:  4.94% Token: |75|
Top 6th token. Logit: 24.43 Prob:  3.01% Token: |85|
Top 7th token. Logit: 24.25 Prob:  2.53% Token: |40|
Top 8th token. Logit: 24.04 Prob:  2.04% Token: |55|
Top 9th token. Logit: 24.00 Prob:  1.97% Token: |65|


In [None]:
example_prompt = "90 is less than 100. 80 is less than 90. 70 is less than"
example_answer = " 80"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '90', ' is', ' less', ' than', ' 100', '.', ' 80', ' is', ' less', ' than', ' 90', '.', ' 70', ' is', ' less', ' than']
Tokenized answer: [' 80']


Top 0th token. Logit: 22.02 Prob: 95.21% Token: | 70|
Top 1th token. Logit: 17.60 Prob:  1.15% Token: | 90|
Top 2th token. Logit: 17.60 Prob:  1.15% Token: | 80|
Top 3th token. Logit: 16.81 Prob:  0.52% Token: | 75|
Top 4th token. Logit: 16.29 Prob:  0.31% Token: | 20|
Top 5th token. Logit: 16.06 Prob:  0.25% Token: | 65|
Top 6th token. Logit: 15.80 Prob:  0.19% Token: | 50|
Top 7th token. Logit: 15.63 Prob:  0.16% Token: | 60|
Top 8th token. Logit: 15.00 Prob:  0.08% Token: | 40|
Top 9th token. Logit: 14.66 Prob:  0.06% Token: | 85|


## test tokenizer to make pos_dict, prompt_dict

In [None]:
model.tokenizer('1701')

{'input_ids': [1558, 486], 'attention_mask': [1, 1]}

In [None]:
model.tokenizer('01')

{'input_ids': [486], 'attention_mask': [1]}

In [None]:
model.tokenizer.convert_tokens_to_string(model.tokenizer.convert_ids_to_tokens(model.tokenizer('1701')['input_ids']))

'1701'

In [None]:
# model.tokenizer.decode(486)
model.tokenizer.decode([486])

'01'

In [None]:
len(model.tokenizer()['input_ids'])

11

### get rid of "Ġ" +

In [None]:
model.tokenizer("1711")

{'input_ids': [1558, 1157], 'attention_mask': [1, 1]}

In [None]:
model.tokenizer.tokenize("1711")

['17', '11']

B/c 11 is after 17 and doesn't have a space in front, get rid of "Ġ" +

In [None]:
def get_prompts_pos_dicts(input_text, YY):
    pos_dict = {}
    prompt_dict = {}
    tokens_list = model.tokenizer(input_text)['input_ids']

    for index, token in enumerate(tokens_list):
        token_as_string = model.tokenizer.decode(token)
        # if token_as_string == YY:
        #     key = 'YY'
        # else:
        #     key = 'T'+str(index)
        key = 'T'+str(index)
        pos_dict[key] = index
        prompt_dict[key] = token_as_string
    prompt_dict['text'] = input_text

    return pos_dict, prompt_dict

In [None]:
input_text = 'The war lasted from the year 1750 to the year 17'
get_prompts_pos_dicts(input_text, '50')

({'T0': 0,
  'T1': 1,
  'T2': 2,
  'T3': 3,
  'T4': 4,
  'T5': 5,
  'T6': 6,
  'T7': 7,
  'T8': 8,
  'T9': 9,
  'T10': 10,
  'T11': 11},
 {'T0': 'The',
  'T1': ' war',
  'T2': ' lasted',
  'T3': ' from',
  'T4': ' the',
  'T5': ' year',
  'T6': ' 17',
  'T7': '50',
  'T8': ' to',
  'T9': ' the',
  'T10': ' year',
  'T11': ' 17',
  'text': 'The war lasted from the year 1750 to the year 17'})

## make datasets

In [None]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        # self.io_tokenIDs = [
        #     self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        # ]
        # self.s_tokenIDs = [
        #     self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        # ]

        # self.YY = int()

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [None]:
def generate_prompts_list(x, y):
    prompts_list = []
    for YY in range(x, y):
        input_text = f'The war lasted from the year 17{YY} to the year 17'
        pos_dict, prompt_dict = get_prompts_pos_dicts(input_text, YY)
        prompts_list.append(prompt_dict)
    return pos_dict, prompts_list

# prompts_list = generate_prompts_list(45, 55)
pos_dict, prompts_list = generate_prompts_list(50, 51)
# prompts_list
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [None]:
def generate_prompts_list_corr(x, y):
    prompts_list = []
    # for YY in range(x, y):
    YY = '01'
    input_text = f'The war lasted from the year 17{YY} to the year 17'
    pos_dict, prompt_dict = get_prompts_pos_dicts(input_text, YY)
    prompts_list.append(prompt_dict)
    return pos_dict, prompts_list

# prompts_list = generate_prompts_list(45, 55)
pos_dict, prompts_list_2 = generate_prompts_list_corr(50, 51)
# prompts_list_2
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

## obtain the logits of each number between YY and 99

In [None]:
logits = torch.randn(32, 100, 256)  # [batch size, seq len, vocab size]
logits[range(logits.size(0)), [99]*logits.size(0), [5]*logits.size(0)] == logits[range(logits.size(0)), 99, 5]

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True])

In [None]:
logits[range(logits.size(0)), 99, 5].size() # logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]

torch.Size([32])

In [None]:
# obtain the logits of each number between YY and 99, where YY is a two digit integer

import torch

def get_logits_for_range(logits, start_num, end_num, vocab):
    """
    :param tensor: The logits tensor with dimensions [batch size, seq len, vocab size]
    :param start_num: The starting number
    :param end_num: The ending number
    :param vocab: A list or dictionary mapping of the vocabulary
    :return: A tensor containing logits for numbers between start_num and end_num
    """
    # Getting indices for numbers between start_num and end_num
    # indices = [vocab[str(num)] for num in range(start_num, end_num+1)]
    indices = []
    for num in range(start_num, end_num+1):
        num_as_vocabID = model.tokenizer(str(num))['input_ids'][0]
        indices.append(num_as_vocabID)

    # Extract logits for these indices
    logits_for_range = logits[:, logits.size(1)-1, indices]

    return logits_for_range

# Example usage:
tensor = torch.randn(32, 100, 50000)  # [batch size, seq len, vocab size]
vocab = {str(i): i for i in range(50000)}  # Example vocab mapping
YY = 87
logits_greaterThan = get_logits_for_range(tensor, YY, 99, vocab)
print(logits_greaterThan.shape)  # Should be [batch size, seq len, (99-YY+1)]

In [None]:
logits_greaterThan_sum = logits_greaterThan.sum(dim=1)
logits_greaterThan_sum.size()

torch.Size([32])

In [None]:
logits_greaterThan_sum.mean()

tensor(-0.2856)

In [None]:
# int(prompts[input_ind]['YY'])  # YY in an input
# we want the first token > int(prompts[batch_ind]['YY'])
# search thru entire vocab space until find token > int(prompts[input_ind]['YY'])
# how do we convert index in vocab space to the token it represents?

# greater_than_Y_idx = (io_logits > Y).nonzero(as_tuple=True)[0].item()
# first_token_greater_than_Y = dataset.io_tokenIDs[greater_than_Y_idx]

# Ablation Expm Functions

In [None]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.
    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Get the right logits; anything greater than YY
    # range(logits.size(0)) for every input in the batch
    # dataset.word_idx["end"]: at the last pos, so "what's the next prediction after end?"
    # what's the logit of the YY token (whose pos at an input seq is recorded in the dataset by dataset.YY_tokenIDs)

    # YY = dataset.YY  # only correct dataset indices of corr and incorr tokens matters
    YY = 50

    logits_greaterThan = get_logits_for_range(logits, YY, 99, vocab)
    logits_greaterThan_sum = logits_greaterThan.sum(dim=1)

    # get the wrong logits; anything less than YY
    logits_lessThan = get_logits_for_range(logits, 00, YY-1, vocab)
    logits_lessThan_sum = logits_lessThan.sum(dim=1)

    # Find logit difference of corr minus incorr; sum up all tokens between YY and 99, minus sum of all YY and 00
    answer_logit_diff = logits_greaterThan_sum - logits_lessThan_sum
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [None]:
def mean_ablate_by_lst(lst, model, print_output=True):
    # CIRCUIT = {
    #     "number mover": lst,
    #     "number mover 2": lst,
    # }

    # SEQ_POS_TO_KEEP = {
    #     "number mover": "end",
    #     "number mover 2": "YY",
    # }
    CIRCUIT = {}
    SEQ_POS_TO_KEEP = {}

    for ind, key in enumerate(pos_dict.keys()):
        headName = "head" + str(ind)
        CIRCUIT[headName] = lst
        SEQ_POS_TO_KEEP[headName] = key

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)  # make sure text in clean vs corr have same num tokens for each prompt

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    return new_score

## test fns

In [None]:
CIRCUIT = {}
SEQ_POS_TO_KEEP = {}

for ind, key in enumerate(pos_dict.keys()):
    headName = "head" + str(ind)
    CIRCUIT[headName] = [(layer, head) for layer in range(12) for head in range(12)]
    SEQ_POS_TO_KEEP[headName] = key

In [None]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
ioi_logits_minimal = model(dataset.toks)

In [None]:
YY = 50

logits_greaterThan = get_logits_for_range(ioi_logits_original, YY, 99, vocab)
logits_greaterThan_sum = logits_greaterThan.sum(dim=1)

# get the wrong logits; anything less than YY
logits_lessThan = get_logits_for_range(ioi_logits_original, 00, YY-1, vocab)
logits_lessThan_sum = logits_lessThan.sum(dim=1)

# Find logit difference of corr minus incorr; sum up all tokens between YY and 99, minus sum of all YY and 00
answer_logit_diff = logits_greaterThan_sum - logits_lessThan_sum

In [None]:
logits_greaterThan.size()

torch.Size([1, 50])

# Ablate the model tests

## Test Greater-Than vs other circuits

See how greater-than circuit performs on the greater-than task; it should be similar to the paper. Else, either greater-than paper has issues (less likely) or this mean ablation code/setup was not generalized correctly (more likely).

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 85.1862


Did it get right? Let's try an incompelte circuit for sanity check.

In [None]:
greater_than = []
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 1.8475


Likely still has score due to MLPs

In [None]:
greater_than = [(0, 1)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 1.8473


In [None]:
greater_than = [(0, 1), (0, 3), (9,1)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 30.9436


In [None]:
import random
num_of_tuples = 9  # Number of tuples you want
greater_than = [(random.randint(0, 9), random.randint(0, 9)) for _ in range(num_of_tuples)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 2.1236


In [None]:
greater_than = [(layer, head) for layer in range(12) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 100.0000


### add heads to orig paper circ

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(10, 7)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 84.0478


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(layer, head) for layer in range(0) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 85.1862


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(layer, head) for layer in range(0, 4) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 81.5903


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(layer, head) for layer in range(0, 6) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 88.9415


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(layer, head) for layer in range(0, 9) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 97.0480


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(layer, head) for layer in range(8, 9) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 92.0383


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9,1)] + [(layer, head) for layer in range(5, 9) for head in range(12)]
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 97.0697


# Ablate by seq pos

In [None]:
def mean_ablate_by_seqpos(lst, model, print_output=True):
    # CIRCUIT = {
    #     "number mover": lst,
    #     "number mover 2": lst,
    # }

    # SEQ_POS_TO_KEEP = {
    #     "number mover": "end",
    #     "number mover 2": "YY",
    # }
    CIRCUIT = {}
    SEQ_POS_TO_KEEP = {}

    # for ind, key in enumerate(pos_dict.keys()):
    # ind = 7
    # key = "T7"

    for ind, key in enumerate(["T7", "end"]):
        headName = "head" + str(ind)
        CIRCUIT[headName] = lst
        SEQ_POS_TO_KEEP[headName] = key

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)  # make sure text in clean vs corr have same num tokens for each prompt

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    return new_score

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
new_score = mean_ablate_by_seqpos(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 85.1939


In [None]:
def mean_ablate_by_seqpos(lst, model, print_output=True):
    CIRCUIT = {
        "end": [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)],
        "YY": [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9)],
    }

    SEQ_POS_TO_KEEP = {
        "end": "end",
        "YY": "T7",
    }
    # CIRCUIT = {}
    # SEQ_POS_TO_KEEP = {}

    # # for ind, key in enumerate(pos_dict.keys()):
    # # ind = 7
    # # key = "T7"

    # for ind, key in enumerate(["T7", "end"]):
    #     headName = "head" + str(ind)
    #     CIRCUIT[headName] = lst
    #     SEQ_POS_TO_KEEP[headName] = key

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)  # make sure text in clean vs corr have same num tokens for each prompt

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    return new_score

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
new_score = mean_ablate_by_seqpos(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 85.3446


In [None]:
def mean_ablate_by_seqpos(lst, model, print_output=True):
    CIRCUIT = {
        "end": [(9, 1)],
        "YY": [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11)],
    }

    SEQ_POS_TO_KEEP = {
        "end": "end",
        "YY": "T7",
    }
    # CIRCUIT = {}
    # SEQ_POS_TO_KEEP = {}

    # # for ind, key in enumerate(pos_dict.keys()):
    # # ind = 7
    # # key = "T7"

    # for ind, key in enumerate(["T7", "end"]):
    #     headName = "head" + str(ind)
    #     CIRCUIT[headName] = lst
    #     SEQ_POS_TO_KEEP[headName] = key

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)  # make sure text in clean vs corr have same num tokens for each prompt

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    return new_score

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
new_score = mean_ablate_by_seqpos(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 43.3762


In [None]:
def mean_ablate_by_seqpos(lst, model, print_output=True):
    CIRCUIT = {
        "end": [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)],
        # "YY": [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11)],
    }

    SEQ_POS_TO_KEEP = {
        "end": "end",
        # "YY": "T7",
    }
    # CIRCUIT = {}
    # SEQ_POS_TO_KEEP = {}

    # # for ind, key in enumerate(pos_dict.keys()):
    # # ind = 7
    # # key = "T7"

    # for ind, key in enumerate(["T7", "end"]):
    #     headName = "head" + str(ind)
    #     CIRCUIT[headName] = lst
    #     SEQ_POS_TO_KEEP[headName] = key

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)  # make sure text in clean vs corr have same num tokens for each prompt

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    return new_score

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
new_score = mean_ablate_by_seqpos(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 72.7873


# Prune backwards

In [None]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(11, -1, -1):  # go thru all heads in a layer first
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True)

Average logit difference (circuit / full) %: 103.1818


tensor(103.1818, device='cuda:0')

In [None]:
backw_3 = curr_circuit.copy()
backw_3

[(0, 0),
 (0, 1),
 (0, 2),
 (0, 3),
 (0, 5),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 1),
 (1, 5),
 (1, 6),
 (1, 7),
 (1, 8),
 (2, 2),
 (2, 8),
 (2, 10),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 8),
 (4, 3),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 9),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 6),
 (5, 8),
 (5, 11),
 (6, 1),
 (6, 3),
 (6, 9),
 (6, 10),
 (6, 11),
 (7, 0),
 (7, 9),
 (7, 10),
 (7, 11),
 (8, 11),
 (9, 1),
 (11, 10)]

In [None]:
len(backw_3)

47