### Code to setup InterpBench

In [1]:
!apt-get update -q && apt-get install -y --no-install-recommends libgl1-mesa-glx graphviz graphviz-dev

Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:5 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:6 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Get:7 http://archive.ubuntu.com/ubuntu jammy-updates/restricted amd64 Packages [2,537 kB]
Get:8 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 Packages [1,391 kB]
Hit:9 https://ppa.launchpadcontent.net/c2d4u.team/c2d4u4.0+/ubuntu jammy InRelease
Get:10 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 Packages [2,179 kB]
Hit:11 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:12 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:13 https://ppa.launchpadcontent.net/ubuntugis

In [4]:
!git clone --recurse-submodules https://github.com/FlyingPumba/circuits-benchmark.git
%cd circuits-benchmark
!pip install -e .

Cloning into 'circuits-benchmark'...
remote: Enumerating objects: 4234, done.[K
remote: Counting objects: 100% (405/405), done.[K
remote: Compressing objects: 100% (182/182), done.[K
remote: Total 4234 (delta 277), reused 276 (delta 221), pack-reused 3829[K
Receiving objects: 100% (4234/4234), 16.91 MiB | 23.79 MiB/s, done.
Resolving deltas: 100% (2969/2969), done.
Submodule 'submodules/Automatic-Circuit-Discovery' (https://github.com/FlyingPumba/Automatic-Circuit-Discovery.git) registered for path 'submodules/Automatic-Circuit-Discovery'
Submodule 'submodules/iit' (https://github.com/cybershiptrooper/iit.git) registered for path 'submodules/iit'
Submodule 'submodules/tracr' (https://github.com/FlyingPumba/tracr.git) registered for path 'submodules/tracr'
Cloning into '/content/circuits-benchmark/submodules/Automatic-Circuit-Discovery'...
remote: Enumerating objects: 9802, done.        
remote: Counting objects: 100% (4014/4014), done.        
remote: Compressing objects: 100% (931

### Load model from InterpBench

In [1]:
!git lfs install
!git clone https://huggingface.co/cybershiptrooper/InterpBench

Updated Git hooks.
Git LFS initialized.
Cloning into 'InterpBench'...
remote: Enumerating objects: 215, done.[K
remote: Counting objects: 100% (211/211), done.[K
remote: Compressing objects: 100% (195/195), done.[K
remote: Total 215 (delta 67), reused 0 (delta 0), pack-reused 4 (from 1)[K
Receiving objects: 100% (215/215), 383.70 KiB | 7.67 MiB/s, done.
Resolving deltas: 100% (67/67), done.
Filtering content: 100% (55/55), 82.13 MiB | 24.25 MiB/s, done.


In [1]:
import pickle
import torch
from transformer_lens import HookedTransformerConfig, HookedTransformer
from transformer_lens import HookedTransformer

In [3]:
import circuits_benchmark.benchmark.cases.case_3 as case_3

task = case_3.Case3()
task_idx = task.get_index()

In [4]:
dir_name = f"InterpBench/{task_idx}"
cfg_dict = pickle.load(open(f"{dir_name}/ll_model_cfg.pkl", "rb"))
cfg = HookedTransformerConfig.from_dict(cfg_dict)
model = HookedTransformer(cfg)
weights = torch.load(f"{dir_name}/ll_model.pth")
model.load_state_dict(weights)

<All keys matched successfully>

In [5]:
# turn off grads
model.eval()
model.requires_grad_(False)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x2cdc87750>

### Run evaluations

In [6]:
# load high level model
from circuits_benchmark.utils.iit import make_iit_hl_model
import circuits_benchmark.utils.iit.correspondence as correspondence
import iit.model_pairs as mp

def make_model_pair(benchmark_case):
    hl_model = benchmark_case.build_transformer_lens_model()
    hl_model = make_iit_hl_model(hl_model, eval_mode=True)
    tracr_output = benchmark_case.get_tracr_output()
    hl_ll_corr = correspondence.TracrCorrespondence.from_output(
            case=benchmark_case, tracr_output=tracr_output
        )
    model_pair = mp.StrictIITModelPair(hl_model, model, hl_ll_corr)
    return model_pair

In [7]:
# evaluate models
import circuits_benchmark.commands.evaluation.iit.iit_eval as eval_node_effect

# model_pair = make_model_pair(task)
args = eval_node_effect.setup_args_parser(None, True)
args

Namespace(indices=None, output_dir='/Users/cybershiptrooper/src/interpretability/MATS/circuits-benchmark/results', device='cpu', seed=1234, weights='510', mean=True, save_to_wandb=False, batch_size=512, categorical_metric='accuracy', load_from_wandb=False, max_len=1000)

In [8]:
args.max_len = 100 # make this smaller or larger to include more or less datapoints 

In [12]:
model_pair = make_model_pair(task)
node_effects, eval_metrics = eval_node_effect.get_node_effects(case=task, model_pair=model_pair, args=args, use_mean_cache=True)

Moving model to device:  cpu
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_

100%|██████████| 40/40 [00:03<00:00, 10.40it/s]
100%|██████████| 40/40 [00:01<00:00, 34.26it/s]
100%|██████████| 1/1 [00:00<00:00, 42.97it/s]
100%|██████████| 5/5 [00:00<00:00, 15.12it/s]
100%|██████████| 5/5 [00:00<00:00, 47.98it/s]


In [13]:
node_effects

Unnamed: 0,node,status,resample_ablate_effect,mean_ablate_effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.005972,0.0
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.0,0.0
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.0,0.0
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.000174,0.0
4,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.0,0.0
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.076713,0.0
6,blocks.1.mlp.hook_post,not_in_circuit,0.029705,0.0
7,"blocks.1.attn.hook_result, head :2",in_circuit,1.0,0.71375
8,blocks.0.mlp.hook_post,in_circuit,1.0,0.71375


In [14]:
print(eval_metrics)

val/iit_loss: 0.0001
val/IIA: 99.99%
val/accuracy: 100.00%
val/strict_accuracy: 100.00%


### Run ACDC

In [15]:
from circuits_benchmark.commands.evaluation.iit.acdc_utils import ACDCRunner
acdc_runner = ACDCRunner.make_default_runner(task=task_idx)
acdc_runner.args

Namespace(weights='100_100_40', output_dir='./results/acdc_3/weight_100_100_40/threshold_0.025', device='cpu', threshold=0.025, data_size=1000, using_wandb=False, load_from_wandb=False, include_mlp=False, next_token=False, use_pos_embed=False, first_cache_cpu='True', second_cache_cpu='True', zero_ablation=False, wandb_entity_name='remix_school-of-rock', wandb_group_name='default', wandb_project_name='acdc', wandb_run_name=None, wandb_dir='/tmp/wandb', wandb_mode='online', indices_mode='normal', names_mode='normal', torch_num_threads=0, max_num_epochs=100000, single_step=False, abs_value_threshold=False, images_output_dir='./results/acdc_3/weight_100_100_40/threshold_0.025/images')

In [18]:
from circuits_benchmark.transformers.hooked_tracr_transformer import (
    HookedTracrTransformer,
)
tl_model = HookedTracrTransformer(
    model.cfg,
    model_pair.hl_model.tracr_input_encoder,
    model_pair.hl_model.tracr_output_encoder,
    model_pair.hl_model.residual_stream_labels,
    remove_extra_tensor_cloning=False,
)

In [21]:
metric_name = "kl" if model_pair.hl_model.is_categorical() else "l2"
data_size = acdc_runner.args.data_size
validation_metric = task.get_validation_metric(metric_name, tl_model, data_size=data_size)
toks_int_values = task.get_clean_data(count=data_size).get_inputs()
toks_int_values_other = task.get_corrupted_data(count=data_size).get_inputs()
acdc_circuit, acdc_experiment = acdc_runner.run_acdc(
    tl_model=tl_model,
    clean_dataset=toks_int_values,
    corrupt_dataset=toks_int_values_other,
    validation_metric=validation_metric,
)



Moving model to device:  cpu
dict_keys(['blocks.1.hook_resid_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_mlp_in', 'blocks.1.attn.hook_result', 'blocks.1.attn.hook_q', 'blocks.1.hook_q_input', 'blocks.1.attn.hook_k', 'blocks.1.hook_k_input', 'blocks.1.attn.hook_v', 'blocks.1.hook_v_input', 'blocks.0.hook_mlp_out', 'blocks.0.hook_mlp_in', 'blocks.0.attn.hook_result', 'blocks.0.attn.hook_q', 'blocks.0.hook_q_input', 'blocks.0.attn.hook_k', 'blocks.0.hook_k_input', 'blocks.0.attn.hook_v', 'blocks.0.hook_v_input', 'blocks.0.hook_resid_pre'])
blocks.1.hook_resid_post
blocks.1.hook_mlp_out
blocks.1.mlp.hook_post
blocks.1.mlp.hook_pre
blocks.1.hook_mlp_in
blocks.1.hook_resid_mid
blocks.1.hook_attn_out
blocks.1.attn.hook_result
blocks.1.attn.hook_z
blocks.1.attn.hook_pattern
blocks.1.attn.hook_attn_scores
blocks.1.attn.hook_v
blocks.1.attn.hook_k
blocks.1.attn.hook_q
blocks.1.hook_v_input
blocks.1.hook_k_input
blocks.1.hook_q_input
blocks.1.hook_resid_pre
blocks.0.hook_resid_post
blocks.0.ho



In [30]:
# load edges from InterpBench
from circuits_benchmark.utils.circuits_comparison import calculate_fpr_and_tpr, Circuit
from acdc.TLACDCCorrespondence import TLACDCCorrespondence
from circuits_benchmark.transformers.acdc_circuit_builder import build_acdc_circuit


def get_tpr_fpr_for_acdc_circuit(acdc_circuit):
    edges = pickle.load(open(f"{dir_name}/edges.pkl", "rb"))
    gt_circuit = Circuit()
    for edge in edges:
        gt_circuit.add_edge(edge[0], edge[1])

    full_corr = TLACDCCorrespondence.setup_from_model(
                tl_model, use_pos_embed=True
            )
    full_circuit = build_acdc_circuit(full_corr)
    return calculate_fpr_and_tpr(acdc_circuit, gt_circuit, full_circuit)

get_tpr_fpr_for_acdc_circuit(acdc_circuit)


Summary:
 - Nodes TP rate: 0.5
 - Nodes FP rate: 0.14285714285714285
 - Edges TP rate: 0.0
 - Edges FP rate: 0.0


{'nodes': {'true_positive': {blocks.0.hook_mlp_out, blocks.1.hook_resid_post},
  'false_positive': {blocks.1.hook_mlp_out},
  'false_negative': {blocks.1.attn.hook_result[0],
   blocks.1.attn.hook_result[1]},
  'true_negative': {blocks.0.attn.hook_result[0],
   blocks.0.attn.hook_result[1],
   blocks.0.attn.hook_result[2],
   blocks.0.attn.hook_result[3],
   blocks.1.attn.hook_result[2],
   blocks.1.attn.hook_result[3]},
  'tpr': 0.5,
  'fpr': 0.14285714285714285},
 'edges': {'true_positive': set(),
  'false_positive': set(),
  'false_negative': {(blocks.0.hook_mlp_out, blocks.1.attn.hook_result[0]),
   (blocks.0.hook_mlp_out, blocks.1.attn.hook_result[1]),
   (hook_embed, blocks.0.hook_mlp_out),
   (hook_pos_embed, blocks.1.attn.hook_result[0]),
   (hook_pos_embed, blocks.1.attn.hook_result[1])},
  'true_negative': {(blocks.0.attn.hook_result[0], blocks.0.hook_mlp_out),
   (blocks.0.attn.hook_result[0], blocks.1.attn.hook_result[0]),
   (blocks.0.attn.hook_result[0], blocks.1.attn.hoo