In [1]:
%load_ext autoreload
%autoreload 2

import torch
from circuits_benchmark.utils.get_cases import get_cases
import iit.model_pairs as mp
from circuits_benchmark.utils.ll_model_loader.ll_model_loader_factory import get_ll_model_loader
from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase
import os

task_idx = '35'
out_dir = f'results/tuned_lens/{task_idx}'
os.makedirs(out_dir, exist_ok=True)

task: BenchmarkCase = get_cases(indices=[task_idx])[0]

ll_model_loader = get_ll_model_loader(task, interp_bench=True)
hl_ll_corr, model = ll_model_loader.load_ll_model_and_correspondence(device='cuda' if torch.cuda.is_available() else 'cpu')
# turn off grads
model.eval()
model.requires_grad_(False)

hl_model = task.get_hl_model()
model_pair = mp.StrictIITModelPair(hl_model, model, hl_ll_corr)

{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint()}
dict_keys([TracrHLNode(name: blocks.0.mlp.hook_post,
 label: map_1,
 classes: 0,
 index: [:]
)])


In [2]:
max_len = 1000
unique_test_data = task.get_clean_data(max_samples=max_len, unique_data=True)

loader = torch.utils.data.DataLoader(unique_test_data, batch_size=256, shuffle=False, drop_last=False)

In [3]:
from interp_utils.node_stats import get_node_norm_stats

node_norms = get_node_norm_stats(model_pair, loader, return_cache_dict=False)

  node_norms = pd.concat(


In [4]:
import circuits_benchmark.commands.evaluation.iit.iit_eval as eval_node_effect
max_len = 100
node_effects, eval_metrics = eval_node_effect.get_node_effects(case=task, model_pair=model_pair, use_mean_cache=False, max_len=max_len)

100%|██████████| 40/40 [00:05<00:00,  6.84it/s]
100%|██████████| 40/40 [00:00<00:00, 59.40it/s]
100%|██████████| 1/1 [00:00<00:00, 14.23it/s]
100%|██████████| 1/1 [00:00<00:00, 121.78it/s]


In [5]:
# combine node effects with node_norms
import pandas as pd
combined_df = pd.merge(node_effects, node_norms, left_on="node", right_on="name", how="inner")
combined_df.drop(columns=["name", "in_circuit"], inplace=True)
combined_df

Unnamed: 0,node,status,resample_ablate_effect,zero_ablate_effect,norm_cache,norm_std
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.0,0.0,0.420833,0.052473
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.0,0.0,3.20299,0.380507
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.0,0.0,1.145408,0.218464
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.0,0.0,4.230924,0.584415
4,"blocks.1.attn.hook_result, head 0",not_in_circuit,0.0,0.0,41.70228,2.478303
5,"blocks.1.attn.hook_result, head 1",not_in_circuit,0.0,0.0,46.41768,3.397992
6,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.0,0.0,53.01903,3.905091
7,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.0,0.0,19.437596,0.913765
8,blocks.1.mlp.hook_post,not_in_circuit,0.0,0.0,234.350295,16.563889
9,blocks.0.mlp.hook_post,in_circuit,1.0,0.493,39.616982,3.633166


In [6]:
import plotly.express as px

fig = px.scatter(combined_df, x="zero_ablate_effect", 
                 y="norm_cache", color="status",
                 error_y="norm_std",
                 # color map
                 color_discrete_map={
                    "in_circuit": "green",
                    "not_in_circuit": "orange",
                 },
                 labels={
                     "zero_ablate_effect": "Zero Ablation Effect",
                     "norm_cache": "Norm of Node Activation",
                     "status": "",
                     "resample_ablate_effect": "Resample Ablate Effect",
                 },
                 hover_data=["node", "resample_ablate_effect"],
                 # remove background grid and color
                 template="plotly_white",
                 )

# decrease margin
fig.update_layout(margin=dict(l=70, r=70, t=70, b=70))
# increase font size
fig.update_layout(font=dict(size=16))
fig.show()
# save to file as pdf
fig.write_image(f"node_stats_{task.get_name()}.pdf")
