In [None]:
import os
import sys
sys.path.append('../Automatic-Circuit-Discovery/')
sys.path.append('..')
import re

import IPython
ipython = IPython.get_ipython()
ipython.magic('load_ext autoreload')
ipython.magic('autoreload 2')
import acdc
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

# Model Setup

In [None]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

# Dataset Setup

In [None]:
from ioi_dataset import IOIDataset, format_prompt, make_table
N = 25
clean_dataset = IOIDataset(
    prompt_type='mixed',
    N=N,
    tokenizer=model.tokenizer,
    prepend_bos=False,
    seed=1,
    device=device
)
corr_dataset = clean_dataset.gen_flipped_prompts('ABC->XYZ, BAB->XYZ')

make_table(
  colnames = ["IOI prompt", "IOI subj", "IOI indirect obj", "ABC prompt"],
  cols = [
    map(format_prompt, clean_dataset.sentences),
    model.to_string(clean_dataset.s_tokenIDs).split(),
    model.to_string(clean_dataset.io_tokenIDs).split(),
    map(format_prompt, clean_dataset.sentences),
  ],
  title = "Sentences from IOI vs ABC distribution",
)

# Metric Setup

In [None]:
def ave_logit_diff(
    logits: Float[Tensor, 'batch seq d_vocab'],
    ioi_dataset: IOIDataset,
    per_prompt: bool = False
):
    '''
        Return average logit difference between correct and incorrect answers
    '''
    # Get logits for indirect objects
    io_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.io_tokenIDs]
    s_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.s_tokenIDs]
    # Get logits for subject
    logit_diff = io_logits - s_logits
    return logit_diff if per_prompt else logit_diff.mean()

with t.no_grad():
    clean_logits = model(clean_dataset.toks)
    corrupt_logits = model(corr_dataset.toks)
    clean_logit_diff = ave_logit_diff(clean_logits, clean_dataset).item()
    corrupt_logit_diff = ave_logit_diff(corrupt_logits, corr_dataset).item()

def ioi_metric(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    corrupted_logit_diff: float = corrupt_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    ioi_dataset: IOIDataset = clean_dataset
 ):
    patched_logit_diff = ave_logit_diff(logits, ioi_dataset)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

def negative_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -ioi_metric(logits)
    
# Get clean and corrupt logit differences
with t.no_grad():
    clean_metric = ioi_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, clean_dataset)
    corrupt_metric = ioi_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, corr_dataset)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

# Run Experiment

In [7]:
from ACDCPPExperiment import ACDCPPExperiment
THRESHOLDS = np.arange(0.005, 0.155, 0.005)
RUN_NAME = 'noabs_value'
model.reset_hooks()
acdcpp_exp = ACDCPPExperiment(model,
                              clean_dataset.toks,
                              corr_dataset.toks,
                              ioi_metric,
                              negative_ioi_metric,
                              [0.04], # THRESHOLDS[:1],
                              run_name=RUN_NAME,
                              verbose=True,
                              attr_absolute_val=True,
                              save_graphs_after=0.07,
                              pruning_mode = "edge",
                              no_pruned_nodes_attr=1, # TODO add warni
                            #   positions = list(range(clean_dataset.toks.shape[1])),
                             )
pruned_heads, num_passes, pruned_attrs = acdcpp_exp.run()

Set up model hooks


  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 12/12 [00:00<00:00, 28.06it/s]


dict_keys(['blocks.11.hook_resid_post', 'blocks.11.hook_mlp_out', 'blocks.11.hook_mlp_in', 'blocks.11.attn.hook_result', 'blocks.11.attn.hook_q', 'blocks.11.hook_q_input', 'blocks.11.attn.hook_k', 'blocks.11.hook_k_input', 'blocks.11.attn.hook_v', 'blocks.11.hook_v_input', 'blocks.10.hook_mlp_out', 'blocks.10.hook_mlp_in', 'blocks.10.attn.hook_result', 'blocks.10.attn.hook_q', 'blocks.10.hook_q_input', 'blocks.10.attn.hook_k', 'blocks.10.hook_k_input', 'blocks.10.attn.hook_v', 'blocks.10.hook_v_input', 'blocks.9.hook_mlp_out', 'blocks.9.hook_mlp_in', 'blocks.9.attn.hook_result', 'blocks.9.attn.hook_q', 'blocks.9.hook_q_input', 'blocks.9.attn.hook_k', 'blocks.9.hook_k_input', 'blocks.9.attn.hook_v', 'blocks.9.hook_v_input', 'blocks.8.hook_mlp_out', 'blocks.8.hook_mlp_in', 'blocks.8.attn.hook_result', 'blocks.8.attn.hook_q', 'blocks.8.hook_q_input', 'blocks.8.attn.hook_k', 'blocks.8.hook_k_input', 'blocks.8.attn.hook_v', 'blocks.8.hook_v_input', 'blocks.7.hook_mlp_out', 'blocks.7.hook_ml



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.11.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_resid_post', index=[:]) with attribution 0.07791779935359955
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.11.attn.hook_result', index=[:, :, 2]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_resid_post', index=[:]) with attribution 0.20973780751228333
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.11.attn.hook_result', index=[:, :, 1]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_resid_post', index=[:]) with attribution 0.06629768759012222
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.10.attn.hook_result', index=[:, :, 7]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_resid_post', index=[:]) with attribution 1.4627970457077026
NOT PRUNING upstream_component=ModelComponent(hoo



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_v_input', index=[:, :, 10]) with attribution 0.05485761538147926
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_v_input', index=[:, :, 10]) with attribution 0.05485761538147926
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_v_input', index=[:, :, 10]) with attribution 0.05485761538147926
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_v_input', index=[:, :, 10]) with attribution 0.05485761538147926
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.hook



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_v_input', index=[:, :, 2]) with attribution 0.12173828482627869
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_v_input', index=[:, :, 2]) with attribution 0.12173828482627869
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_v_input', index=[:, :, 2]) with attribution 0.12173828482627869
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_v_input', index=[:, :, 2]) with attribution 0.12173828482627869
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_k_input', index=[:, :, 10]) with attribution 0.12421857565641403
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_k_input', index=[:, :, 10]) with attribution 0.12421857565641403
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_k_input', index=[:, :, 10]) with attribution 0.10832288861274719
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_k_input', index=[:, :, 10]) with attribution 0.10832288861274719
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.hook



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_k_input', index=[:, :, 2]) with attribution 0.07110200822353363
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_k_input', index=[:, :, 2]) with attribution 0.050725057721138
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_k_input', index=[:, :, 2]) with attribution 0.050725057721138
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_k_input', index=[:, :, 2]) with attribution 0.050725057721138
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_out',



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.10.attn.hook_result', index=[:, :, 10]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_q_input', index=[:, :, 10]) with attribution 0.24477745592594147
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.10.attn.hook_result', index=[:, :, 7]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_q_input', index=[:, :, 10]) with attribution 0.8063588738441467
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.10.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_q_input', index=[:, :, 10]) with attribution 0.06538743525743484
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.10.attn.hook_result', index=[:, :, 1]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_q_input', index=[:, :, 10]) with attribution 0.123191237449646
NOT PRUNING upstream_component=Mod



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_q_input', index=[:, :, 6]) with attribution 0.050604160875082016
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_q_input', index=[:, :, 6]) with attribution 0.05749812349677086
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_q_input', index=[:, :, 6]) with attribution 0.05749812349677086
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.attn.hook_result', index=[:, :, 10]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_q_input', index=[:, :, 6]) with attribution 0.058339521288871765
NOT PRUNING upstream_component=ModelC



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.10.attn.hook_result', index=[:, :, 0]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_q_input', index=[:, :, 2]) with attribution 0.08677536249160767
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_q_input', index=[:, :, 2]) with attribution 0.11440907418727875
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_q_input', index=[:, :, 2]) with attribution 0.11440907418727875
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.11.hook_q_input', index=[:, :, 2]) with attribution 0.11296018213033676
NOT PRUNING upstream_component=ModelComponent(hook_point_na



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.10.attn.hook_result', index=[:, :, 7]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_mlp_in', index=[:]) with attribution 0.045011505484580994
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_mlp_in', index=[:]) with attribution 0.05171195790171623
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_mlp_in', index=[:]) with attribution 0.05159616097807884
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_mlp_in', index=[:]) with attribution 0.16856688261032104




NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_v_input', index=[:, :, 7]) with attribution 0.04414111748337746
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_v_input', index=[:, :, 7]) with attribution 0.04414111748337746
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_v_input', index=[:, :, 7]) with attribution 0.04414111748337746
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_v_input', index=[:, :, 7]) with attribution 0.04414111748337746
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_v_input', index=[:, :, 2]) with attribution 0.07325278222560883
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_v_input', index=[:, :, 2]) with attribution 0.07325278222560883
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_v_input', index=[:, :, 2]) with attribution 0.07325278222560883
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_v_input', index=[:, :, 2]) with attribution 0.07325278222560883
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_k_input', index=[:, :, 7]) with attribution 0.07518402487039566
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_k_input', index=[:, :, 7]) with attribution 0.0947979986667633
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_k_input', index=[:, :, 7]) with attribution 0.0947979986667633
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_k_input', index=[:, :, 7]) with attribution 0.0947979986667633
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.hook_mlp_ou



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_k_input', index=[:, :, 2]) with attribution 0.05885838344693184
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_k_input', index=[:, :, 2]) with attribution 0.09489040076732635
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_k_input', index=[:, :, 2]) with attribution 0.09489040076732635
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_k_input', index=[:, :, 2]) with attribution 0.09489040076732635
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_q_input', index=[:, :, 10]) with attribution 0.040135160088539124
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_q_input', index=[:, :, 10]) with attribution 0.04792698100209236
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_q_input', index=[:, :, 10]) with attribution 0.05384305119514465
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_q_input', index=[:, :, 10]) with attribution 0.08201702684164047
NOT PRUNING upstream_component=ModelComponent



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_q_input', index=[:, :, 6]) with attribution 0.09389758110046387
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.attn.hook_result', index=[:, :, 10]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_q_input', index=[:, :, 6]) with attribution 0.10453678667545319
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.attn.hook_result', index=[:, :, 10]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_q_input', index=[:, :, 6]) with attribution 0.10453678667545319
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_q_input', index=[:, :, 6]) with attribution 0.11353877186775208
NOT PRUNING upstream_component=ModelCo



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.9.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_q_input', index=[:, :, 2]) with attribution 0.07104332745075226
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.attn.hook_result', index=[:, :, 10]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_q_input', index=[:, :, 2]) with attribution 0.1123955249786377
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.attn.hook_result', index=[:, :, 10]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_q_input', index=[:, :, 2]) with attribution 0.1123955249786377
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.10.hook_q_input', index=[:, :, 2]) with attribution 0.04692140594124794
NOT PRUNING upstream_component=ModelComp



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_v_input', index=[:, :, 9]) with attribution 0.13993321359157562
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_v_input', index=[:, :, 9]) with attribution 0.13993321359157562
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_v_input', index=[:, :, 9]) with attribution 0.13993321359157562
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_v_input', index=[:, :, 9]) with attribution 0.13993321359157562
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.attn.hook_result', index=[:, :, 10]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_v_input', index=[:, :, 3]) with attribution 0.04508884251117706




NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_k_input', index=[:, :, 9]) with attribution 0.05806328356266022
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_k_input', index=[:, :, 9]) with attribution 0.05806328356266022
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_k_input', index=[:, :, 9]) with attribution 0.04722271114587784
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_k_input', index=[:, :, 9]) with attribution 0.04722271114587784
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.attn.hook_result', index=[:, :, 10]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_q_input', index=[:, :, 9]) with attribution 0.14157964289188385
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_q_input', index=[:, :, 9]) with attribution 0.16346758604049683
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_q_input', index=[:, :, 9]) with attribution 0.16346758604049683
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.attn.hook_result', index=[:, :, 3]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_q_input', index=[:, :, 9]) with attribution 0.07695141434669495
NOT PRUNING upstream_component=ModelCompone



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.8.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_q_input', index=[:, :, 4]) with attribution 0.06745767593383789
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_q_input', index=[:, :, 4]) with attribution 0.045243315398693085
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_q_input', index=[:, :, 4]) with attribution 0.045243315398693085
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.attn.hook_result', index=[:, :, 3]) downstream_component=ModelComponent(hook_point_name='blocks.9.hook_q_input', index=[:, :, 4]) with attribution 0.04067771136760712
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.8.hook_mlp_in', index=[:]) with attribution 0.051792435348033905
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.8.hook_v_input', index=[:, :, 10]) with attribution 0.063717320561409
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.8.hook_v_input', index=[:, :, 10]) with attribution 0.07570300251245499
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.8.hook_v_input', index=[:, :, 10]) with attribution 0.07570300251245499
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.hook_ml



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.3.attn.hook_result', index=[:, :, 0]) downstream_component=ModelComponent(hook_point_name='blocks.8.hook_v_input', index=[:, :, 6]) with attribution 0.09059502929449081
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.3.attn.hook_result', index=[:, :, 0]) downstream_component=ModelComponent(hook_point_name='blocks.8.hook_v_input', index=[:, :, 6]) with attribution 0.09059502929449081
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.3.attn.hook_result', index=[:, :, 0]) downstream_component=ModelComponent(hook_point_name='blocks.8.hook_v_input', index=[:, :, 6]) with attribution 0.09059502929449081
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.3.attn.hook_result', index=[:, :, 0]) downstream_component=ModelComponent(hook_point_name='blocks.8.hook_v_input', index=[:, :, 6]) with attribution 0.09059502929449081
NOT PRUNING upstream_component=ModelComponen



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.8.hook_k_input', index=[:, :, 10]) with attribution 0.11518697440624237
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.8.hook_k_input', index=[:, :, 10]) with attribution 0.04899803549051285
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.8.hook_k_input', index=[:, :, 10]) with attribution 0.04899803549051285
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.8.hook_k_input', index=[:, :, 10]) with attribution 0.08211716264486313
NOT PRUNING upstream_component=ModelComponent(hook



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.7.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_in', index=[:]) with attribution 0.13920047879219055
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_in', index=[:]) with attribution 0.07414314895868301
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_in', index=[:]) with attribution 0.11392048746347427
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_mlp_in', index=[:]) with attribution 0.10119417309761047
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.attn.hook_



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.attn.hook_result', index=[:, :, 0]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_v_input', index=[:, :, 3]) with attribution 0.16227322816848755
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.attn.hook_result', index=[:, :, 5]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_v_input', index=[:, :, 3]) with attribution 0.06290838122367859
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.attn.hook_result', index=[:, :, 5]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_v_input', index=[:, :, 3]) with attribution 0.06290838122367859
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.4.attn.hook_result', index=[:, :, 3]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_v_input', index=[:, :, 3]) with attribution 0.08481753617525101
NOT PRUNING upstream_component=ModelComponen



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_k_input', index=[:, :, 3]) with attribution 0.040338583290576935
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_k_input', index=[:, :, 3]) with attribution 0.040338583290576935
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_k_input', index=[:, :, 3]) with attribution 0.040338583290576935
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_k_input', index=[:, :, 3]) with attribution 0.040338583290576935
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.6.attn.hook_result', index=[:, :, 6]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_q_input', index=[:, :, 1]) with attribution 0.08352181315422058
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_q_input', index=[:, :, 1]) with attribution 0.0961487665772438
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_q_input', index=[:, :, 1]) with attribution 0.0961487665772438
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.attn.hook_result', index=[:, :, 5]) downstream_component=ModelComponent(hook_point_name='blocks.7.hook_q_input', index=[:, :, 1]) with attribution 0.18271636962890625
NOT PRUNING upstream_component=ModelComponent(



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.4.attn.hook_result', index=[:, :, 3]) downstream_component=ModelComponent(hook_point_name='blocks.6.hook_v_input', index=[:, :, 0]) with attribution 0.04288163781166077
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.3.attn.hook_result', index=[:, :, 0]) downstream_component=ModelComponent(hook_point_name='blocks.6.hook_v_input', index=[:, :, 0]) with attribution 0.05708121880888939
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.3.attn.hook_result', index=[:, :, 0]) downstream_component=ModelComponent(hook_point_name='blocks.6.hook_v_input', index=[:, :, 0]) with attribution 0.05708121880888939
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.3.attn.hook_result', index=[:, :, 0]) downstream_component=ModelComponent(hook_point_name='blocks.6.hook_v_input', index=[:, :, 0]) with attribution 0.05708121880888939
NOT PRUNING upstream_component=ModelComponen



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.6.hook_q_input', index=[:, :, 9]) with attribution 0.17691349983215332
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.attn.hook_result', index=[:, :, 5]) downstream_component=ModelComponent(hook_point_name='blocks.6.hook_q_input', index=[:, :, 9]) with attribution 0.4667830467224121
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.4.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.6.hook_q_input', index=[:, :, 9]) with attribution 0.2313012331724167
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.4.attn.hook_result', index=[:, :, 4]) downstream_component=ModelComponent(hook_point_name='blocks.6.hook_q_input', index=[:, :, 9]) with attribution 0.08709018677473068
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blo



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.attn.hook_result', index=[:, :, 9]) downstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_in', index=[:]) with attribution 0.13820289075374603
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.attn.hook_result', index=[:, :, 8]) downstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_in', index=[:]) with attribution 0.11275483667850494
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.5.attn.hook_result', index=[:, :, 5]) downstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_in', index=[:]) with attribution 0.6042169332504272
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.4.attn.hook_result', index=[:, :, 11]) downstream_component=ModelComponent(hook_point_name='blocks.5.hook_mlp_in', index=[:]) with attribution 0.33682382106781006
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.4.



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.5.hook_k_input', index=[:, :, 9]) with attribution 0.04324762150645256
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.5.hook_k_input', index=[:, :, 9]) with attribution 0.04324762150645256
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.5.hook_k_input', index=[:, :, 9]) with attribution 0.04324762150645256
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.5.hook_k_input', index=[:, :, 9]) with attribution 0.04324762150645256
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.4.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.5.hook_q_input', index=[:, :, 5]) with attribution 0.07133670151233673
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.3.attn.hook_result', index=[:, :, 0]) downstream_component=ModelComponent(hook_point_name='blocks.5.hook_q_input', index=[:, :, 5]) with attribution 0.08435379713773727
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.3.attn.hook_result', index=[:, :, 0]) downstream_component=ModelComponent(hook_point_name='blocks.5.hook_q_input', index=[:, :, 5]) with attribution 0.08435379713773727
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.1.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.5.hook_q_input', index=[:, :, 5]) with attribution 0.051651667803525925
NOT PRUNING upstream_component=ModelComponent(hook_point_name='



NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.4.hook_q_input', index=[:, :, 4]) with attribution 0.05104396492242813
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.4.hook_q_input', index=[:, :, 4]) with attribution 0.05104396492242813
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.4.hook_q_input', index=[:, :, 4]) with attribution 0.05104396492242813
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.4.hook_q_input', index=[:, :, 4]) with attribution 0.05104396492242813
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.2.hook_mlp_out

100%|██████████| 445/445 [00:10<00:00, 41.59it/s] 


NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.3.hook_q_input', index=[:, :, 1]) with attribution 0.06044726073741913
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.0.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.3.hook_q_input', index=[:, :, 1]) with attribution 0.06044726073741913
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.2.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.3.hook_q_input', index=[:, :, 0]) with attribution 0.09276579320430756
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.1.hook_mlp_out', index=[:]) downstream_component=ModelComponent(hook_point_name='blocks.3.hook_q_input', index=[:, :, 0]) with attribution 0.10897935181856155
NOT PRUNING upstream_component=ModelComponent(hook_point_name='blocks.1.hook_mlp_out

  0%|          | 0/1 [00:28<?, ?it/s]

No edge 3366
No edge 3366
No edge 3366
No edge 3366





KeyboardInterrupt: 

# Save Data

In [None]:
import json

for thresh in pruned_heads.keys():
   pruned_heads[thresh][0] = list(pruned_heads[thresh][0])
   pruned_heads[thresh][1] = list(pruned_heads[thresh][1])

cleaned_attrs = {}
for thresh in pruned_attrs.keys():
    cleaned_attrs[thresh] = {}
    for (layer, head), attr in pruned_attrs[thresh].items():
        cleaned_attrs[thresh][f'L{layer}H{head}'] = attr
        
with open(f'{RUN_NAME}_pruned_heads.json', 'w') as f:
    json.dump(pruned_heads, f)
with open(f'{RUN_NAME}_num_passes.json', 'w') as f:
    json.dump(num_passes, f)
with open(f'{RUN_NAME}_pruned_attrs.json', 'w') as f:
    json.dump(cleaned_attrs, f)