<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-qpdo1o_w
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-qpdo1o_w
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x78e83e32b6d0>

Plotting helper functions:

In [6]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [8]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9106, done.[K
remote: Counting objects: 100% (1820/1820), done.[K
remote: Compressing objects: 100% (289/289), done.[K
remote: Total 9106 (delta 1614), reused 1608 (delta 1528), pack-reused 7286[K
Receiving objects: 100% (9106/9106), 155.60 MiB | 13.07 MiB/s, done.
Resolving deltas: 100% (5507/5507), done.


In [9]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [10]:
import ioi_circuit_extraction as ioi_circuit_extraction

# test prompts

In [68]:
modeltest = HookedTransformer.from_pretrained("gpt2")

Loaded pretrained model gpt2 into HookedTransformer


In [69]:
example_prompt = " eleven twelve thirteen fourteen"
example_answer = " fifteen"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', ' eleven', ' twelve', ' thirteen', ' fourteen']
Tokenized answer: [' fifteen']


Top 0th token. Logit: 17.69 Prob: 45.98% Token: | fifteen|
Top 1th token. Logit: 16.41 Prob: 12.86% Token: | sixteen|
Top 2th token. Logit: 16.01 Prob:  8.59% Token: | fourteen|
Top 3th token. Logit: 15.49 Prob:  5.12% Token: | twenty|
Top 4th token. Logit: 14.70 Prob:  2.31% Token: | thirty|
Top 5th token. Logit: 14.65 Prob:  2.21% Token: | thirteen|
Top 6th token. Logit: 14.64 Prob:  2.19% Token: | seventeen|
Top 7th token. Logit: 14.39 Prob:  1.71% Token: | eighteen|
Top 8th token. Logit: 14.13 Prob:  1.31% Token: | eight|
Top 9th token. Logit: 13.98 Prob:  1.12% Token: | forty|


In [70]:
example_prompt = "eleven twelve thirteen fourteen"
example_answer = " fifteen"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'ele', 'ven', ' twelve', ' thirteen', ' fourteen']
Tokenized answer: [' fifteen']


Top 0th token. Logit: 16.30 Prob: 19.64% Token: | fifteen|
Top 1th token. Logit: 16.14 Prob: 16.64% Token: | fourteen|
Top 2th token. Logit: 15.92 Prob: 13.40% Token: | sixteen|
Top 3th token. Logit: 15.28 Prob:  7.07% Token: | thirteen|
Top 4th token. Logit: 14.93 Prob:  4.98% Token: | twenty|
Top 5th token. Logit: 14.87 Prob:  4.70% Token: | eighteen|
Top 6th token. Logit: 14.87 Prob:  4.69% Token: | seventeen|
Top 7th token. Logit: 14.41 Prob:  2.96% Token: | twelve|
Top 8th token. Logit: 14.30 Prob:  2.66% Token: | thirty|
Top 9th token. Logit: 13.95 Prob:  1.87% Token: | eight|


In [None]:
example_prompt = "fifteen sixteen seventeen eighteen"
example_answer = " nineteen"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'fif', 'teen', ' sixteen', ' seventeen', ' eighteen']
Tokenized answer: [' nineteen']


Top 0th token. Logit: 16.00 Prob: 29.44% Token: | nineteen|
Top 1th token. Logit: 15.54 Prob: 18.58% Token: | twenty|
Top 2th token. Logit: 15.08 Prob: 11.79% Token: | eighteen|
Top 3th token. Logit: 14.78 Prob:  8.69% Token: | seventeen|
Top 4th token. Logit: 13.50 Prob:  2.43% Token: | thirty|
Top 5th token. Logit: 13.20 Prob:  1.78% Token: | seventy|
Top 6th token. Logit: 13.13 Prob:  1.67% Token: | sixteen|
Top 7th token. Logit: 13.02 Prob:  1.50% Token: | fifteen|
Top 8th token. Logit: 12.92 Prob:  1.36% Token: | nineteenth|
Top 9th token. Logit: 12.69 Prob:  1.08% Token: | 19|


In [None]:
example_prompt = "fifteen sixteen seventeen eighteen nineteen twenty twenty-"
example_answer = "one"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'fif', 'teen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty', ' twenty', '-']
Tokenized answer: [' one']


Top 0th token. Logit: 18.18 Prob: 31.67% Token: |one|
Top 1th token. Logit: 17.21 Prob: 11.96% Token: |four|
Top 2th token. Logit: 17.17 Prob: 11.58% Token: |two|
Top 3th token. Logit: 16.98 Prob:  9.51% Token: |five|
Top 4th token. Logit: 16.73 Prob:  7.43% Token: |nine|
Top 5th token. Logit: 16.64 Prob:  6.77% Token: |first|
Top 6th token. Logit: 16.32 Prob:  4.94% Token: |three|
Top 7th token. Logit: 16.12 Prob:  4.04% Token: |seven|
Top 8th token. Logit: 15.82 Prob:  2.99% Token: |eight|
Top 9th token. Logit: 15.75 Prob:  2.80% Token: |six|


In [None]:
example_prompt = "fifteen sixteen seventeen eighteen nineteen twenty twenty-one twenty-"
example_answer = "two"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'fif', 'teen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty', ' twenty', '-', 'one', ' twenty', '-']
Tokenized answer: [' two']


Top 0th token. Logit: 19.88 Prob: 43.81% Token: |one|
Top 1th token. Logit: 19.11 Prob: 20.35% Token: |two|
Top 2th token. Logit: 18.01 Prob:  6.79% Token: |three|
Top 3th token. Logit: 17.87 Prob:  5.90% Token: |nine|
Top 4th token. Logit: 17.85 Prob:  5.75% Token: |five|
Top 5th token. Logit: 17.73 Prob:  5.14% Token: |four|
Top 6th token. Logit: 17.47 Prob:  3.93% Token: |seven|
Top 7th token. Logit: 17.24 Prob:  3.14% Token: |six|
Top 8th token. Logit: 16.50 Prob:  1.50% Token: |eight|
Top 9th token. Logit: 16.43 Prob:  1.39% Token: |first|


In [None]:
example_prompt = "fifteen sixteen seventeen eighteen nineteen twenty twentyone twenty-"
example_answer = "two"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'fif', 'teen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty', ' twenty', 'one', ' twenty', '-']
Tokenized answer: [' two']


Top 0th token. Logit: 20.11 Prob: 50.19% Token: |one|
Top 1th token. Logit: 18.64 Prob: 11.54% Token: |two|
Top 2th token. Logit: 18.30 Prob:  8.22% Token: |four|
Top 3th token. Logit: 18.04 Prob:  6.34% Token: |five|
Top 4th token. Logit: 17.75 Prob:  4.74% Token: |three|
Top 5th token. Logit: 17.59 Prob:  4.03% Token: |first|
Top 6th token. Logit: 17.54 Prob:  3.82% Token: |nine|
Top 7th token. Logit: 17.38 Prob:  3.27% Token: |six|
Top 8th token. Logit: 17.31 Prob:  3.05% Token: |seven|
Top 9th token. Logit: 16.66 Prob:  1.60% Token: |eight|


In [140]:
example_prompt = " twenty twentyone twenty"
example_answer = "two"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', ' twenty', ' twenty', 'one', ' twenty']
Tokenized answer: [' two']


Top 0th token. Logit: 15.09 Prob: 30.94% Token: |-|
Top 1th token. Logit: 14.56 Prob: 18.19% Token: | one|
Top 2th token. Logit: 14.03 Prob: 10.75% Token: |two|
Top 3th token. Logit: 13.49 Prob:  6.26% Token: | two|
Top 4th token. Logit: 12.84 Prob:  3.25% Token: |one|
Top 5th token. Logit: 12.84 Prob:  3.24% Token: | twenty|
Top 6th token. Logit: 11.94 Prob:  1.32% Token: | five|
Top 7th token. Logit: 11.62 Prob:  0.96% Token: | three|
Top 8th token. Logit: 11.62 Prob:  0.96% Token: | seven|
Top 9th token. Logit: 11.59 Prob:  0.94% Token: |three|


In [141]:
example_prompt = " sixty sixtyone sixtytwo sixtythree sixty"
example_answer = "four"
utils.test_prompt(example_prompt, example_answer, modeltest, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', ' sixty', ' sixty', 'one', ' sixty', 'two', ' s', 'ixt', 'yth', 'ree', ' sixty']
Tokenized answer: [' four']


Top 0th token. Logit: 16.08 Prob: 41.34% Token: |three|
Top 1th token. Logit: 14.59 Prob:  9.31% Token: |seven|
Top 2th token. Logit: 14.41 Prob:  7.77% Token: |two|
Top 3th token. Logit: 14.31 Prob:  7.05% Token: | three|
Top 4th token. Logit: 13.63 Prob:  3.59% Token: |one|
Top 5th token. Logit: 13.53 Prob:  3.22% Token: |-|
Top 6th token. Logit: 13.43 Prob:  2.93% Token: | one|
Top 7th token. Logit: 13.20 Prob:  2.32% Token: |four|
Top 8th token. Logit: 12.96 Prob:  1.82% Token: |five|
Top 9th token. Logit: 12.71 Prob:  1.43% Token: |six|


# Generate dataset with multiple prompts

In [11]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

Repalce io_tokens with correct answer (next, which is '5') and s_tokens with incorrect (current, which repeats)

In [12]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
}

In [62]:
def generate_prompts_list(x ,y):
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen', 'twenty']
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen']
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen']
    words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve']
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': words[i],
            'S2': words[i+1],
            'S3': words[i+2],
            'S4': words[i+3],
            'corr': words[i+4],
            'incorr': words[i],  # this is arbitrary
            'text': f"{words[i]} {words[i+1]} {words[i+2]} {words[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

# prompts_list = generate_prompts_list(0, 6)
# prompts_list = generate_prompts_list(0, 15)
prompts_list = generate_prompts_list(0, 8)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)
prompts_list

[{'S1': 'one',
  'S2': 'two',
  'S3': 'three',
  'S4': 'four',
  'corr': 'five',
  'incorr': 'one',
  'text': 'one two three four'},
 {'S1': 'two',
  'S2': 'three',
  'S3': 'four',
  'S4': 'five',
  'corr': 'six',
  'incorr': 'two',
  'text': 'two three four five'},
 {'S1': 'three',
  'S2': 'four',
  'S3': 'five',
  'S4': 'six',
  'corr': 'seven',
  'incorr': 'three',
  'text': 'three four five six'},
 {'S1': 'four',
  'S2': 'five',
  'S3': 'six',
  'S4': 'seven',
  'corr': 'eight',
  'incorr': 'four',
  'text': 'four five six seven'},
 {'S1': 'five',
  'S2': 'six',
  'S3': 'seven',
  'S4': 'eight',
  'corr': 'nine',
  'incorr': 'five',
  'text': 'five six seven eight'},
 {'S1': 'six',
  'S2': 'seven',
  'S3': 'eight',
  'S4': 'nine',
  'corr': 'ten',
  'incorr': 'six',
  'text': 'six seven eight nine'},
 {'S1': 'seven',
  'S2': 'eight',
  'S3': 'nine',
  'S4': 'ten',
  'corr': 'eleven',
  'incorr': 'seven',
  'text': 'seven eight nine ten'},
 {'S1': 'eight',
  'S2': 'nine',
  'S3': 't

In [45]:
import random

def generate_prompts_list_corr(x ,y):
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen', 'twenty']
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen']
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen']
    words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve']
    prompts_list = []
    for i in range(x, y):
        r1 = random.choice(words)
        r2 = random.choice(words)
        while True:
            r3_ind = random.randint(0,len(words)-1)
            r4_ind = random.randint(0,len(words)-1)
            if words[r3_ind] != words[r4_ind-1]:
                break
        r3 = words[r3_ind]
        r4 = words[r4_ind]
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': str(r1),
            'incorr': str(r4),
            'text': f"{r1} {r2} {r3} {r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(0, 8)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)
# prompts_list_2

In [52]:
prompts_list_2

[{'S1': 'eight',
  'S2': 'two',
  'S3': 'four',
  'S4': 'three',
  'corr': 'eight',
  'incorr': 'three',
  'text': 'eight two four three'},
 {'S1': 'seven',
  'S2': 'seven',
  'S3': 'five',
  'S4': 'four',
  'corr': 'seven',
  'incorr': 'four',
  'text': 'seven seven five four'},
 {'S1': 'eleven',
  'S2': 'twelve',
  'S3': 'twelve',
  'S4': 'six',
  'corr': 'eleven',
  'incorr': 'six',
  'text': 'eleven twelve twelve six'},
 {'S1': 'eight',
  'S2': 'six',
  'S3': 'two',
  'S4': 'four',
  'corr': 'eight',
  'incorr': 'four',
  'text': 'eight six two four'},
 {'S1': 'one',
  'S2': 'two',
  'S3': 'one',
  'S4': 'five',
  'corr': 'one',
  'incorr': 'five',
  'text': 'one two one five'},
 {'S1': 'ten',
  'S2': 'seven',
  'S3': 'nine',
  'S4': 'seven',
  'corr': 'ten',
  'incorr': 'seven',
  'text': 'ten seven nine seven'},
 {'S1': 'eight',
  'S2': 'four',
  'S3': 'seven',
  'S4': 'two',
  'corr': 'eight',
  'incorr': 'two',
  'text': 'eight four seven two'},
 {'S1': 'two',
  'S2': 'three',


In [50]:
dataset.toks.shape

torch.Size([8, 4])

In [51]:
dataset_2.toks.shape

torch.Size([8, 5])

In [53]:
dataset_2.toks

tensor([[26022,   734,  1440,  1115, 50256],
        [26548,  3598,  1936,  1440, 50256],
        [11129,   574, 14104, 14104,  2237],
        [26022,  2237,   734,  1440, 50256],
        [  505,   734,   530,  1936, 50256],
        [ 1452,  3598,  5193,  3598, 50256],
        [26022,  1440,  3598,   734, 50256],
        [11545,  1115, 14104,  3624, 50256]], dtype=torch.int32)

In [54]:
model.tokenizer.decode([26022])

'eight'

In [55]:
model.tokenizer.decode([50256])

'<|endoftext|>'

In [56]:
model.tokenizer.decode([2237])

' six'

In [57]:
model.tokenizer.decode([14104])

' twelve'

In [58]:
model.tokenizer.decode([11129,   574, 14104, 14104,  2237])

'eleven twelve twelve six'

In [60]:
model.tokenizer.decode([11129])

'ele'

In [59]:
model.tokenizer.decode([574])

'ven'

In [61]:
dataset.toks

tensor([[  505,   734,  1115,  1440],
        [11545,  1115,  1440,  1936],
        [15542,  1440,  1936,  2237],
        [14337,  1936,  2237,  3598],
        [13261,  2237,  3598,  3624],
        [19412,  3598,  3624,  5193],
        [26548,  3624,  5193,  3478],
        [26022,  5193,  3478, 22216]], dtype=torch.int32)

In [63]:
model.tokenizer.decode([26022,  5193,  3478, 22216])

'eight nine ten eleven'

In [65]:
model.tokenizer('eleven')

{'input_ids': [11129, 574], 'attention_mask': [1, 1]}

In [66]:
model.tokenizer('ten eleven')

{'input_ids': [1452, 22216], 'attention_mask': [1, 1]}

In [67]:
model.tokenizer(' eleven')

{'input_ids': [22216], 'attention_mask': [1]}

In [None]:
# [{'S1': 'ten',
#   'S2': 'seven',
#   'S3': 'five',
#   'S4': 'five',
#   'corr': 'ten',
#   'incorr': 'five',
#   'text': 'ten seven five five'},
#  {'S1': 'two',
#   'S2': 'one',
#   'S3': 'ten',
#   'S4': 'four',
#   'corr': 'two',
#   'incorr': 'four',
#   'text': 'two one ten four'},
#  {'S1': 'one',
#   'S2': 'seven',
#   'S3': 'five',
#   'S4': 'one',
#   'corr': 'one',
#   'incorr': 'one',
#   'text': 'one seven five one'},
#  {'S1': 'one',
#   'S2': 'six',
#   'S3': 'three',
#   'S4': 'five',
#   'corr': 'one',
#   'incorr': 'five',
#   'text': 'one six three five'},
#  {'S1': 'five',
#   'S2': 'nine',
#   'S3': 'seven',
#   'S4': 'four',
#   'corr': 'five',
#   'incorr': 'four',
#   'text': 'five nine seven four'},
#  {'S1': 'eight',
#   'S2': 'four',
#   'S3': 'ten',
#   'S4': 'four',
#   'corr': 'eight',
#   'incorr': 'four',
#   'text': 'eight four ten four'}]

In [22]:
prompts_list_2 = [{'S1': 'ten',
  'S2': 'seven',
  'S3': 'nine',
  'S4': 'seven',
  'corr': 'ten',
  'incorr': 'seven',
  'text': 'ten seven nine seven'},
 {'S1': 'six',
  'S2': 'ten',
  'S3': 'six',
  'S4': 'five',
  'corr': 'six',
  'incorr': 'five',
  'text': 'six ten six five'},
 {'S1': 'six',
  'S2': 'eight',
  'S3': 'six',
  'S4': 'three',
  'corr': 'six',
  'incorr': 'three',
  'text': 'six eight six three'},
 {'S1': 'two',
  'S2': 'two',
  'S3': 'six',
  'S4': 'one',
  'corr': 'two',
  'incorr': 'one',
  'text': 'two two six one'},
 {'S1': 'nine',
  'S2': 'five',
  'S3': 'nine',
  'S4': 'two',
  'corr': 'nine',
  'incorr': 'two',
  'text': 'nine five nine two'},
 {'S1': 'five',
  'S2': 'ten',
  'S3': 'eight',
  'S4': 'eight',
  'corr': 'five',
  'incorr': 'eight',
  'text': 'five ten eight eight'}]

dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

# Ablation Expm Functions

In [15]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [16]:
def mean_ablate_by_lst(lst, model, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

We can also prevent redundant computation of the full circuit score by storing it and just passing it in to the function.

In [17]:
def find_circuit_forw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [18]:
def find_circuit_backw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

# test

In [29]:
model.tokenizer('12')

{'input_ids': [1065], 'attention_mask': [1]}

In [47]:
# fb 80, digits incr
# https://colab.research.google.com/drive/1mFWmGAKtigFcqqWWMCwU7wWQY2HT5ZOo#scrollTo=lJEY-Zs2g_a5&line=1&uniqifier=1
circuit = [(1, 5), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (6, 3), (6, 8), (6, 10), (7, 0), (7, 2), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (11, 8)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

RuntimeError: ignored

# all space in front

In [82]:
def generate_prompts_list(x ,y):
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen', 'twenty']
    words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve']
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': words[i],
            'S2': words[i+1],
            'S3': words[i+2],
            'S4': words[i+3],
            'corr': words[i+4],
            'incorr': words[i+3],  # this is arbitrary
            'text': f"{words[i]}{words[i+1]}{words[i+2]}{words[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

# prompts_list = generate_prompts_list(0, 6)
# prompts_list = generate_prompts_list(0, 15)
prompts_list = generate_prompts_list(0, 8)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)
prompts_list

[{'S1': ' one',
  'S2': ' two',
  'S3': ' three',
  'S4': ' four',
  'corr': ' five',
  'incorr': ' four',
  'text': ' one two three four'},
 {'S1': ' two',
  'S2': ' three',
  'S3': ' four',
  'S4': ' five',
  'corr': ' six',
  'incorr': ' five',
  'text': ' two three four five'},
 {'S1': ' three',
  'S2': ' four',
  'S3': ' five',
  'S4': ' six',
  'corr': ' seven',
  'incorr': ' six',
  'text': ' three four five six'},
 {'S1': ' four',
  'S2': ' five',
  'S3': ' six',
  'S4': ' seven',
  'corr': ' eight',
  'incorr': ' seven',
  'text': ' four five six seven'},
 {'S1': ' five',
  'S2': ' six',
  'S3': ' seven',
  'S4': ' eight',
  'corr': ' nine',
  'incorr': ' eight',
  'text': ' five six seven eight'},
 {'S1': ' six',
  'S2': ' seven',
  'S3': ' eight',
  'S4': ' nine',
  'corr': ' ten',
  'incorr': ' nine',
  'text': ' six seven eight nine'},
 {'S1': ' seven',
  'S2': ' eight',
  'S3': ' nine',
  'S4': ' ten',
  'corr': ' eleven',
  'incorr': ' ten',
  'text': ' seven eight nine 

In [73]:
import random

def generate_prompts_list_corr(x ,y):
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen', 'twenty']
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen']
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen']
    words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve']
    prompts_list = []
    for i in range(x, y):
        r1 = random.choice(words)
        r2 = random.choice(words)
        while True:
            r3_ind = random.randint(0,len(words)-1)
            r4_ind = random.randint(0,len(words)-1)
            if words[r3_ind] != words[r4_ind-1]:
                break
        r3 = words[r3_ind]
        r4 = words[r4_ind]
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': str(r1),
            'incorr': str(r4),
            'text': f"{r1} {r2} {r3} {r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(0, 8)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)
# prompts_list_2

In [75]:
dataset_2.toks

tensor([[ 1440,   220,  3478,   220, 14104,   220,  3624],
        [ 1440,   220,  1936,   220, 22216,   220,  1936],
        [  734,   220,  1115,   220,  3478,   220,  1936],
        [ 3598,   220,  3598,   220,  1936,   220,  5193],
        [ 1440,   220, 22216,   220,  1936,   220,  1936],
        [ 2237,   220,  3624,   220,  1115,   220,  1936],
        [ 5193,   220,  3598,   220,  2237,   220,  1440],
        [ 3624,   220,  5193,   220, 22216,   220,  2237]], dtype=torch.int32)

In [76]:
model.tokenizer.decode([1440,   220,  3478,   220, 14104,   220,  3624])

' four  ten  twelve  eight'

In [77]:
model.tokenizer.decode([1440])

' four'

In [78]:
model.tokenizer.decode([220,  3478])

'  ten'

In [74]:
# fb 80, digits incr
# https://colab.research.google.com/drive/1mFWmGAKtigFcqqWWMCwU7wWQY2HT5ZOo#scrollTo=lJEY-Zs2g_a5&line=1&uniqifier=1
circuit = [(1, 5), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (6, 3), (6, 8), (6, 10), (7, 0), (7, 2), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (11, 8)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

RuntimeError: ignored

Make sure to get rid of the spaces too!

In [81]:
import random

def generate_prompts_list_corr(x ,y):
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen', 'twenty']
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen']
    # words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen']
    words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve']
    prompts_list = []
    for i in range(x, y):
        r1 = random.choice(words)
        r2 = random.choice(words)
        while True:
            r3_ind = random.randint(0,len(words)-1)
            r4_ind = random.randint(0,len(words)-1)
            if words[r3_ind] != words[r4_ind-1]:
                break
        r3 = words[r3_ind]
        r4 = words[r4_ind]
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': str(r1),
            'incorr': str(r4),
            'text': f"{r1}{r2}{r3}{r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(0, 8)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)
prompts_list_2

[{'S1': ' ten',
  'S2': ' eleven',
  'S3': ' five',
  'S4': ' two',
  'corr': ' ten',
  'incorr': ' two',
  'text': ' ten eleven five two'},
 {'S1': ' eight',
  'S2': ' four',
  'S3': ' twelve',
  'S4': ' five',
  'corr': ' eight',
  'incorr': ' five',
  'text': ' eight four twelve five'},
 {'S1': ' ten',
  'S2': ' two',
  'S3': ' nine',
  'S4': ' two',
  'corr': ' ten',
  'incorr': ' two',
  'text': ' ten two nine two'},
 {'S1': ' two',
  'S2': ' twelve',
  'S3': ' twelve',
  'S4': ' five',
  'corr': ' two',
  'incorr': ' five',
  'text': ' two twelve twelve five'},
 {'S1': ' twelve',
  'S2': ' one',
  'S3': ' twelve',
  'S4': ' twelve',
  'corr': ' twelve',
  'incorr': ' twelve',
  'text': ' twelve one twelve twelve'},
 {'S1': ' five',
  'S2': ' seven',
  'S3': ' five',
  'S4': ' four',
  'corr': ' five',
  'incorr': ' four',
  'text': ' five seven five four'},
 {'S1': ' six',
  'S2': ' nine',
  'S3': ' five',
  'S4': ' two',
  'corr': ' six',
  'incorr': ' two',
  'text': ' six nine

In [83]:
# fb 80, digits incr
# https://colab.research.google.com/drive/1mFWmGAKtigFcqqWWMCwU7wWQY2HT5ZOo#scrollTo=lJEY-Zs2g_a5&line=1&uniqifier=1
circuit = [(1, 5), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (6, 3), (6, 8), (6, 10), (7, 0), (7, 2), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (11, 8)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: nan


nan

In [86]:
lst = circuit
CIRCUIT = {
        "number mover": lst,
        "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

SEQ_POS_TO_KEEP = {
    "number mover": "end",
    "number mover 4": "S4",
    "number mover 3": "S3",
    "number mover 2": "S2",
    "number mover 1": "S1",
}


model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
ioi_logits_minimal = model(dataset.toks)

orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)

In [92]:
io_logits = ioi_logits_original[range(ioi_logits_original.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]

In [93]:
s_logits = ioi_logits_original[range(ioi_logits_original.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
# Find logit difference
answer_logit_diff = io_logits - s_logits
answer_logit_diff.mean()

tensor(0., device='cuda:0')

In [95]:
io_logits

tensor([5.8010, 5.0481, 6.4777, 6.5491, 7.3650, 6.9085, 7.0106, 7.5116],
       device='cuda:0')

In [94]:
s_logits

tensor([5.8010, 5.0481, 6.4777, 6.5491, 7.3650, 6.9085, 7.0106, 7.5116],
       device='cuda:0')

In [96]:
dataset.io_tokenIDs

[220, 220, 220, 220, 220, 220, 220, 220]

In [97]:
dataset.s_tokenIDs

[220, 220, 220, 220, 220, 220, 220, 220]

In [98]:
model.tokenizer.decode([220])

' '

In [87]:
orig_score

tensor(0., device='cuda:0')

In [88]:
new_score

tensor(0., device='cuda:0')

In [100]:
words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve']

i=0
print(words[i+4])
print(words[i+3])

 five
 four


In [102]:
model.tokenizer.encode(" " + words[i+4])[0]

220

In [103]:
model.tokenizer.encode(" " + words[i+4])

[220, 1936]

The issue is that it tries to add a space in front! Redo dataset making without spaces

# redo dataset without spaces

In [109]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(prompt["corr"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [119]:
def generate_prompts_list(x ,y):
    words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve', ' thirteen', ' fourteen', ' fifteen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty']
    # words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve']
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': words[i],
            'S2': words[i+1],
            'S3': words[i+2],
            'S4': words[i+3],
            'corr': words[i+4],
            'incorr': words[i+3],  # this is arbitrary
            'text': f"{words[i]}{words[i+1]}{words[i+2]}{words[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

# prompts_list = generate_prompts_list(0, 8)
prompts_list = generate_prompts_list(0, 16)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [136]:
import random

def generate_prompts_list_corr(x ,y):
    words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve', ' thirteen', ' fourteen', ' fifteen', ' sixteen', ' seventeen', ' eighteen', ' nineteen', ' twenty']
    # words = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten', ' eleven', ' twelve']
    prompts_list = []
    for i in range(x, y):
        r1 = random.choice(words)
        r2 = random.choice(words)
        while True:
            r3_ind = random.randint(0,len(words)-1)
            r4_ind = random.randint(0,len(words)-1)
            if words[r3_ind] != words[r4_ind-1]:
                break
        r3 = words[r3_ind]
        r4 = words[r4_ind]
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': str(r1),
            'incorr': str(r4),
            'text': f"{r1}{r2}{r3}{r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

# prompts_list_2 = generate_prompts_list_corr(0, 8)
prompts_list_2 = generate_prompts_list_corr(0, 16)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

In [137]:
circuit = [(1, 5), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (6, 3), (6, 8), (6, 10), (7, 0), (7, 2), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (11, 8)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 57.2228


57.222774505615234

# try other tasks circs

## gt, IOI

In [None]:
# greater-than
circuit = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 20.8534


20.853376388549805

In [None]:
# IOI
circuit = [(0, 1), (0, 10), (2, 2), (3, 0), (4, 11), (5, 5), (5, 8), (5, 9), (6, 9), (7, 3), (7, 9), (8, 6), (8, 10), (9, 0), (9, 6), (9, 7), (9, 9), (10, 0), (10, 1), (10, 2), (10, 6), (10, 7), (10, 10), (11, 2), (11, 9), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 12.2694


12.269357681274414

## fb 80

In [None]:
# fb 80, digits incr
# https://colab.research.google.com/drive/1mFWmGAKtigFcqqWWMCwU7wWQY2HT5ZOo#scrollTo=lJEY-Zs2g_a5&line=1&uniqifier=1
circuit = [(1, 5), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 10), (4, 11), (5, 0), (5, 2), (5, 3), (5, 4), (5, 6), (6, 3), (6, 8), (6, 10), (7, 0), (7, 2), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (11, 8)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 74.7523


74.7523422241211

In [None]:
# fb 80, numwords
# https://colab.research.google.com/drive/1QTv-4osLHadCAay0beew-xlXszPCG88s#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(3, 2), (4, 4), (4, 8), (4, 10), (4, 11), (5, 5), (5, 6), (5, 7), (5, 8), (6, 1), (6, 7), (6, 9), (6, 10), (7, 0), (7, 2), (7, 5), (7, 6), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (10, 2)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 81.9778


81.9777603149414

In [None]:
# fb 80, months
# https://colab.research.google.com/drive/1lhQqlizYGMC11vzp6I9mJ3dyxIr8tV3l#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(4, 4), (7, 11), (8, 6), (8, 9), (8, 11), (9, 1), (9, 5), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 36.0939


36.09392547607422

## bf 97

In [19]:
# digits incr
circuit = [(0, 1), (0, 2), (0, 3), (0, 5), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (1, 7), (2, 0), (2, 2), (2, 3), (2, 4), (2, 6), (2, 9), (3, 3), (3, 7), (3, 11), (4, 4), (4, 7), (4, 10), (5, 2), (5, 3), (5, 4), (5, 6), (5, 8), (5, 11), (6, 1), (6, 3), (6, 6), (6, 8), (6, 10), (7, 2), (7, 6), (7, 8), (7, 10), (7, 11), (8, 6), (8, 8), (8, 11), (9, 1), (11, 10)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 74.0674


74.06744384765625

In [23]:
# numwords
# https://colab.research.google.com/drive/1QTv-4osLHadCAay0beew-xlXszPCG88s#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
circuit = [(0, 1), (1, 0), (1, 5), (3, 2), (4, 4), (4, 8), (4, 10), (5, 4), (5, 6), (5, 8), (6, 9), (6, 10), (7, 7), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (9, 5), (10, 2), (11, 8)]
mean_ablate_by_lst(circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 93.0875


93.0875015258789

In [21]:
# months
# https://colab.research.google.com/drive/1lhQqlizYGMC11vzp6I9mJ3dyxIr8tV3l#scrollTo=563kZf_4r_mw&line=2&uniqifier=1
# circuit = [(4, 4), (7, 11), (8, 6), (8, 9), (8, 11), (9, 1), (9, 5), (11, 10)]
# mean_ablate_by_lst(circuit, model, print_output=True).item()

# Prune fwds-backwds iteratively

### iter fwd backw, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
99.6773910522461

Removed: (0, 2)
100.02043914794922

Removed: (0, 3)
98.45348358154297

Removed: (0, 4)
99.20835876464844

Removed: (0, 5)
97.02346801757812

Removed: (0, 6)
97.01580810546875

Removed: (0, 9)
98.982666015625

Removed: (0, 10)
98.61641693115234

Removed: (0, 11)
98.89505004882812

Removed: (1, 0)
98.81217956542969

Removed: (1, 1)
97.33100891113281

Removed: (1, 2)
97.43998718261719

Removed: (1, 3)
98.00090026855469

Removed: (1, 4)
97.64048767089844

Removed: (1, 6)
97.73833465576172

Removed: (1, 7)
99.25526428222656

Removed: (1, 8)
100.56314849853516

Removed: (1, 9)
100.19857025146484

Removed: (1, 10)
100.46278381347656

Removed: (1, 11)
100.2510757446289

Removed: (2, 0)
100.60169219970703

Removed: (2, 1)
101.95899200439453

Removed: (2, 2)
101.20800018310547

Removed: (2, 3)
100.92581939697266

Removed: (2, 4)
99.52203369140625

Removed: (2, 5)
100.21537017822266

Removed: (2, 6)
99.31548309326172

Removed: (2, 7)
100.2766

In [None]:
fb_3 = curr_circuit.copy()
fb_3

[(0, 1),
 (0, 7),
 (0, 8),
 (1, 5),
 (4, 4),
 (4, 7),
 (4, 9),
 (4, 10),
 (5, 4),
 (5, 6),
 (5, 8),
 (6, 6),
 (6, 10),
 (7, 2),
 (7, 6),
 (7, 7),
 (7, 8),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (9, 5),
 (10, 2)]

#### loop rmv and check for most impt heads

In [None]:
circ = fb_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 97.1143


97.11428833007812

In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 85.3010
removed: (0, 7)
Average logit difference (circuit / full) %: 96.2253
removed: (0, 8)
Average logit difference (circuit / full) %: 96.7433
removed: (1, 5)
Average logit difference (circuit / full) %: 92.4179
removed: (4, 4)
Average logit difference (circuit / full) %: 68.2523
removed: (4, 7)
Average logit difference (circuit / full) %: 96.2778
removed: (4, 9)
Average logit difference (circuit / full) %: 96.9809
removed: (4, 10)
Average logit difference (circuit / full) %: 95.4617
removed: (5, 4)
Average logit difference (circuit / full) %: 94.5391
removed: (5, 6)
Average logit difference (circuit / full) %: 91.3686
removed: (5, 8)
Average logit difference (circuit / full) %: 96.0476
removed: (6, 6)
Average logit difference (circuit / full) %: 95.3563
removed: (6, 10)
Average logit difference (circuit / full) %: 89.6346
removed: (7, 2)
Average logit difference (circuit / full) %: 95.5758
removed: (7, 6)
Average logit di

In [None]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(9, 1): 66.3958969116211,
 (4, 4): 68.25231170654297,
 (7, 11): 76.1490249633789,
 (10, 2): 82.1600341796875,
 (0, 1): 85.30104064941406,
 (8, 11): 86.6012954711914,
 (8, 8): 88.16691589355469,
 (6, 10): 89.63456726074219,
 (5, 6): 91.36856842041016,
 (1, 5): 92.41793823242188,
 (8, 6): 92.7655258178711,
 (9, 5): 94.21306610107422,
 (7, 10): 94.44884490966797,
 (5, 4): 94.53907775878906,
 (7, 6): 94.88253021240234,
 (8, 0): 95.13815307617188,
 (6, 6): 95.35630798339844,
 (4, 10): 95.46166229248047,
 (7, 2): 95.5757827758789,
 (5, 8): 96.04763793945312,
 (0, 7): 96.22527313232422,
 (4, 7): 96.27782440185547,
 (7, 7): 96.50817108154297,
 (7, 8): 96.73036193847656,
 (0, 8): 96.7433090209961,
 (4, 9): 96.98094940185547}

### iter fwd backw, threshold 20

In [None]:
threshold = 20
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
99.6773910522461

Removed: (0, 1)
87.98929595947266

Removed: (0, 2)
88.6512451171875

Removed: (0, 3)
86.90736389160156

Removed: (0, 4)
87.87069702148438

Removed: (0, 5)
85.47981262207031

Removed: (0, 6)
85.468994140625

Removed: (0, 7)
85.32069396972656

Removed: (0, 8)
84.54505920410156

Removed: (0, 9)
84.93225860595703

Removed: (0, 10)
87.507568359375

Removed: (0, 11)
87.77931213378906

Removed: (1, 0)
87.68569946289062

Removed: (1, 1)
86.57756805419922

Removed: (1, 2)
86.78600311279297

Removed: (1, 3)
86.860107421875

Removed: (1, 4)
86.6047592163086

Removed: (1, 5)
81.6650619506836

Removed: (1, 6)
81.76591491699219

Removed: (1, 7)
82.36595916748047

Removed: (1, 8)
83.85513305664062

Removed: (1, 9)
83.27083587646484

Removed: (1, 10)
83.19534301757812

Removed: (1, 11)
82.02559661865234

Removed: (2, 0)
82.97765350341797

Removed: (2, 1)
84.57157135009766

Removed: (2, 2)
83.71525573730469

Removed: (2, 3)
82.80513000488281

Remov

In [None]:
fb_20 = curr_circuit.copy()
fb_20

[(3, 2),
 (4, 4),
 (4, 8),
 (4, 10),
 (4, 11),
 (5, 5),
 (5, 6),
 (5, 7),
 (5, 8),
 (6, 1),
 (6, 7),
 (6, 9),
 (6, 10),
 (7, 0),
 (7, 2),
 (7, 5),
 (7, 6),
 (7, 7),
 (7, 8),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (10, 2)]

In [None]:
mean_ablate_by_lst(fb_20, model, print_output=True)

Average logit difference (circuit / full) %: 80.0140


tensor(80.0140, device='cuda:0')

In [None]:
len(fb_20)

28

#### loop rmv and check for most impt heads

In [None]:
circ = fb_20
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 80.0140


80.01403045654297

In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()
    lh_scores[lh] = new_score

removed: (3, 2)
Average logit difference (circuit / full) %: 78.5359
removed: (4, 4)
Average logit difference (circuit / full) %: 57.5247
removed: (4, 8)
Average logit difference (circuit / full) %: 79.5581
removed: (4, 10)
Average logit difference (circuit / full) %: 77.7767
removed: (4, 11)
Average logit difference (circuit / full) %: 78.8688
removed: (5, 5)
Average logit difference (circuit / full) %: 79.1987
removed: (5, 6)
Average logit difference (circuit / full) %: 74.9998
removed: (5, 7)
Average logit difference (circuit / full) %: 79.7443
removed: (5, 8)
Average logit difference (circuit / full) %: 77.3542
removed: (6, 1)
Average logit difference (circuit / full) %: 78.4455
removed: (6, 7)
Average logit difference (circuit / full) %: 79.5844
removed: (6, 9)
Average logit difference (circuit / full) %: 79.8339
removed: (6, 10)
Average logit difference (circuit / full) %: 75.6201
removed: (7, 0)
Average logit difference (circuit / full) %: 79.8278
removed: (7, 2)
Average logit d

In [None]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(9, 1): 52.848934173583984,
 (4, 4): 57.52471923828125,
 (7, 11): 65.4083023071289,
 (10, 2): 67.60511016845703,
 (8, 8): 71.76107788085938,
 (8, 11): 73.92449188232422,
 (5, 6): 74.99983978271484,
 (6, 10): 75.62007904052734,
 (7, 10): 77.34323120117188,
 (5, 8): 77.35415649414062,
 (8, 6): 77.66063690185547,
 (4, 10): 77.77672576904297,
 (7, 6): 78.29383087158203,
 (6, 1): 78.44552612304688,
 (3, 2): 78.53585815429688,
 (8, 0): 78.65070343017578,
 (7, 2): 78.7663345336914,
 (8, 1): 78.78251647949219,
 (4, 11): 78.86881256103516,
 (5, 5): 79.1987075805664,
 (7, 7): 79.32286071777344,
 (4, 8): 79.5581283569336,
 (6, 7): 79.58438873291016,
 (7, 5): 79.66378021240234,
 (5, 7): 79.74430847167969,
 (7, 8): 79.74847412109375,
 (7, 0): 79.8277816772461,
 (6, 9): 79.83391571044922}

# Prune backwds-fwds iteratively

### threshold 3 (OLD, DO NOT RUN)

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
99.36744689941406

Removed: (11, 1)
98.15276336669922

Removed: (11, 2)
98.11139678955078

Removed: (11, 3)
97.8514404296875

Removed: (11, 4)
98.47762298583984

Removed: (11, 5)
98.56700897216797

Removed: (11, 6)
98.57281494140625

Removed: (11, 7)
98.51939392089844

Removed: (11, 9)
98.4183349609375

Removed: (11, 10)
97.51922607421875

Removed: (11, 11)
97.42621612548828

Removed: (10, 0)
97.4111099243164

Removed: (10, 1)
97.25757598876953

Removed: (10, 3)
97.45303344726562

Removed: (10, 4)
97.4413070678711

Removed: (10, 5)
97.26394653320312

Removed: (10, 6)
97.17695617675781

Removed: (10, 7)
100.68805694580078

Removed: (10, 8)
99.4012680053711

Removed: (10, 9)
99.0429916381836

Removed: (10, 10)
99.18038177490234

Removed: (10, 11)
98.84198760986328

Removed: (9, 0)
98.87261962890625

Removed: (9, 2)
98.37891387939453

Removed: (9, 3)
97.85379791259766

Removed: (9, 4)
98.12966918945312

Removed: (9, 6)
98.41926574707031

Removed: (9

In [None]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (1, 0),
 (1, 5),
 (3, 2),
 (4, 4),
 (4, 8),
 (4, 10),
 (5, 4),
 (5, 6),
 (5, 8),
 (6, 9),
 (6, 10),
 (7, 7),
 (7, 8),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 11),
 (9, 1),
 (9, 5),
 (10, 2),
 (11, 8)]

#### loop rmv and check for most impt heads

In [None]:
circ = bf_3
mean_ablate_by_lst(circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 97.4123


97.41226959228516

In [None]:
lh_scores = {}
for lh in circ:
    copy_circuit = circ.copy()
    copy_circuit.remove(lh)
    print("removed: " + str(lh))
    new_score = mean_ablate_by_lst(copy_circuit, model, print_output=True).item()
    lh_scores[lh] = new_score

removed: (0, 1)
Average logit difference (circuit / full) %: 85.8583
removed: (1, 0)
Average logit difference (circuit / full) %: 94.9495
removed: (1, 5)
Average logit difference (circuit / full) %: 90.7815
removed: (3, 2)
Average logit difference (circuit / full) %: 94.8821
removed: (4, 4)
Average logit difference (circuit / full) %: 64.8832
removed: (4, 8)
Average logit difference (circuit / full) %: 96.5330
removed: (4, 10)
Average logit difference (circuit / full) %: 96.0251
removed: (5, 4)
Average logit difference (circuit / full) %: 94.7940
removed: (5, 6)
Average logit difference (circuit / full) %: 91.8178
removed: (5, 8)
Average logit difference (circuit / full) %: 95.4617
removed: (6, 9)
Average logit difference (circuit / full) %: 96.9442
removed: (6, 10)
Average logit difference (circuit / full) %: 88.7927
removed: (7, 7)
Average logit difference (circuit / full) %: 96.6616
removed: (7, 8)
Average logit difference (circuit / full) %: 96.9658
removed: (7, 10)
Average logit d

In [None]:
dict(sorted(lh_scores.items(), key=lambda item: item[1]))

{(4, 4): 64.88319396972656,
 (9, 1): 66.968994140625,
 (7, 11): 73.26703643798828,
 (10, 2): 83.12308502197266,
 (0, 1): 85.8582534790039,
 (8, 11): 87.15882873535156,
 (8, 8): 88.12217712402344,
 (6, 10): 88.79270935058594,
 (1, 5): 90.781494140625,
 (5, 6): 91.81783294677734,
 (8, 6): 92.28190612792969,
 (9, 5): 93.96353912353516,
 (8, 0): 94.70207214355469,
 (11, 8): 94.79283142089844,
 (5, 4): 94.7939682006836,
 (3, 2): 94.88208770751953,
 (1, 0): 94.94954681396484,
 (5, 8): 95.46170806884766,
 (7, 10): 95.58817291259766,
 (4, 10): 96.02513122558594,
 (8, 1): 96.072509765625,
 (4, 8): 96.53302001953125,
 (7, 7): 96.66156768798828,
 (6, 9): 96.94415283203125,
 (7, 8): 96.96578979492188}

#### compare fb and bf

In [None]:
print(len(fb_3))
print(len(bf_3))

26
25


In [None]:
set(fb_3) - set(bf_3)

{(0, 7), (0, 8), (4, 7), (4, 9), (6, 6), (7, 2), (7, 6)}