<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-m7lt8bgf
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-m7lt8bgf
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 218ebd6f491f47f5e2f64e4c4327548b60a093eb
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-2.13.1-py3-none-any.whl (486 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6.0 (from transformer-lens==0.0.0)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━


## Installing the NodeSource Node.js 16.x repo...


## Populating apt-get cache...

+ apt-get update
Get:1 https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/ InRelease [3,622 B]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  InRelease [1,581 B]
Get:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  Packages [1,084 kB]
Get:4 http://security.ubuntu.com/ubuntu focal-security InRelease [114 kB]
Hit:5 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu focal InRelease
Hit:6 http://archive.ubuntu.com/ubuntu focal InRelease
Get:7 http://archive.ubuntu.com/ubuntu focal-updates InRelease [114 kB]
Hit:8 http://ppa.launchpad.net/cran/libgit2/ubuntu focal InRelease
Hit:9 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease
Hit:10 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu focal InRelease
Get:11 http://archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]
Hit:12 http://ppa.launchpad.net/ubuntugis/ppa/

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fa759f59120>

Plotting helper functions:

In [6]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Import functions from repo

In [8]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 7955, done.[K
remote: Counting objects: 100% (1628/1628), done.[K
remote: Compressing objects: 100% (236/236), done.[K
remote: Total 7955 (delta 1471), reused 1492 (delta 1387), pack-reused 6327[K
Receiving objects: 100% (7955/7955), 129.76 MiB | 7.75 MiB/s, done.
Resolving deltas: 100% (4701/4701), done.


In [9]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [10]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

In [23]:
#@title Names list
names = [
    "Michael",
    "Christopher",
    "Jessica",
    "Matthew",
    "Ashley",
    "Jennifer",
    "Joshua",
    "Amanda",
    "Daniel",
    "David",
    "James",
    "Robert",
    "John",
    "Joseph",
    "Andrew",
    "Ryan",
    "Brandon",
    "Jason",
    "Justin",
    "Sarah",
    "William",
    "Jonathan",
    "Stephanie",
    "Brian",
    "Nicole",
    "Nicholas",
    "Anthony",
    "Heather",
    "Eric",
    "Elizabeth",
    "Adam",
    "Megan",
    "Melissa",
    "Kevin",
    "Steven",
    "Thomas",
    "Timothy",
    "Christina",
    "Kyle",
    "Rachel",
    "Laura",
    "Lauren",
    "Amber",
    "Brittany",
    "Danielle",
    "Richard",
    "Kimberly",
    "Jeffrey",
    "Amy",
    "Crystal",
    "Michelle",
    "Tiffany",
    "Jeremy",
    "Benjamin",
    "Mark",
    "Emily",
    "Aaron",
    "Charles",
    "Rebecca",
    "Jacob",
    "Stephen",
    "Patrick",
    "Sean",
    "Erin",
    "Jamie",
    "Kelly",
    "Samantha",
    "Nathan",
    "Sara",
    "Dustin",
    "Paul",
    "Angela",
    "Tyler",
    "Scott",
    "Katherine",
    "Andrea",
    "Gregory",
    "Erica",
    "Mary",
    "Travis",
    "Lisa",
    "Kenneth",
    "Bryan",
    "Lindsey",
    "Kristen",
    "Jose",
    "Alexander",
    "Jesse",
    "Katie",
    "Lindsay",
    "Shannon",
    "Vanessa",
    "Courtney",
    "Christine",
    "Alicia",
    "Cody",
    "Allison",
    "Bradley",
    "Samuel",
]

In [24]:
def filter_names(names):
    return [name for name in names if len(model.tokenizer.tokenize(name)) == 1]
names = filter_names(names)

In [25]:
import random

def make_prompts_list(names, template, num_sentences, num_targ_tokens):
    sentences = []
    generated_set = set() # Ensure none of the generated sentences are the same
    while len(sentences) < num_sentences:
        unique_names = random.sample(names, k=num_targ_tokens)
        temp_template = template
        sentence_dict = {}
        for i, name in enumerate(unique_names, start=1):
            temp_template = temp_template.replace(f"[S{i}]", name)
            sentence_dict[f'S{i}'] = name
        sentence_dict['text'] = temp_template
        if sentence_dict['text'] not in generated_set:
            generated_set.add(sentence_dict['text'])
            sentences.append(sentence_dict)
    return sentences

In [45]:
class Dataset:
    def __init__(self, prompts, tokenizer, N, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = N
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["S1"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["S2"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if key != 'text']:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                    target_token = prompt[targ]
                else:
                    target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

Repalce io_tokens with correct answer (next, which is '5') and s_tokens with incorrect (current, which repeats)

In [46]:
# template = "[S1] and [S2] went to the store. [S2] gave a drink to"
template = "Then, [S1] and [S2] went to the store, [S2] gave a drink to"
num_inputs = 30
num_targ_tokens = 2
prompts_list = make_prompts_list(names, template, num_inputs, num_targ_tokens)
dataset = Dataset(prompts_list, model.tokenizer, num_inputs, S1_is_first=False)

In [47]:
dataset_2 = Dataset(prompts_list, model.tokenizer, num_inputs, S1_is_first=False)

# Ablate the model and compare with original

In [11]:
# from ioi_dataset import NAMES, IOIDataset

# N = 25
# ioi_dataset = IOIDataset(
#     prompt_type="mixed",
#     N=N,
#     tokenizer=model.tokenizer,
#     prepend_bos=False,
#     seed=1,
#     # device=str(device)
# )
# abc_dataset = ioi_dataset.gen_flipped_prompts("ABB->XYZ, BAB->XYZ")

In [51]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [41]:
CIRCUIT = {
    "name mover": [(9, 9), (10, 0), (9, 6)],
    "backup name mover": [(10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (9, 7), (9, 0), (11, 9)],
    "negative name mover": [(10, 7), (11, 10)],
    "s2 inhibition": [(7, 3), (7, 9), (8, 6), (8, 10)],
    # "induction": [(5, 5), (5, 8), (5, 9), (6, 9)],
    # "duplicate token": [(0, 1), (0, 10), (3, 0)],
    # "previous token": [(2, 2), (4, 11)],
}

SEQ_POS_TO_KEEP = {
    "name mover": "end",
    "backup name mover": "end",
    "negative name mover": "end",
    "s2 inhibition": "end",
    # "induction": "S2",
    # "duplicate token": "S2",
    # "previous token": "S1+1",
}

We can set the mean ablation to be over the same dataset, though with some issues

In [48]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
ioi_logits_minimal = model(dataset.toks)

print(f"Average logit difference (IOI dataset, using entire model): {logits_to_ave_logit_diff_2(ioi_logits_original, dataset):.4f}")
print(f"Average logit difference (IOI dataset, only using circuit): {logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset):.4f}")

Average logit difference (IOI dataset, using entire model): 2.5075
Average logit difference (IOI dataset, only using circuit): 2.2445


## Test circuit without main name movers

In [52]:
CIRCUIT = {
    # "name mover": [(9, 9), (10, 0), (9, 6)],
    "backup name mover": [(10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (9, 7), (9, 0), (11, 9)],
    "negative name mover": [(10, 7), (11, 10)],
    "s2 inhibition": [(7, 3), (7, 9), (8, 6), (8, 10)],
    # "induction": [(5, 5), (5, 8), (5, 9), (6, 9)],
    # "duplicate token": [(0, 1), (0, 10), (3, 0)],
    # "previous token": [(2, 2), (4, 11)],
}

SEQ_POS_TO_KEEP = {
    # "name mover": "end",
    "backup name mover": "end",
    "negative name mover": "end",
    "s2 inhibition": "end",
    # "induction": "S2",
    # "duplicate token": "S2",
    # "previous token": "S1+1",
}

In [53]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
ioi_logits_minimal = model(dataset.toks)

print(f"Average logit difference (IOI dataset, using entire model): {logits_to_ave_logit_diff_2(ioi_logits_original, dataset):.4f}")
print(f"Average logit difference (IOI dataset, only using circuit): {logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset):.4f}")

Average logit difference (IOI dataset, using entire model): 2.5075
Average logit difference (IOI dataset, only using circuit): 1.9428


## Test circuit without backup name movers

In [54]:
CIRCUIT = {
    "name mover": [(9, 9), (10, 0), (9, 6)],
    # "backup name mover": [(10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (9, 7), (9, 0), (11, 9)],
    "negative name mover": [(10, 7), (11, 10)],
    "s2 inhibition": [(7, 3), (7, 9), (8, 6), (8, 10)],
    # "induction": [(5, 5), (5, 8), (5, 9), (6, 9)],
    # "duplicate token": [(0, 1), (0, 10), (3, 0)],
    # "previous token": [(2, 2), (4, 11)],
}

SEQ_POS_TO_KEEP = {
    "name mover": "end",
    # "backup name mover": "end",
    "negative name mover": "end",
    "s2 inhibition": "end",
    # "induction": "S2",
    # "duplicate token": "S2",
    # "previous token": "S1+1",
}

In [55]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
ioi_logits_minimal = model(dataset.toks)

print(f"Average logit difference (IOI dataset, using entire model): {logits_to_ave_logit_diff_2(ioi_logits_original, dataset):.4f}")
print(f"Average logit difference (IOI dataset, only using circuit): {logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset):.4f}")

Average logit difference (IOI dataset, using entire model): 2.5075
Average logit difference (IOI dataset, only using circuit): 2.0672


## Test circuit without all name movers

In [49]:
CIRCUIT = {
    # "name mover": [(9, 9), (10, 0), (9, 6)],
    # "backup name mover": [(10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (9, 7), (9, 0), (11, 9)],
    # "negative name mover": [(10, 7), (11, 10)],
    "s2 inhibition": [(7, 3), (7, 9), (8, 6), (8, 10)],
    # "induction": [(5, 5), (5, 8), (5, 9), (6, 9)],
    # "duplicate token": [(0, 1), (0, 10), (3, 0)],
    # "previous token": [(2, 2), (4, 11)],
}

SEQ_POS_TO_KEEP = {
    # "name mover": "end",
    # "backup name mover": "end",
    # "negative name mover": "end",
    "s2 inhibition": "end",
    # "induction": "S2",
    # "duplicate token": "S2",
    # "previous token": "S1+1",
}

In [50]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
ioi_logits_minimal = model(dataset.toks)

print(f"Average logit difference (IOI dataset, using entire model): {logits_to_ave_logit_diff_2(ioi_logits_original, dataset):.4f}")
print(f"Average logit difference (IOI dataset, only using circuit): {logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset):.4f}")

Average logit difference (IOI dataset, using entire model): 2.5075
Average logit difference (IOI dataset, only using circuit): 0.4925


This shows name movers are very important.

## Test a wrong circuit to see that the logit diff doesn't match original

We keep only 4 heads, so that means we knock out most of the circuit

In [15]:
WRONG_CIRCUIT = {
    "name mover": [(2,2)],
    "backup name mover": [(1, 1)],
    "negative name mover": [],
    "s2 inhibition": [],
    "induction": [(4,4)],
    "duplicate token": [],
    "previous token": [(8,3)],
}

In [16]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

ioi_logits_original, ioi_cache = model.run_with_cache(ioi_dataset.toks)

model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=abc_dataset, circuit=WRONG_CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
ioi_logits_minimal = model(ioi_dataset.toks)

print(f"Average logit difference (IOI dataset, using entire model): {logits_to_ave_logit_diff_2(ioi_logits_original):.4f}")
print(f"Average logit difference (IOI dataset, only using circuit): {logits_to_ave_logit_diff_2(ioi_logits_minimal):.4f}")

Average logit difference (IOI dataset, using entire model): 2.8052
Average logit difference (IOI dataset, only using circuit): -0.7422


# Plot minimality scores

In [20]:
%%capture
!pip install circuitsvis

Write a function that converts into this:

In [18]:
K_FOR_EACH_COMPONENT = {
    (9, 9): set(),
    (10, 0): {(9, 9)},
    (9, 6): {(9, 9), (10, 0)},
    (10, 7): {(11, 10)},
    (11, 10): {(10, 7)},
    (8, 10): {(7, 9), (8, 6), (7, 3)},
    (7, 9): {(8, 10), (8, 6), (7, 3)},
    (8, 6): {(7, 9), (8, 10), (7, 3)},
    (7, 3): {(7, 9), (8, 10), (8, 6)},
    (5, 5): {(5, 9), (6, 9), (5, 8)},
    (5, 9): {(11, 10), (10, 7)},
    (6, 9): {(5, 9), (5, 5), (5, 8)},
    (5, 8): {(11, 10), (10, 7)},
    (0, 1): {(0, 10), (3, 0)},
    (0, 10): {(0, 1), (3, 0)},
    (3, 0): {(0, 1), (0, 10)},
    (4, 11): {(2, 2)},
    (2, 2): {(4, 11)},
    (11, 2): {(9, 9), (10, 0), (9, 6)},
    (10, 6): {(9, 9), (10, 0), (9, 6), (11, 2)},
    (10, 10): {(9, 9), (10, 0), (9, 6), (11, 2), (10, 6)},
    (10, 2): {(9, 9), (10, 0), (9, 6), (11, 2), (10, 6), (10, 10)},
    (9, 7): {(9, 9), (10, 0), (9, 6), (11, 2), (10, 6), (10, 10), (10, 2)},
    (10, 1): {(9, 9), (10, 0), (9, 6), (11, 2), (10, 6), (10, 10), (10, 2), (9, 7)},
    (11, 9): {(9, 9), (10, 0), (9, 6), (9, 0)},
    (9, 0): {(9, 9), (10, 0), (9, 6), (11, 9)},
}

In [21]:
from solutions import *

def get_score(
	model: HookedTransformer,
	ioi_dataset: IOIDataset,
	abc_dataset: IOIDataset,
	K: Set[Tuple[int, int]],
	C: Dict[str, List[Tuple[int, int]]],
) -> float:
	'''
	Returns the value F(C \ K), where F is the logit diff, C is the
	core circuit, and K is the set of circuit components to remove.
	'''
	C_excl_K = {k: [head for head in v if head not in K] for k, v in C.items()}
	model = ioi_circuit_extraction.add_mean_ablation_hook(model, abc_dataset, C_excl_K, SEQ_POS_TO_KEEP)
	logits = model(ioi_dataset.toks)
	score = logits_to_ave_logit_diff_2(logits, ioi_dataset).item()

	return score

def get_minimality_score(
		model: HookedTransformer,
		ioi_dataset: IOIDataset,
		abc_dataset: IOIDataset,
		v: Tuple[int, int],
		K: Set[Tuple[int, int]],
		C: Dict[str, List[Tuple[int, int]]] = CIRCUIT,
	) -> float:
		'''
		Returns the value | F(C \ K_union_v) - F(C | K) |, where F is
		the logit diff, C is the core circuit, K is the set of circuit
		components to remove, and v is a head (not in K).
		'''
		assert v not in K
		K_union_v = K | {v}
		C_excl_K_score = get_score(model, ioi_dataset, abc_dataset, K, C)
		C_excl_Kv_score = get_score(model, ioi_dataset, abc_dataset, K_union_v, C)

		return abs(C_excl_K_score - C_excl_Kv_score)

def get_all_minimality_scores(
		model: HookedTransformer,
		ioi_dataset: IOIDataset = ioi_dataset,
		abc_dataset: IOIDataset = abc_dataset,
		k_for_each_component: Dict = K_FOR_EACH_COMPONENT
	) -> Dict[Tuple[int, int], float]:
		'''
		Returns dict of minimality scores for every head in the model (as
		a fraction of F(M), the logit diff of the full model).

		Warning - this resets all hooks at the end (including permanent).
		'''
		# Get full circuit score F(M), to divide minimality scores by
		model.reset_hooks(including_permanent=True)
		logits = model(ioi_dataset.toks)
		full_circuit_score = logits_to_ave_logit_diff_2(logits, ioi_dataset).item()

		# Get all minimality scores, using the `get_minimality_score` function
		minimality_scores = {}
		for v, K in tqdm(k_for_each_component.items()):
			score = get_minimality_score(model, ioi_dataset, abc_dataset, v, K)
			minimality_scores[v] = score / full_circuit_score

		model.reset_hooks(including_permanent=True)

		return minimality_scores

In [None]:
minimality_scores = get_all_minimality_scores(model)

In [None]:
def plot_minimal_set_results(minimality_scores: Dict[Tuple[int, int], float]):
    '''
    Plots the minimality results, in a way resembling figure 7 in the paper.

    minimality_scores:
        Dict with elements like (9, 9): minimality score for head 9.9 (as described
        in section 4.2 of the paper)
    '''

    CIRCUIT_reversed = {head: k for k, v in CIRCUIT.items() for head in v}
    colors = [CIRCUIT_reversed[head].capitalize() + " head" for head in minimality_scores.keys()]
    color_sequence = [px.colors.qualitative.Dark2[i] for i in [0, 1, 2, 5, 3, 6]] + ["#BAEA84"]

    bar(
        list(minimality_scores.values()),
        x=list(map(str, minimality_scores.keys())),
        labels={"x": "Attention head", "y": "Change in logit diff", "color": "Head type"},
        color=colors,
        template="ggplot2",
        color_discrete_sequence=color_sequence,
        bargap=0.02,
        yaxis_tickformat=".0%",
        legend_title_text="",
        title="Plot of minimality scores (as percentages of full model logit diff)",
        width=800,
        hovermode="x unified"
    )

plot_minimal_set_results(minimality_scores)