# Setup Imports


Code from: https://arena-ch1-transformers.streamlit.app/[1.3]_Indirect_Object_Identification

In [2]:
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os, sys

if IN_COLAB:
    # Install packages
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    # %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

    # Code to download the necessary files (e.g. solutions, test funcs)
    import os, sys
    if not os.path.exists("chapter1_transformers"):
        !curl -o /content/main.zip https://codeload.github.com/callummcdougall/ARENA_2.0/zip/refs/heads/main
        !unzip /content/main.zip 'ARENA_2.0-main/chapter1_transformers/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter1_transformers/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter1_transformers", "chapter1_transformers")
        os.rmdir("ARENA_2.0-main")
        os.chdir("chapter1_transformers/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0
Collecting jaxtyping
  Downloading jaxtyping-0.2.24-py3-none-any.whl (38 kB)
Collecting typeguard<3,>=2.13.3 (from jaxtyping)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, jaxtyping
Successfully installed jaxtyping-0.2.24 typeguard-2.13.3
Collecting transformer_lens
  Downloading transformer_lens-1.12.0-py3-none-any.whl (118 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.0/119.0 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer_lens)
  Downloading accelerate-0.25.0-py3-none-any.whl (2

In [3]:
import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
from pathlib import Path
import torch as t
from torch import Tensor
import numpy as np
import einops
from tqdm.notebook import tqdm
import plotly.express as px
import webbrowser
import re
import itertools
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set
from functools import partial
from IPython.display import display, HTML
from rich.table import Table, Column
from rich import print as rprint
# import circuitsvis as cv
from pathlib import Path
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP

t.set_grad_enabled(False)

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = (exercises_dir / "part3_indirect_object_identification").resolve()
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

# from plotly_utils import imshow, line, scatter, bar
# import part3_indirect_object_identification.tests as tests

device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

MAIN = __name__ == "__main__"



In [4]:
from part3_indirect_object_identification.ioi_dataset import NAMES, IOIDataset

## Load Model

In [5]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


# Change Inputs Here

In [6]:
task = "numwords"

heads_not_ablate = [(0, 1), (1, 5), (4, 4), (4, 10), (5, 8), (6, 1), (6, 6), (6, 10), (7, 2), (7, 6), (7, 11), (8, 1), (8, 6), (8, 8), (8, 9), (8, 11), (9, 1), (9, 5), (9, 7)]
mlps_not_ablate = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

# Generate dataset with multiple prompts

In [7]:
import torch

In [8]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.corr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["corr"])[0] for prompt in self.prompts
        ]
        self.incorr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        # for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
        for targ in [key for key in pos_dict]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [9]:
import pickle

prompts_list = []

temps = ['done', 'lost', 'names']
nw ='nw'

for i in temps:
    file_name = f'/content/{nw}_prompts_{i}.pkl'
    with open(file_name, 'rb') as file:
        filelist = pickle.load(file)

    print(filelist[0]['text'])
    prompts_list += filelist [:512] #768 512

len(prompts_list)

Van done in one. Hat done in two. Ring done in three. Desk done in four. Sun done in
Oil lost in one. Apple lost in two. Tree lost in three. Snow lost in four. Apple lost in
Marcus born in one. Victoria born in two. George born in three. Brandon born in four. Jamie born in


1536

In [10]:
# pos_dict = {
#     'S1': 4,
#     'S2': 10,
#     'S3': 16,
#     'S4': 22,
# }

pos_dict = {}
for i in range(len(model.tokenizer.tokenize(prompts_list[0]['text']))):
    pos_dict['S'+str(i)] = i

# pos_dict

In [11]:
import pickle
file_name = f'/content/randDS_{task}.pkl'
with open(file_name, 'rb') as file:
    prompts_list_2 = pickle.load(file)

In [12]:
# prompts_list = prompts_list[:500]
# prompts_list_2 = prompts_list_2[:500]

In [13]:
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [14]:
dataset_1 = dataset

In [15]:
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

# Path patching fns

## Performance Metrics

In [16]:
import gc

# del(ioi_cache)
# del(ioi_logits_original)

torch.cuda.empty_cache()
gc.collect()

18

In [17]:
def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset_1: IOIDataset = dataset_1, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    corr_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset_1.word_idx["end"], dataset_1.corr_tokenIDs]
    incorr_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset_1.word_idx["end"], dataset_1.incorr_tokenIDs]
    # Find logit difference
    answer_logit_diff = corr_logits - incorr_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

model.reset_hooks(including_permanent=True)

ioi_logits_original = model(dataset_1.toks)
abc_logits_original = model(dataset_2.toks)

ioi_average_logit_diff = logits_to_ave_logit_diff_2(ioi_logits_original).item()
abc_average_logit_diff = logits_to_ave_logit_diff_2(abc_logits_original).item()
orig_score = ioi_average_logit_diff

In [18]:
def ioi_metric_3(
    logits: Float[Tensor, "batch seq d_vocab"],
    clean_logit_diff: float = ioi_average_logit_diff,
    corrupted_logit_diff: float = abc_average_logit_diff,
    dataset_1: IOIDataset = dataset_1,
) -> float:
    patched_logit_diff = logits_to_ave_logit_diff_2(logits, dataset_1)
    return (patched_logit_diff / clean_logit_diff)

print(f"IOI metric (IOI dataset): {ioi_metric_3(ioi_logits_original):.4f}")
print(f"IOI metric (ABC dataset): {ioi_metric_3(abc_logits_original):.4f}")

IOI metric (IOI dataset): 1.0000
IOI metric (ABC dataset): -0.0505


In [19]:
del(ioi_logits_original)
del(abc_logits_original)

## patching fns

In [20]:
def patch_or_freeze_head_vectors(
    orig_head_vector: Float[Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    new_cache: ActivationCache,
    orig_cache: ActivationCache,
    head_to_patch: Tuple[int, int],
) -> Float[Tensor, "batch pos head_index d_head"]:
    '''
    This helps implement step 2 of path patching. We freeze all head outputs (i.e. set them
    to their values in orig_cache), except for head_to_patch (if it's in this layer) which
    we patch with the value from new_cache.

    head_to_patch: tuple of (layer, head)
        we can use hook.layer() to check if the head to patch is in this layer
    '''
    # Setting using ..., otherwise changing orig_head_vector will edit cache value too
    orig_head_vector[...] = orig_cache[hook.name][...]
    if head_to_patch[0] == hook.layer():
        orig_head_vector[:, :, head_to_patch[1]] = new_cache[hook.name][:, :, head_to_patch[1]]
    return orig_head_vector

def patch_head_input(
    orig_activation: Float[Tensor, "batch pos head_idx d_head"],
    hook: HookPoint,
    patched_cache: ActivationCache,
    head_list: List[Tuple[int, int]],
) -> Float[Tensor, "batch pos head_idx d_head"]:
    '''
    Function which can patch any combination of heads in layers,
    according to the heads in head_list.
    '''
    heads_to_patch = [head for layer, head in head_list if layer == hook.layer()]
    orig_activation[:, :, heads_to_patch] = patched_cache[hook.name][:, :, heads_to_patch]
    return orig_activation

In [21]:
# def patch_or_freeze_head_vectors(
def patch_or_freeze_mlp_vectors(
    # orig_head_vector: Float[Tensor, "batch pos head_index d_head"],
    orig_MLP_vector: Float[Tensor, "batch pos d_model"],
    hook: HookPoint,
    new_cache: ActivationCache,
    orig_cache: ActivationCache,
    # head_to_patch: Tuple[int, int],
    layer_to_patch: int,
# ) -> Float[Tensor, "batch pos head_index d_head"]:
) -> Float[Tensor, "batch pos d_model"]:
    '''
    This helps implement step 2 of path patching. We freeze all head outputs (i.e. set them
    to their values in orig_cache), except for head_to_patch (if it's in this layer) which
    we patch with the value from new_cache.

    head_to_patch: tuple of (layer, head)
        we can use hook.layer() to check if the head to patch is in this layer
    '''
    # the layer is hook.layer(), and orig_head_vector is ALREADY an MLP at a layer, so we don't get it by layer
    # we just have to patch in each neuron of the MLP (d_model) that's why dims are "batch pos d_model"

    # Setting using ..., otherwise changing orig_head_vector will edit cache value too
    # this keeps everything the same
    # we NEED this to prevent change by ref!
    # orig_head_vector[...] = orig_cache[hook.name][...]
    # orig_MLP_vector[...] = orig_cache[hook.name][...]

    # this change the one MLP layer
    # if head_to_patch[0] == hook.layer():
    #     orig_head_vector[:, :, head_to_patch[1]] = new_cache[hook.name][:, :, head_to_patch[1]]

    # set the entire sender head as new (corr) cache output actvs
    if layer_to_patch == hook.layer():
        orig_MLP_vector[:, :, :] = new_cache[hook.name][:, :, :]
    return orig_MLP_vector

def patch_mlp_input(
# def patch_head_input(
    # orig_activation: Float[Tensor, "batch pos head_idx d_head"],
    orig_activation: Float[Tensor, "batch pos d_model"],
    hook: HookPoint,
    patched_cache: ActivationCache,
    # head_list: List[Tuple[int, int]],
    layer_list: List[int],
) -> Float[Tensor, "batch pos head_idx d_head"]:
    '''
    Function which can patch any combination of heads in layers,
    according to the heads in head_list.
    '''
    # heads_to_patch = [head for layer, head in head_list if layer == hook.layer()]  # we dont need list, just layer int
    # orig_activation[:, :, heads_to_patch] = patched_cache[hook.name][:, :, heads_to_patch] # heads_to_patch should now be an int, layer?

    # exact same thing as before, since we don't need to freeze other heads, so uinlike
    # w/ attn heads, this is the same fn. we can just re-use the prev one instead of this one
    # we also don't have a head list, but an MLP list
    # ACCTU this is diff; there's al ist of layer ints, and only if layer in list do we patch

    # orig_activation[...] = orig_cache[hook.name][...]  # we dont need this? but by ref? postpone thinking this

    if hook.layer() in layer_list:
        # pdb.set_trace()
        orig_activation[:, :, :] = patched_cache[hook.name][:, :, :]
    return orig_activation

# MLP and Head ablation fns

## MLP ablation fns

In [22]:
from torch import Tensor
from typing import Dict, Tuple, List
from jaxtyping import Float, Bool
import torch as t

def logits_to_ave_logit_diff(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    corr_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.corr_tokenIDs]
    incorr_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.incorr_tokenIDs]
    # Find logit difference
    answer_logit_diff = corr_logits - incorr_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [23]:
def compute_means_by_template_MLP(
    means_dataset: Dataset,
    model: HookedTransformer
) -> Float[Tensor, "layer batch seq head_idx d_head"]:
    '''
    Returns the mean of each head's output over the means dataset. This mean is
    computed separately for each group of prompts with the same template (these
    are given by means_dataset.groups).
    '''
    # Cache the outputs of every head
    _, means_cache = model.run_with_cache(
        means_dataset.toks.long(),
        return_type=None,
        names_filter=lambda name: name.endswith("mlp_out"),
    )
    # Create tensor to store means
    n_layers, d_model = model.cfg.n_layers, model.cfg.d_model
    batch, seq_len = len(means_dataset), means_dataset.max_len
    means = t.zeros(size=(n_layers, batch, seq_len, d_model), device=model.cfg.device)

    # Get set of different templates for this data
    for layer in range(n_layers):
        mlp_output_for_this_layer: Float[Tensor, "batch seq d_model"] = means_cache[utils.get_act_name("mlp_out", layer)]
        for template_group in means_dataset.groups:  # here, we only have one group
            mlp_output_for_this_template = mlp_output_for_this_layer[template_group]
            # aggregate all batches
            mlp_output_means_for_this_template = einops.reduce(mlp_output_for_this_template, "batch seq d_model -> seq d_model", "mean")
            means[layer, template_group] = mlp_output_means_for_this_template
            # at layer, each batch ind is tempalte group (a tensor of size seq d_model)
            # is assigned the SAME mean, "mlp_output_means_for_this_template"

    del(means_cache)

    return means

In [24]:
def get_mlp_outputs_and_posns_to_keep(
    means_dataset: Dataset,
    model: HookedTransformer,
    circuit: Dict[str, List[int]],  # Adjusted to hold list of layers instead of (layer, head) tuples
    seq_pos_to_keep: Dict[str, str],
) -> Dict[int, Bool[Tensor, "batch seq"]]:  # Adjusted the return type to "batch seq"
    '''
    Returns a dictionary mapping layers to a boolean mask giving the indices of the
    MLP output which *shouldn't* be mean-ablated.

    The output of this function will be used for the hook function that does ablation.
    '''
    mlp_outputs_and_posns_to_keep = {}
    batch, seq = len(means_dataset), means_dataset.max_len

    for layer in range(model.cfg.n_layers):
        mask = t.zeros(size=(batch, seq))

        for (mlp_type, layer_list) in circuit.items():
            seq_pos = seq_pos_to_keep[mlp_type]
            indices = means_dataset.word_idx[seq_pos]
            if layer in layer_list:  # Check if the current layer is in the layer list for this mlp_type
                mask[:, indices] = 1

        mlp_outputs_and_posns_to_keep[layer] = mask.bool()

    return mlp_outputs_and_posns_to_keep

In [25]:
def hook_fn_mask_mlp_out(
    mlp_out: Float[Tensor, "batch seq d_mlp"],
    hook: HookPoint,
    mlp_outputs_and_posns_to_keep: Dict[int, Bool[Tensor, "batch seq"]],
    means: Float[Tensor, "layer batch seq d_mlp"],
) -> Float[Tensor, "batch seq d_mlp"]:
    '''
    Hook function which masks the MLP output of a transformer layer.

    mlp_outputs_and_posns_to_keep
        Dict created with the get_mlp_outputs_and_posns_to_keep function. This tells
        us where to mask.

    means
        Tensor of mean MLP output values of the means_dataset over each group of prompts
        with the same template. This tells us what values to mask with.
    '''
    # Get the mask for this layer, adapted for MLP output structure
    mask_for_this_layer = mlp_outputs_and_posns_to_keep[hook.layer()].unsqueeze(-1).to(mlp_out.device)

    # Set MLP output values to the mean where necessary
    mlp_out = t.where(mask_for_this_layer, mlp_out, means[hook.layer()])

    return mlp_out

In [26]:
CIRCUIT = {}
SEQ_POS_TO_KEEP = {}
def add_mean_ablation_hook_MLP(
    model: HookedTransformer,
    means_dataset: Dataset,
    circuit: Dict[str, List[Tuple[int, int]]] = CIRCUIT,
    seq_pos_to_keep: Dict[str, str] = SEQ_POS_TO_KEEP,
    is_permanent: bool = True,
) -> HookedTransformer:
    '''
    Adds a permanent hook to the model, which ablates according to the circuit and
    seq_pos_to_keep dictionaries.

    In other words, when the model is run on ioi_dataset, every head's output will
    be replaced with the mean over means_dataset for sequences with the same template,
    except for a subset of heads and sequence positions as specified by the circuit
    and seq_pos_to_keep dicts.
    '''

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = compute_means_by_template_MLP(means_dataset, model)

    # Convert this into a boolean map
    mlp_outputs_and_posns_to_keep = get_mlp_outputs_and_posns_to_keep(means_dataset, model, circuit, seq_pos_to_keep)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_fn_mask_mlp_out,
        mlp_outputs_and_posns_to_keep=mlp_outputs_and_posns_to_keep,
        means=means
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("mlp_out"), hook_fn, is_permanent=True)

    return model

In [27]:
def mean_ablate_by_lst_MLP(lst, model, orig_score, print_output=True):
    CIRCUIT = {}
    SEQ_POS_TO_KEEP = {}
    for i in range(len(model.tokenizer.tokenize(prompts_list_2[0]['text']))):
        CIRCUIT['S'+str(i)] = lst
        if i == len(model.tokenizer.tokenize(prompts_list_2[0]['text'])) - 1:
            SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
        else:
            SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    # ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = add_mean_ablation_hook_MLP(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    new_logits = model(dataset.toks)

    # orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff(new_logits, dataset)
    del(new_logits)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

## head fns

In [28]:
def get_heads_and_posns_to_keep(
    means_dataset: Dataset,
    model: HookedTransformer,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
) -> Dict[int, Bool[Tensor, "batch seq head"]]:
    '''
    Returns a dictionary mapping layers to a boolean mask giving the indices of the
    z output which *shouldn't* be mean-ablated.

    The output of this function will be used for the hook function that does ablation.
    '''
    heads_and_posns_to_keep = {}
    batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads

    for layer in range(model.cfg.n_layers):

        mask = t.zeros(size=(batch, seq, n_heads))

        for (head_type, head_list) in circuit.items():
            seq_pos = seq_pos_to_keep[head_type]
            indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
            for (layer_idx, head_idx) in head_list:
                if layer_idx == layer:
                    mask[:, indices, head_idx] = 1

        heads_and_posns_to_keep[layer] = mask.bool()

    return heads_and_posns_to_keep

def hook_fn_mask_z(
    z: Float[Tensor, "batch seq head d_head"],
    hook: HookPoint,
    heads_and_posns_to_keep: Dict[int, Bool[Tensor, "batch seq head"]],
    means: Float[Tensor, "layer batch seq head d_head"],
) -> Float[Tensor, "batch seq head d_head"]:
    '''
    Hook function which masks the z output of a transformer head.

    heads_and_posns_to_keep
        Dict created with the get_heads_and_posns_to_keep function. This tells
        us where to mask.

    means
        Tensor of mean z values of the means_dataset over each group of prompts
        with the same template. This tells us what values to mask with.
    '''
    # Get the mask for this layer, and add d_head=1 dimension so it broadcasts correctly
    mask_for_this_layer = heads_and_posns_to_keep[hook.layer()].unsqueeze(-1).to(z.device)

    # Set z values to the mean
    z = t.where(mask_for_this_layer, z, means[hook.layer()])

    return z

def compute_means_by_template(
    means_dataset: Dataset,
    model: HookedTransformer
) -> Float[Tensor, "layer batch seq head_idx d_head"]:
    '''
    Returns the mean of each head's output over the means dataset. This mean is
    computed separately for each group of prompts with the same template (these
    are given by means_dataset.groups).
    '''
    # Cache the outputs of every head
    _, means_cache = model.run_with_cache(
        means_dataset.toks.long(),
        return_type=None,
        names_filter=lambda name: name.endswith("z"),
    )
    # Create tensor to store means
    n_layers, n_heads, d_head = model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_head
    batch, seq_len = len(means_dataset), means_dataset.max_len
    means = t.zeros(size=(n_layers, batch, seq_len, n_heads, d_head), device=model.cfg.device)

    # Get set of different templates for this data
    for layer in range(model.cfg.n_layers):
        z_for_this_layer: Float[Tensor, "batch seq head d_head"] = means_cache[utils.get_act_name("z", layer)]
        for template_group in means_dataset.groups:
            z_for_this_template = z_for_this_layer[template_group]
            z_means_for_this_template = einops.reduce(z_for_this_template, "batch seq head d_head -> seq head d_head", "mean")
            means[layer, template_group] = z_means_for_this_template

    del(means_cache)

    return means

def add_mean_ablation_hook(
    model: HookedTransformer,
    means_dataset: Dataset,
    circuit: Dict[str, List[Tuple[int, int]]] = CIRCUIT,
    seq_pos_to_keep: Dict[str, str] = SEQ_POS_TO_KEEP,
    is_permanent: bool = True,
) -> HookedTransformer:
    '''
    Adds a permanent hook to the model, which ablates according to the circuit and
    seq_pos_to_keep dictionaries.

    In other words, when the model is run on ioi_dataset, every head's output will
    be replaced with the mean over means_dataset for sequences with the same template,
    except for a subset of heads and sequence positions as specified by the circuit
    and seq_pos_to_keep dicts.
    '''

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = compute_means_by_template(means_dataset, model)

    # Convert this into a boolean map
    heads_and_posns_to_keep = get_heads_and_posns_to_keep(means_dataset, model, circuit, seq_pos_to_keep)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_fn_mask_z,
        heads_and_posns_to_keep=heads_and_posns_to_keep,
        means=means
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)

    return model

## both

In [29]:
def add_mean_ablation_hook_MLP_head(
    model: HookedTransformer,
    means_dataset: Dataset,
    heads_lst, mlp_lst,
    is_permanent: bool = True,
) -> HookedTransformer:
    CIRCUIT = {}
    SEQ_POS_TO_KEEP = {}
    for i in range(len(model.tokenizer.tokenize(prompts_list_2[0]['text']))):
        CIRCUIT['S'+str(i)] = heads_lst
        if i == len(model.tokenizer.tokenize(prompts_list_2[0]['text'])) - 1:
            SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
        else:
            SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = compute_means_by_template(means_dataset, model)

    # Convert this into a boolean map
    heads_and_posns_to_keep = get_heads_and_posns_to_keep(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_fn_mask_z,
        heads_and_posns_to_keep=heads_and_posns_to_keep,
        means=means
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)

    ########################
    CIRCUIT = {}
    SEQ_POS_TO_KEEP = {}
    for i in range(len(model.tokenizer.tokenize(prompts_list_2[0]['text']))):
        CIRCUIT['S'+str(i)] = mlp_lst
        if i == len(model.tokenizer.tokenize(prompts_list_2[0]['text'])) - 1:
            SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
        else:
            SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = compute_means_by_template_MLP(means_dataset, model)

    # Convert this into a boolean map
    mlp_outputs_and_posns_to_keep = get_mlp_outputs_and_posns_to_keep(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_fn_mask_mlp_out,
        mlp_outputs_and_posns_to_keep=mlp_outputs_and_posns_to_keep,
        means=means
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("mlp_out"), hook_fn, is_permanent=True)

    return model

# only loop thru sender/mlp nodes of circuit

## head to head

In [30]:
def circ_path_patch_head_to_heads(
    circuit: List[Tuple[int, int]],
    receiver_heads: List[Tuple[int, int]],
    receiver_input: str,
    model: HookedTransformer,
    patching_metric: Callable,
    new_dataset: IOIDataset = dataset_2,
    orig_dataset: IOIDataset = dataset_1,
    new_cache: Optional[ActivationCache] = None,
    orig_cache: Optional[ActivationCache] = None,
) -> Float[Tensor, "layer head"]:
    '''
    Performs path patching (see algorithm in appendix B of IOI paper), with:

        sender head = (each head, looped through, one at a time)
        receiver node = input to a later head (or set of heads)

    The receiver node is specified by receiver_heads and receiver_input.
    Example (for S-inhibition path patching the queries):
        receiver_heads = [(8, 6), (8, 10), (7, 9), (7, 3)],
        receiver_input = "v"

    Returns:
        tensor of metric values for every possible sender head
    '''
    # SOLUTION
    # model.reset_hooks()

    assert receiver_input in ("k", "q", "v", "z")
    receiver_layers = set(next(zip(*receiver_heads)))  # a set of all layers of receiver heads
    receiver_hook_names = [utils.get_act_name(receiver_input, layer) for layer in receiver_layers]
    receiver_hook_names_filter = lambda name: name in receiver_hook_names

    results = t.zeros(max(receiver_layers), model.cfg.n_heads, device="cuda", dtype=t.float32)

    # ========== Step 1 ==========
    # Gather activations on x_orig and x_new

    # Note the use of names_filter for the run_with_cache function. Using it means we
    # only cache the things we need (in this case, just attn head outputs).
    z_name_filter = lambda name: name.endswith("z")
    if new_cache is None:
        _, new_cache = model.run_with_cache(
            new_dataset.toks,
            names_filter=z_name_filter,
            return_type=None
        )
    if orig_cache is None:
        _, orig_cache = model.run_with_cache(
            orig_dataset.toks,
            names_filter=z_name_filter,
            return_type=None
        )

    # only consider circuit nodes that
    # however, as indices repr L H, results must still be as big as entire circuit
    # before the receiver head
    senders = [tup for tup in circuit if tup[0] < receiver_heads[0][0]]

    for (sender_layer, sender_head) in tqdm(senders):

    # Note, the sender layer will always be before the final receiver layer, otherwise there will
    # be no causal effect from sender -> receiver. So we only need to loop this far.
    # for (sender_layer, sender_head) in tqdm(list(itertools.product(
    #     range(max(receiver_layers)),
    #     range(model.cfg.n_heads)
    # ))):

        # ========== Step 2 ==========
        # Run on x_orig, with sender head patched from x_new, every other head frozen

        hook_fn = partial(
            patch_or_freeze_head_vectors,
            new_cache=new_cache,
            orig_cache=orig_cache,
            head_to_patch=(sender_layer, sender_head),
        )
        model.add_hook(z_name_filter, hook_fn, level=1)

        _, patched_cache = model.run_with_cache(
            orig_dataset.toks,
            names_filter=receiver_hook_names_filter,
            return_type=None
        )
        # model.reset_hooks(including_permanent=True)
        assert set(patched_cache.keys()) == set(receiver_hook_names)

        # ========== Step 3 ==========
        # Run on x_orig, patching in the receiver node(s) from the previously cached value

        hook_fn = partial(
            patch_head_input,
            patched_cache=patched_cache,
            head_list=receiver_heads,
        )
        patched_logits = model.run_with_hooks(
            orig_dataset.toks,
            fwd_hooks = [(receiver_hook_names_filter, hook_fn)],
            return_type="logits"
        )

        # Save the results
        results[sender_layer, sender_head] = patching_metric(patched_logits)

    return results

## mlp to mlp

In [31]:
def circ_path_patch_MLPs_to_MLPs(
    mlp_circuit: List[int],
    receiver_layers: List[int],
    # receiver_input: str,
    model: HookedTransformer,
    patching_metric: Callable,
    new_dataset: IOIDataset = dataset_2,
    orig_dataset: IOIDataset = dataset_1,
    new_cache: Optional[ActivationCache] = None,
    orig_cache: Optional[ActivationCache] = None,
) -> Float[Tensor, "layer head"]:
    '''
    Performs path patching (see algorithm in appendix B of IOI paper), with:

        sender head = (each head, looped through, one at a time)
        receiver node = input to a later head (or set of heads)

    The receiver node is specified by receiver_heads and receiver_input.
    Example (for S-inhibition path patching the queries):
        receiver_heads = [(8, 6), (8, 10), (7, 9), (7, 3)],
        receiver_input = "v"

    Returns:
        tensor of metric values for every possible sender head
    '''
    # model.reset_hooks()

    # assert receiver_input in ("k", "q", "v")  # we can run get_path_patch_head_to_heads() 3 times for k, q, v!
    # receiver_layers = set(next(zip(*receiver_heads)))
    # receiver_hook_names = [utils.get_act_name(receiver_input, layer) for layer in receiver_layers]
    receiver_hook_names = [utils.get_act_name('mlp_out', layer) for layer in receiver_layers]  # modify for mlp_out
    receiver_hook_names_filter = lambda name: name in receiver_hook_names

    # results = t.zeros(max(receiver_layers), model.cfg.n_heads, device="cuda", dtype=t.float32)
    results = t.zeros(max(receiver_layers), device="cuda", dtype=t.float32)

    # ========== Step 1 ==========
    # z_name_filter = lambda name: name.endswith("z")
    z_name_filter = lambda name: name.endswith("mlp_out")

    if new_cache is None:
        _, new_cache = model.run_with_cache(
            new_dataset.toks,
            names_filter=z_name_filter,
            return_type=None
        )
    if orig_cache is None:
        _, orig_cache = model.run_with_cache(
            orig_dataset.toks,
            names_filter=z_name_filter,
            return_type=None
        )

    # Note, the sender layer will always be before the final receiver layer, otherwise there will
    # be no causal effect from sender -> receiver. So we only need to loop this far.

    # for (sender_layer, sender_head) in tqdm(list(itertools.product(
    #     range(max(receiver_layers)),  # all the layers from 0 to highest receiver layer (in circuit)
    #     range(model.cfg.n_heads)  # all heads from 0 to 12
    # ))):

    # for (sender_layer) in range(max(receiver_layers)):  # all the layers from 0 to highest receiver layer (in circuit)
    sender_mlp_list = [L for L in mlp_circuit if L < max(receiver_layers)]
    for (sender_layer) in sender_mlp_list:
        # ========== Step 2 ==========
        # Run on x_orig, with sender head patched from x_new, every other head frozen

        hook_fn = partial(
            # patch_or_freeze_head_vectors,
            patch_or_freeze_mlp_vectors,
            new_cache=new_cache,
            orig_cache=orig_cache,
            # head_to_patch=(sender_layer, sender_head),
            layer_to_patch = sender_layer # an int
        )

        model.add_hook(z_name_filter, hook_fn, level=1)

        _, patched_cache = model.run_with_cache(
            orig_dataset.toks,
            names_filter=receiver_hook_names_filter,
            return_type=None
        )
        # model.reset_hooks(including_permanent=True)
        # assert set(patched_cache.keys()) == set(receiver_hook_names)

        # ========== Step 3 ==========
        # Run on x_orig, patching in the receiver node(s) from the previously cached value

        hook_fn = partial(
            # patch_head_input,
            patch_mlp_input,
            patched_cache=patched_cache,
            # head_list=receiver_heads, # list of layer ints
            layer_list=receiver_layers,
        )
        patched_logits = model.run_with_hooks(
            orig_dataset.toks,
            fwd_hooks = [(receiver_hook_names_filter, hook_fn)],
            return_type="logits"
        )

        # Save the results
        # results[sender_layer, sender_head] = patching_metric(patched_logits)
        results[sender_layer] = patching_metric(patched_logits)

    # the result is which sender layers affect ALL the inputted nodes. this is why we just
    # want to pass one node at a time- to see which layers affect just IT.
    # if we want a 'group of nodes under a common type', we'd pass a set of nodes
    return results

## head to MLP

head senders to MLP receiver (circ nodes)

In [32]:
def circ_path_patch_head_to_mlp(
    circuit: List[Tuple[int, int]],
    receiver_layers: List[int],
    model: HookedTransformer,
    patching_metric: Callable,
    new_dataset: IOIDataset = dataset_2,
    orig_dataset: IOIDataset = dataset_1,
    new_cache: Optional[ActivationCache] = None,
    orig_cache: Optional[ActivationCache] = None,
) -> Float[Tensor, "layer head"]:
    '''
    Performs path patching (see algorithm in appendix B of IOI paper), with:

        sender head = (each head, looped through, one at a time)
        receiver node = input to a later head (or set of heads)

    The receiver node is specified by receiver_heads and receiver_input.
    Example (for S-inhibition path patching the queries):
        receiver_heads = [(8, 6), (8, 10), (7, 9), (7, 3)],
        receiver_input = "v"

    Returns:
        tensor of metric values for every possible sender head
    '''
    # model.reset_hooks()

    # assert receiver_input in ("k", "q", "v")  # we can run get_path_patch_head_to_heads() 3 times for k, q, v!
    # receiver_layers = set(next(zip(*receiver_heads)))
    # receiver_hook_names = [utils.get_act_name(receiver_input, layer) for layer in receiver_layers]
    receiver_hook_names = [utils.get_act_name('mlp_out', layer) for layer in receiver_layers]  # modify for mlp_out
    receiver_hook_names_filter = lambda name: name in receiver_hook_names

    results = t.zeros(max(receiver_layers), model.cfg.n_heads, device="cuda", dtype=t.float32)
    # results = t.zeros(max(receiver_layers), device="cuda", dtype=t.float32)

    # ========== Step 1 ==========
    # z_name_filter = lambda name: name.endswith("z")
    z_name_filter = lambda name: name.endswith(("z", "mlpout"))  # gets same value as just z

    # z_name_filter = lambda name: name.endswith("mlp_out")

    if new_cache is None:
        _, new_cache = model.run_with_cache(
            new_dataset.toks,
            names_filter=z_name_filter,
            return_type=None
        )
    if orig_cache is None:
        _, orig_cache = model.run_with_cache(
            orig_dataset.toks,
            names_filter=z_name_filter,
            return_type=None
        )

    # Note, the sender layer will always be before the final receiver layer, otherwise there will
    # be no causal effect from sender -> receiver. So we only need to loop this far.

    # for (sender_layer, sender_head) in tqdm(list(itertools.product(
    #     range(max(receiver_layers)),  # all the layers from 0 to highest receiver layer (in circuit)
    #     range(model.cfg.n_heads)  # all heads from 0 to 12
    # ))):
    senders = [tup for tup in circuit if tup[0] < receiver_layers[0]]

    for (sender_layer, sender_head) in tqdm(senders):

    # have a separate loop for both MLPs AND heads as senders
    # for (sender_layer) in range(max(receiver_layers)):  # all the layers from 0 to highest receiver layer (in circuit)

        # ========== Step 2 ==========
        # Run on x_orig, with sender head patched from x_new, every other head frozen

        hook_fn = partial(
            patch_or_freeze_head_vectors,
            # patch_or_freeze_mlp_vectors,
            new_cache=new_cache,
            orig_cache=orig_cache,
            head_to_patch=(sender_layer, sender_head),
            # layer_to_patch = sender_layer # an int
        )

        model.add_hook(z_name_filter, hook_fn, level=1)

        _, patched_cache = model.run_with_cache(
            orig_dataset.toks,
            names_filter=receiver_hook_names_filter,
            return_type=None
        )
        # model.reset_hooks(including_permanent=True)
        # assert set(patched_cache.keys()) == set(receiver_hook_names)

        # ========== Step 3 ==========
        # Run on x_orig, patching in the receiver node(s) from the previously cached value

        hook_fn = partial(
            # patch_head_input,
            patch_mlp_input,
            patched_cache=patched_cache,
            # head_list=receiver_heads, # list of layer ints
            layer_list=receiver_layers,
        )
        patched_logits = model.run_with_hooks(
            orig_dataset.toks,
            fwd_hooks = [(receiver_hook_names_filter, hook_fn)],
            return_type="logits"
        )

        # Save the results
        results[sender_layer, sender_head] = patching_metric(patched_logits)
        # results[sender_layer] = patching_metric(patched_logits)

    # the result is which sender layers affect ALL the inputted nodes. this is why we just
    # want to pass one node at a time- to see which layers affect just IT.
    # if we want a 'group of nodes under a common type', we'd pass a set of nodes
    return results

## MLP to head

MLP senders to head receiver (circ nodes)

In [33]:
def circ_path_patch_mlp_to_head(
    mlp_circuit: List[int],
    receiver_heads: List[Tuple[int, int]],
    # receiver_layers: List[int],
    receiver_input: str,
    model: HookedTransformer,
    patching_metric: Callable,
    new_dataset: IOIDataset = dataset_2,
    orig_dataset: IOIDataset = dataset_1,
    new_cache: Optional[ActivationCache] = None,
    orig_cache: Optional[ActivationCache] = None,
) -> Float[Tensor, "layer head"]:
    '''
    Performs path patching (see algorithm in appendix B of IOI paper), with:

        sender head = (each head, looped through, one at a time)
        receiver node = input to a later head (or set of heads)

    The receiver node is specified by receiver_heads and receiver_input.
    Example (for S-inhibition path patching the queries):
        receiver_heads = [(8, 6), (8, 10), (7, 9), (7, 3)],
        receiver_input = "v"

    Returns:
        tensor of metric values for every possible sender head
    '''
    # model.reset_hooks() # doesn't make diff if comment out or not

    assert receiver_input in ("k", "q", "v", "z")  # we can run get_path_patch_head_to_heads() 3 times for k, q, v!
    receiver_layers = set(next(zip(*receiver_heads)))
    receiver_hook_names = [utils.get_act_name(receiver_input, layer) for layer in receiver_layers]
    # receiver_hook_names = [utils.get_act_name('mlp_out', layer) for layer in receiver_layers]  # modify for mlp_out
    receiver_hook_names_filter = lambda name: name in receiver_hook_names

    # results = t.zeros(max(receiver_layers), model.cfg.n_heads, device="cuda", dtype=t.float32)
    results = t.zeros(max(receiver_layers), device="cuda", dtype=t.float32)

    # ========== Step 1 ==========
    # z_name_filter = lambda name: name.endswith("z")  # this is for sender? actually no; orig cache uses it too
    # z_name_filter = lambda name: name.endswith("mlp_out")
    z_name_filter = lambda name: name.endswith(("z", "mlpout"))  # gets same value as just mlp out

    if new_cache is None:
        _, new_cache = model.run_with_cache(
            new_dataset.toks,
            names_filter=z_name_filter,
            return_type=None
        )
    if orig_cache is None:
        _, orig_cache = model.run_with_cache(
            orig_dataset.toks,
            names_filter=z_name_filter,
            return_type=None
        )

    # Note, the sender layer will always be before the final receiver layer, otherwise there will
    # be no causal effect from sender -> receiver. So we only need to loop this far.

    # for (sender_layer, sender_head) in tqdm(list(itertools.product(
    #     range(max(receiver_layers)),  # all the layers from 0 to highest receiver layer (in circuit)
    #     range(model.cfg.n_heads)  # all heads from 0 to 12
    # ))):

    # have a separate loop for both MLPs AND heads as senders
    # for (sender_layer) in range(max(receiver_layers)):  # all the layers from 0 to highest receiver layer (in circuit)
    sender_mlp_list = [L for L in mlp_circuit if L < receiver_heads[0][0]]
    for (sender_layer) in sender_mlp_list:

        # ========== Step 2 ==========
        # Run on x_orig, with sender head patched from x_new, every other head frozen

        hook_fn = partial(
            # patch_or_freeze_head_vectors,
            patch_or_freeze_mlp_vectors,
            new_cache=new_cache,
            orig_cache=orig_cache,
            # head_to_patch=(sender_layer, sender_head),
            layer_to_patch = sender_layer # an int
        )

        model.add_hook(z_name_filter, hook_fn, level=1)

        _, patched_cache = model.run_with_cache(
            orig_dataset.toks,
            names_filter=receiver_hook_names_filter,
            return_type=None
        )
        # model.reset_hooks(including_permanent=True)
        assert set(patched_cache.keys()) == set(receiver_hook_names)

        # ========== Step 3 ==========
        # Run on x_orig, patching in the receiver node(s) from the previously cached value

        hook_fn = partial(
            patch_head_input,
            # patch_mlp_input,
            patched_cache=patched_cache,
            head_list=receiver_heads, # list of layer ints
            # layer_list=receiver_layers,
        )
        patched_logits = model.run_with_hooks(
            orig_dataset.toks,
            fwd_hooks = [(receiver_hook_names_filter, hook_fn)],
            return_type="logits"
        )

        # Save the results
        # results[sender_layer, sender_head] = patching_metric(patched_logits)
        results[sender_layer] = patching_metric(patched_logits)

    # the result is which sender layers affect ALL the inputted nodes. this is why we just
    # want to pass one node at a time- to see which layers affect just IT.
    # if we want a 'group of nodes under a common type', we'd pass a set of nodes
    return results

# loop backw

## get circuit

In [34]:
orig_score = ioi_average_logit_diff

model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
abl_model = add_mean_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

new_logits = model(dataset.toks)
new_score = logits_to_ave_logit_diff_2(new_logits, dataset)
circ_score = (100 * new_score / orig_score).item()
print(f"(cand circuit / full) %: {circ_score:.4f}")
del(new_logits)

(cand circuit / full) %: 81.1096


## head to head

In [35]:
qkv_to_HH = {} # qkv to dict

for head_type in ["q", "k", "v"]:
    head_to_head_results = {}
    for head in heads_not_ablate:
        print(head_type, head)
        model.reset_hooks()
        model = add_mean_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        result = circ_path_patch_head_to_heads(
            circuit = heads_not_ablate,
            receiver_heads = [head],
            receiver_input = head_type,
            model = model,
            patching_metric = ioi_metric_3
        )
        head_to_head_results[head] = result
    qkv_to_HH[head_type] = head_to_head_results

q (0, 1)


0it [00:00, ?it/s]

q (1, 5)


  0%|          | 0/1 [00:00<?, ?it/s]

q (4, 4)


  0%|          | 0/2 [00:00<?, ?it/s]

q (4, 10)


  0%|          | 0/2 [00:00<?, ?it/s]

q (5, 8)


  0%|          | 0/4 [00:00<?, ?it/s]

q (6, 1)


  0%|          | 0/5 [00:00<?, ?it/s]

q (6, 6)


  0%|          | 0/5 [00:00<?, ?it/s]

q (6, 10)


  0%|          | 0/5 [00:00<?, ?it/s]

q (7, 2)


  0%|          | 0/8 [00:00<?, ?it/s]

q (7, 6)


  0%|          | 0/8 [00:00<?, ?it/s]

q (7, 11)


  0%|          | 0/8 [00:00<?, ?it/s]

q (8, 1)


  0%|          | 0/11 [00:00<?, ?it/s]

q (8, 6)


  0%|          | 0/11 [00:00<?, ?it/s]

q (8, 8)


  0%|          | 0/11 [00:00<?, ?it/s]

q (8, 9)


  0%|          | 0/11 [00:00<?, ?it/s]

q (8, 11)


  0%|          | 0/11 [00:00<?, ?it/s]

q (9, 1)


  0%|          | 0/16 [00:00<?, ?it/s]

q (9, 5)


  0%|          | 0/16 [00:00<?, ?it/s]

q (9, 7)


  0%|          | 0/16 [00:00<?, ?it/s]

k (0, 1)


0it [00:00, ?it/s]

k (1, 5)


  0%|          | 0/1 [00:00<?, ?it/s]

k (4, 4)


  0%|          | 0/2 [00:00<?, ?it/s]

k (4, 10)


  0%|          | 0/2 [00:00<?, ?it/s]

k (5, 8)


  0%|          | 0/4 [00:00<?, ?it/s]

k (6, 1)


  0%|          | 0/5 [00:00<?, ?it/s]

k (6, 6)


  0%|          | 0/5 [00:00<?, ?it/s]

k (6, 10)


  0%|          | 0/5 [00:00<?, ?it/s]

k (7, 2)


  0%|          | 0/8 [00:00<?, ?it/s]

k (7, 6)


  0%|          | 0/8 [00:00<?, ?it/s]

k (7, 11)


  0%|          | 0/8 [00:00<?, ?it/s]

k (8, 1)


  0%|          | 0/11 [00:00<?, ?it/s]

k (8, 6)


  0%|          | 0/11 [00:00<?, ?it/s]

k (8, 8)


  0%|          | 0/11 [00:00<?, ?it/s]

k (8, 9)


  0%|          | 0/11 [00:00<?, ?it/s]

k (8, 11)


  0%|          | 0/11 [00:00<?, ?it/s]

k (9, 1)


  0%|          | 0/16 [00:00<?, ?it/s]

k (9, 5)


  0%|          | 0/16 [00:00<?, ?it/s]

k (9, 7)


  0%|          | 0/16 [00:00<?, ?it/s]

v (0, 1)


0it [00:00, ?it/s]

v (1, 5)


  0%|          | 0/1 [00:00<?, ?it/s]

v (4, 4)


  0%|          | 0/2 [00:00<?, ?it/s]

v (4, 10)


  0%|          | 0/2 [00:00<?, ?it/s]

v (5, 8)


  0%|          | 0/4 [00:00<?, ?it/s]

v (6, 1)


  0%|          | 0/5 [00:00<?, ?it/s]

v (6, 6)


  0%|          | 0/5 [00:00<?, ?it/s]

v (6, 10)


  0%|          | 0/5 [00:00<?, ?it/s]

v (7, 2)


  0%|          | 0/8 [00:00<?, ?it/s]

v (7, 6)


  0%|          | 0/8 [00:00<?, ?it/s]

v (7, 11)


  0%|          | 0/8 [00:00<?, ?it/s]

v (8, 1)


  0%|          | 0/11 [00:00<?, ?it/s]

v (8, 6)


  0%|          | 0/11 [00:00<?, ?it/s]

v (8, 8)


  0%|          | 0/11 [00:00<?, ?it/s]

v (8, 9)


  0%|          | 0/11 [00:00<?, ?it/s]

v (8, 11)


  0%|          | 0/11 [00:00<?, ?it/s]

v (9, 1)


  0%|          | 0/16 [00:00<?, ?it/s]

v (9, 5)


  0%|          | 0/16 [00:00<?, ?it/s]

v (9, 7)


  0%|          | 0/16 [00:00<?, ?it/s]

In [36]:
head_to_head_adjList = {}
for head_type in ["q", "k", "v"]:
    for head in heads_not_ablate:
        result = qkv_to_HH[head_type][head]
        filtered_indices = (result < 0.8) & (result != 0.0)
        rows, cols = filtered_indices.nonzero(as_tuple=True)
        sender_nodes = list(zip(rows.tolist(), cols.tolist()))
        head_with_type = head + (head_type,)
        head_to_head_adjList[head_with_type] = sender_nodes

## mlp to mlp

In [37]:
mlp_to_mlp_results = {}

# for layer in range(11, 0, -1):
for layer in reversed(mlps_not_ablate):
    print(layer)
    model.reset_hooks()
    model = add_mean_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)
    result = circ_path_patch_MLPs_to_MLPs(
        mlp_circuit = mlps_not_ablate,
        receiver_layers = [layer],
        model = model,
        patching_metric = ioi_metric_3
    )
    mlp_to_mlp_results[layer] = result

11
10
9
8
7
6
5
4
3
2
1
0


In [38]:
mlp_to_mlp_adjList = {}
for mlp in mlps_not_ablate:
    result = mlp_to_mlp_results[mlp]
    filtered_indices = (result < 0.80) & (result != 0.0)
    filtered_indices = filtered_indices.nonzero(as_tuple=True)[0]
    mlp_to_mlp_adjList[mlp] = filtered_indices.tolist()

## head to mlp

In [39]:
head_to_mlp_results = {}

# for layer in range(11, 0, -1):
for layer in reversed(mlps_not_ablate):
    print(layer)
    model.reset_hooks()
    model = add_mean_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)
    result = circ_path_patch_head_to_mlp(
        circuit = heads_not_ablate,
        receiver_layers = [layer],
        model = model,
        patching_metric = ioi_metric_3
    )
    head_to_mlp_results[layer] = result

11


  0%|          | 0/19 [00:00<?, ?it/s]

10


  0%|          | 0/19 [00:00<?, ?it/s]

9


  0%|          | 0/16 [00:00<?, ?it/s]

8


  0%|          | 0/11 [00:00<?, ?it/s]

7


  0%|          | 0/8 [00:00<?, ?it/s]

6


  0%|          | 0/5 [00:00<?, ?it/s]

5


  0%|          | 0/4 [00:00<?, ?it/s]

4


  0%|          | 0/2 [00:00<?, ?it/s]

3


  0%|          | 0/2 [00:00<?, ?it/s]

2


  0%|          | 0/2 [00:00<?, ?it/s]

1


  0%|          | 0/1 [00:00<?, ?it/s]

0


0it [00:00, ?it/s]

In [40]:
head_to_mlp_adjList = {}
for layer in mlps_not_ablate:
    result = head_to_mlp_results[layer]
    filtered_indices = (result < 0.8) & (result != 0.0)
    rows, cols = filtered_indices.nonzero(as_tuple=True)
    sender_nodes = list(zip(rows.tolist(), cols.tolist()))
    head_to_mlp_adjList[layer] = sender_nodes

## mlp to head

In [41]:
qkv_mlp_to_HH = {} # qkv to dict

for head_type in ["q", "k", "v"]:
    mlp_to_head_results = {}
    for head in heads_not_ablate:
        print(head_type, head)
        model.reset_hooks()
        model = add_mean_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        result = circ_path_patch_mlp_to_head(
            mlp_circuit = mlps_not_ablate,
            receiver_heads = [head],
            receiver_input = head_type,
            model = model,
            patching_metric = ioi_metric_3
        )
        mlp_to_head_results[head] = result
    qkv_mlp_to_HH[head_type] = mlp_to_head_results

q (0, 1)
q (1, 5)
q (4, 4)
q (4, 10)
q (5, 8)
q (6, 1)
q (6, 6)
q (6, 10)
q (7, 2)
q (7, 6)
q (7, 11)
q (8, 1)
q (8, 6)
q (8, 8)
q (8, 9)
q (8, 11)
q (9, 1)
q (9, 5)
q (9, 7)
k (0, 1)
k (1, 5)
k (4, 4)
k (4, 10)
k (5, 8)
k (6, 1)
k (6, 6)
k (6, 10)
k (7, 2)
k (7, 6)
k (7, 11)
k (8, 1)
k (8, 6)
k (8, 8)
k (8, 9)
k (8, 11)
k (9, 1)
k (9, 5)
k (9, 7)
v (0, 1)
v (1, 5)
v (4, 4)
v (4, 10)
v (5, 8)
v (6, 1)
v (6, 6)
v (6, 10)
v (7, 2)
v (7, 6)
v (7, 11)
v (8, 1)
v (8, 6)
v (8, 8)
v (8, 9)
v (8, 11)
v (9, 1)
v (9, 5)
v (9, 7)


In [42]:
mlp_to_head_adjList = {}
for head_type in ["q", "k", "v"]:
    for head in heads_not_ablate:
        result = qkv_mlp_to_HH[head_type][head]
        filtered_indices = (result < 0.8) & (result != 0.0)
        filtered_indices = filtered_indices.nonzero(as_tuple=True)[0]
        head_with_type = head + (head_type,)
        mlp_to_head_adjList[head_with_type] = filtered_indices.tolist()

# save graph files to free up memory

In [43]:
import pickle
from google.colab import files

In [44]:
with open(task + "_head_to_head_results.pkl", "wb") as file:
    pickle.dump(head_to_head_results, file)
files.download(task + "_head_to_head_results.pkl")

with open(task + "_mlp_to_mlp_results.pkl", "wb") as file:
    pickle.dump(mlp_to_mlp_results, file)
files.download(task + "_mlp_to_mlp_results.pkl")

with open(task + "_head_to_mlp_results.pkl", "wb") as file:
    pickle.dump(head_to_mlp_results, file)
files.download(task + "_head_to_mlp_results.pkl")

with open(task + "_mlp_to_head_results.pkl", "wb") as file:
    pickle.dump(mlp_to_head_results, file)
files.download(task + "_mlp_to_head_results.pkl")

del(head_to_head_results)
del(mlp_to_mlp_results)
del(head_to_mlp_results)
del(mlp_to_head_results)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# resid post

## head to resid

In [45]:
def get_path_patch_head_to_final_resid_post(
    circuit: List[Tuple[int, int]],
    model: HookedTransformer,
    patching_metric: Callable,
    new_dataset: IOIDataset = dataset_2,
    orig_dataset: IOIDataset = dataset_1,
) -> Float[Tensor, "layer head"]:
    '''
    Performs path patching (see algorithm in appendix B of IOI paper), with:

        sender head = (each head, looped through, one at a time)
        receiver node = final value of residual stream

    Returns:
        tensor of metric values for every possible sender head
    '''
    # model.reset_hooks()
    results = t.zeros(model.cfg.n_layers, model.cfg.n_heads, device="cuda", dtype=t.float32)

    resid_post_hook_name = utils.get_act_name("resid_post", model.cfg.n_layers - 1)
    resid_post_name_filter = lambda name: name == resid_post_hook_name


    # ========== Step 1 ==========
    # Gather activations on x_orig and x_new

    # Note the use of names_filter for the run_with_cache function. Using it means we
    # only cache the things we need (in this case, just attn head outputs).
    z_name_filter = lambda name: name.endswith("z")

    _, new_cache = model.run_with_cache(
        new_dataset.toks,
        names_filter=z_name_filter,
        return_type=None
    )

    _, orig_cache = model.run_with_cache(
        orig_dataset.toks,
        names_filter=z_name_filter,
        return_type=None
    )

    # Looping over every possible sender head (the receiver is always the final resid_post)
    # Note use of itertools (gives us a smoother progress bar)
    # for (sender_layer, sender_head) in tqdm(list(itertools.product(range(model.cfg.n_layers), range(model.cfg.n_heads)))):

    for (sender_layer, sender_head) in tqdm(circuit):

        # ========== Step 2 ==========
        # Run on x_orig, with sender head patched from x_new, every other head frozen

        hook_fn = partial(
            patch_or_freeze_head_vectors,
            new_cache=new_cache,
            orig_cache=orig_cache,
            head_to_patch=(sender_layer, sender_head),
        )
        model.add_hook(z_name_filter, hook_fn)

        _, patched_cache = model.run_with_cache(
            orig_dataset.toks,
            names_filter=resid_post_name_filter,
            return_type=None
        )

        assert set(patched_cache.keys()) == {resid_post_hook_name}

        # ========== Step 3 ==========
        # Unembed the final residual stream value, to get our patched logits

        patched_logits = model.unembed(model.ln_final(patched_cache[resid_post_hook_name]))

        # Save the results
        results[sender_layer, sender_head] = patching_metric(patched_logits)

    return results

In [46]:
model.reset_hooks()
model = add_mean_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

path_patch_head_to_final_resid_post = get_path_patch_head_to_final_resid_post(heads_not_ablate, model, ioi_metric_3)

  0%|          | 0/19 [00:00<?, ?it/s]

In [47]:
path_patch_head_to_final_resid_post.size()

torch.Size([12, 12])

In [48]:
heads_to_resid = {}
result = path_patch_head_to_final_resid_post
filtered_indices = (result < 0.8) & (result != 0.0)
rows, cols = filtered_indices.nonzero(as_tuple=True)
heads_to_resid['resid'] = list(zip(rows.tolist(), cols.tolist()))

## mlp to resid

In [49]:
def get_path_patch_mlp_to_final_resid_post(
    mlp_circuit: List[int],
    model: HookedTransformer,
    patching_metric: Callable,
    new_dataset: IOIDataset = dataset_2,
    orig_dataset: IOIDataset = dataset_1,
) -> Float[Tensor, "layer head"]:
    '''
    Performs path patching (see algorithm in appendix B of IOI paper), with:

        sender head = (each head, looped through, one at a time)
        receiver node = final value of residual stream

    Returns:
        tensor of metric values for every possible sender head
    '''
    # model.reset_hooks()
    results = t.zeros(model.cfg.n_layers, device="cuda", dtype=t.float32) #model.cfg.n_heads,

    resid_post_hook_name = utils.get_act_name("resid_post", model.cfg.n_layers - 1)
    resid_post_name_filter = lambda name: name == resid_post_hook_name


    # ========== Step 1 ==========
    # Gather activations on x_orig and x_new

    # Note the use of names_filter for the run_with_cache function. Using it means we
    # only cache the things we need (in this case, just attn head outputs).
    z_name_filter = lambda name: name.endswith(("z", "mlp_out"))

    _, new_cache = model.run_with_cache(
        new_dataset.toks,
        names_filter=z_name_filter,
        return_type=None
    )

    _, orig_cache = model.run_with_cache(
        orig_dataset.toks,
        names_filter=z_name_filter,
        return_type=None
    )

    # Looping over every possible sender head (the receiver is always the final resid_post)
    # Note use of itertools (gives us a smoother progress bar)
    # for (sender_layer, sender_head) in tqdm(list(itertools.product(range(model.cfg.n_layers), range(model.cfg.n_heads)))):

    # for (sender_layer, sender_head) in tqdm(circuit):
    for sender_layer in mlp_circuit:

        # ========== Step 2 ==========
        # Run on x_orig, with sender head patched from x_new, every other head frozen

        hook_fn = partial(
            # patch_or_freeze_head_vectors,
            patch_or_freeze_mlp_vectors,
            new_cache=new_cache,
            orig_cache=orig_cache,
            # head_to_patch=(sender_layer, sender_head),
            layer_to_patch = sender_layer # an int
        )
        model.add_hook(z_name_filter, hook_fn)

        _, patched_cache = model.run_with_cache(
            orig_dataset.toks,
            names_filter=resid_post_name_filter,
            return_type=None
        )

        assert set(patched_cache.keys()) == {resid_post_hook_name}

        # ========== Step 3 ==========
        # Unembed the final residual stream value, to get our patched logits

        patched_logits = model.unembed(model.ln_final(patched_cache[resid_post_hook_name]))

        # Save the results
        # results[sender_layer, sender_head] = patching_metric(patched_logits)
        results[sender_layer] = patching_metric(patched_logits)

    return results

In [50]:
model.reset_hooks()
model = add_mean_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

path_patch_mlp_to_final_resid_post = get_path_patch_mlp_to_final_resid_post(mlps_not_ablate, model, ioi_metric_3)

In [51]:
path_patch_mlp_to_final_resid_post.size()

torch.Size([12])

In [52]:
mlps_to_resid = {}
result = path_patch_mlp_to_final_resid_post
filtered_indices = (result < 0.8) & (result != 0.0)
filtered_indices = filtered_indices.nonzero(as_tuple=True)[0]
mlps_to_resid['resid'] = filtered_indices.tolist()

# filter out nodes with no ingoing edges

In [53]:
head_to_head_adjList = {node: neighbors for node, neighbors in head_to_head_adjList.items() if neighbors}

In [54]:
mlp_to_head_adjList = {node: neighbors for node, neighbors in mlp_to_head_adjList.items() if neighbors}

# save graph files

In [55]:
# import pickle

# task = "numerals"

with open(task + "_head_to_head_adjList.pkl", "wb") as file:
    pickle.dump(head_to_head_adjList, file)
files.download(task + "_head_to_head_adjList.pkl")

with open(task + "_mlp_to_mlp_adjList.pkl", "wb") as file:
    pickle.dump(mlp_to_mlp_adjList, file)
files.download(task + "_mlp_to_mlp_adjList.pkl")

with open(task + "_head_to_mlp_adjList.pkl", "wb") as file:
    pickle.dump(head_to_mlp_adjList, file)
files.download(task + "_head_to_mlp_adjList.pkl")

with open(task + "_mlp_to_head_adjList.pkl", "wb") as file:
    pickle.dump(mlp_to_head_adjList, file)
files.download(task + "_mlp_to_head_adjList.pkl")

with open(task + "_heads_to_resid.pkl", "wb") as file:
    pickle.dump(heads_to_resid, file)
files.download(task + "_heads_to_resid.pkl")

with open(task + "_mlps_to_resid.pkl", "wb") as file:
    pickle.dump(mlps_to_resid, file)
files.download(task + "_mlps_to_resid.pkl")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [56]:
with open(task + "_heads_to_resid_results.pkl", "wb") as file:
    pickle.dump(path_patch_head_to_final_resid_post, file)
files.download(task + "_heads_to_resid_results.pkl")

with open(task + "_mlps_to_resid_results.pkl", "wb") as file:
    pickle.dump(path_patch_mlp_to_final_resid_post, file)
files.download(task + "_mlps_to_resid_results.pkl")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# graph plot

## plot qkv

In [63]:
from graphviz import Digraph, Source
from IPython.display import display
from google.colab import files

def plot_graph_adjacency_qkv(head_to_head_adjList, mlp_to_mlp_adjList, head_to_mlp_adjList,
                             mlp_to_head_adjList, heads_to_resid, mlps_to_resid,
                             filename="circuit_graph", highlighted_nodes=None):
    dot = Digraph()
    # dot.attr(ranksep='0.45', nodesep='0.11')  # vert height- ranksep, nodesep- w
    dot.attr(ranksep='0.3', nodesep='0.05')  # vert height- ranksep, nodesep- w

    dot.node('resid_post', color="#ADD8E6", style='filled')

    for node in mlp_to_mlp_adjList.keys():
        sender_name = "MLP " + str(node)
        dot.node(sender_name, color="#ADD8E6", style='filled')

    for node in head_to_head_adjList.keys():
        sender_name = f"{node[0]} , {node[1]} {node[2]}"
        dot.node(sender_name, color="#ADD8E6", style='filled')
        sender_name = f"{node[0]} , {node[1]}"
        dot.node(sender_name, color="#ADD8E6", style='filled')

    edges_added = []
    # for every q k v node, plot an edge to output node
    for node in head_to_head_adjList.keys():
        sender_name = f"{node[0]} , {node[1]} {node[2]}"
        receiver_name = f"{node[0]} , {node[1]}"
        dot.edge(sender_name, receiver_name, color = 'blue')
        edges_added.append((sender_name, receiver_name))

    for node in mlp_to_head_adjList.keys():
        sender_name = f"{node[0]} , {node[1]} {node[2]}"
        dot.node(sender_name, color="#ADD8E6", style='filled')
        sender_name = f"{node[0]} , {node[1]}"
        dot.node(sender_name, color="#ADD8E6", style='filled')

    # for every q k v node, plot an edge to output node
    for node in mlp_to_head_adjList.keys():
        sender_name = f"{node[0]} , {node[1]} {node[2]}"
        receiver_name = f"{node[0]} , {node[1]}"
        if (sender_name, receiver_name) not in edges_added:
            dot.edge(sender_name, receiver_name, color = 'blue')

    def loop_adjList(adjList):
        for end_node, start_nodes_list in adjList.items():
            if isinstance(end_node, int):
                receiver_name = "MLP " + str(end_node)
            elif isinstance(end_node, tuple):
                if len(end_node) == 3:
                    receiver_name = f"{end_node[0]} , {end_node[1]} {end_node[2]}"
                elif len(end_node) == 2:
                    receiver_name = f"{end_node[0]} , {end_node[1]}"
            else:
                receiver_name = 'resid_post'
            for start in start_nodes_list:
                if isinstance(start, int):
                    sender_name = "MLP " + str(start)
                elif isinstance(start, tuple):
                    if len(start) == 3:
                        sender_name = f"{start[0]} , {start[1]} {start[2]}"
                    elif len(start) == 2:
                        sender_name = f"{start[0]} , {start[1]}"
                dot.node(sender_name, color="#ADD8E6", style='filled')
                dot.node(receiver_name, color="#ADD8E6", style='filled')
                dot.edge(sender_name, receiver_name, color = 'blue')

    loop_adjList(head_to_head_adjList)
    loop_adjList(mlp_to_mlp_adjList)
    loop_adjList(head_to_mlp_adjList)
    loop_adjList(mlp_to_head_adjList)
    loop_adjList(heads_to_resid)
    loop_adjList(mlps_to_resid)

    # Display the graph in Colab
    # display(Source(dot.source))

    # Save the graph to a file
    dot.format = 'png'  # You can change this to 'pdf', 'png', etc. based on your needs
    dot.render(filename)
    files.download(filename + ".png")

In [64]:
plot_graph_adjacency_qkv(head_to_head_adjList, mlp_to_mlp_adjList, head_to_mlp_adjList,
                         mlp_to_head_adjList, heads_to_resid, mlps_to_resid, filename= task+"_qkv_circ")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## rewrite no qkv fn

In [61]:
def plot_graph_adjacency(head_to_head_adjList, mlp_to_mlp_adjList, head_to_mlp_adjList,
                             mlp_to_head_adjList, heads_to_resid, mlps_to_resid,
                             filename="circuit_graph", highlighted_nodes=None):
    dot = Digraph()
    dot.attr(ranksep='0.3', nodesep='0.05')  # vert height- ranksep, nodesep- w

    edges_added = [] # do this bc when no qkv, multiple edges
    def loop_adjList(adjList):
        for end_node, start_nodes_list in adjList.items():
            if isinstance(end_node, int):
                receiver_name = "MLP " + str(end_node)
            elif isinstance(end_node, tuple):
                receiver_name = f"{end_node[0]} , {end_node[1]}"
            else:
                receiver_name = 'resid_post'
            for start in start_nodes_list:
                if isinstance(start, int):
                    sender_name = "MLP " + str(start)
                elif isinstance(start, tuple):
                    sender_name = f"{start[0]} , {start[1]}"
                dot.node(sender_name, color="#ADD8E6", style='filled')
                dot.node(receiver_name, color="#ADD8E6", style='filled')
                if (sender_name, receiver_name) not in edges_added:
                    dot.edge(sender_name, receiver_name, color = 'blue')
                    edges_added.append((sender_name, receiver_name))

    loop_adjList(head_to_head_adjList)
    loop_adjList(mlp_to_mlp_adjList)
    loop_adjList(head_to_mlp_adjList)
    loop_adjList(mlp_to_head_adjList)
    loop_adjList(heads_to_resid)
    loop_adjList(mlps_to_resid)

    # Save the graph to a file
    dot.format = 'png'  # You can change this to 'pdf', 'png', etc. based on your needs
    dot.render(filename)
    files.download(filename + ".png")

In [62]:
plot_graph_adjacency(head_to_head_adjList, mlp_to_mlp_adjList, head_to_mlp_adjList,
                         mlp_to_head_adjList, heads_to_resid, mlps_to_resid, filename=task+"_no_qkv_circ")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>