In [1]:
%load_ext autoreload
%autoreload 2

import torch
from circuits_benchmark.utils.get_cases import get_cases
import iit.model_pairs as mp
from circuits_benchmark.utils.ll_model_loader.ll_model_loader_factory import get_ll_model_loader
from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase
import os

task_idx = '3'
out_dir = f'results/tuned_lens/{task_idx}'
os.makedirs(out_dir, exist_ok=True)

task: BenchmarkCase = get_cases(indices=[task_idx])[0]

ll_model_loader = get_ll_model_loader(task, interp_bench=True)
hl_ll_corr, model = ll_model_loader.load_ll_model_and_correspondence(device='cuda' if torch.cuda.is_available() else 'cpu')
# turn off grads
model.eval()
model.requires_grad_(False)

hl_model = task.get_hl_model()
model_pair = mp.StrictIITModelPair(hl_model, model, hl_ll_corr)

{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_attn_scores': HookPoint(), 'b

In [2]:
%%capture
max_len = 100
unique_test_data = task.get_clean_data(max_samples=max_len, unique_data=True)

loader = torch.utils.data.DataLoader(unique_test_data, batch_size=256, shuffle=False, drop_last=False)

In [4]:
from interp_utils.resample_ablate.hook_maker import make_scaled_ablation_hook
from interp_utils.resample_ablate.get_ablation_effect import get_ablation_effects_for_scales

scales = [0.0, 0.1, 0.2, 0.5, 0.7, 0.8, 1.0, 1.2, 1.4, 2.0]
combined_scales_df = get_ablation_effects_for_scales(model_pair, unique_test_data, 
                                            hook_maker=make_scaled_ablation_hook,
                                            scales=scales)

Running scale 0.0



100%|██████████| 40/40 [00:06<00:00,  5.84it/s]
100%|██████████| 40/40 [00:01<00:00, 23.86it/s]


Running scale 0.1



100%|██████████| 40/40 [00:06<00:00,  6.32it/s]
100%|██████████| 40/40 [00:01<00:00, 24.40it/s]


Running scale 0.2



100%|██████████| 40/40 [00:06<00:00,  6.30it/s]
100%|██████████| 40/40 [00:01<00:00, 24.55it/s]


Running scale 0.5



100%|██████████| 40/40 [00:06<00:00,  6.31it/s]
100%|██████████| 40/40 [00:01<00:00, 24.65it/s]


Running scale 0.7



100%|██████████| 40/40 [00:06<00:00,  6.25it/s]
100%|██████████| 40/40 [00:01<00:00, 24.74it/s]


Running scale 0.8



100%|██████████| 40/40 [00:06<00:00,  6.13it/s]
100%|██████████| 40/40 [00:01<00:00, 24.51it/s]


Running scale 1.0



100%|██████████| 40/40 [00:06<00:00,  5.82it/s]
100%|██████████| 40/40 [00:01<00:00, 24.15it/s]


Running scale 1.2



100%|██████████| 40/40 [00:06<00:00,  6.22it/s]
100%|██████████| 40/40 [00:01<00:00, 24.12it/s]


Running scale 1.4



100%|██████████| 40/40 [00:06<00:00,  6.30it/s]
100%|██████████| 40/40 [00:01<00:00, 24.34it/s]


Running scale 2.0



100%|██████████| 40/40 [00:06<00:00,  6.30it/s]
100%|██████████| 40/40 [00:01<00:00, 24.70it/s]


Running scale 5.0



100%|██████████| 40/40 [00:06<00:00,  6.16it/s]
100%|██████████| 40/40 [00:01<00:00, 23.89it/s]


In [5]:
combined_scales_df.rename(columns={"scale 0.0_y": "scale 0.0"}, inplace=True)
combined_scales_df = combined_scales_df.sort_values(by=["status"], ascending=False)
combined_scales_df

Unnamed: 0,node,status,scale 0.0,scale 0.1,scale 0.2,scale 0.5,scale 0.7,scale 0.8,scale 1.0,scale 1.2,scale 1.4,scale 2.0,scale 5.0
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.003995,0.175552
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.8e-05
6,"blocks.1.attn.hook_result, head 1",not_in_circuit,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.349343
7,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.069386
8,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,blocks.1.mlp.hook_post,not_in_circuit,0.146051,0.08022,0.033787,0.001958,8.8e-05,0.0,0.0,0.0,0.0,0.055231,0.658647
4,blocks.0.mlp.hook_post,in_circuit,0.743146,0.805823,0.894021,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.997797
5,"blocks.1.attn.hook_result, head 0",in_circuit,0.754735,0.76823,0.852649,0.997337,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [16]:
from interp_utils.resample_ablate.plot_utils import plot_causal_effect

out_dir = f"./interp_results/case_{task.get_name()}/resample_ablate_results"
plot_causal_effect(combined_scales_df, scales, image_name=f"causal_effect_scale_{task.get_name()}", out_dir=out_dir)


plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Line
  - plotly.graph_objs.layout.shape.Line
  - etc.




In [20]:
from interp_utils.resample_ablate.collect_cache import collect_activations, collect_pca_directions
loader = torch.utils.data.DataLoader(unique_test_data, batch_size=1024, shuffle=False, drop_last=False)

activation_cache = collect_activations(model_pair, loader=loader)
pca_dirs = collect_pca_directions(activation_cache, num_pca_components=2)
activation_cache.keys()

dict_keys([LLNode(name='blocks.0.attn.hook_result', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.0.attn.hook_result', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.0.attn.hook_result', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.0.attn.hook_result', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None), LLNode(name='blocks.1.attn.hook_result', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_result', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_result', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_result', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)])

In [23]:
from interp_utils.resample_ablate.hook_maker import make_hook
combined_scales_df_orthogonal = {}


for self_patch in [True, False]:
    for ablate_high_variance in [True, False]:
        hook_maker = make_hook(self_patch, ablate_high_variance)
        combined_scales_df_orthogonal[(self_patch, ablate_high_variance)] = get_ablation_effects_for_scales(
            model_pair, 
            unique_test_data, 
            hook_maker=hook_maker,
            scales=scales)

Running scale 0.0



100%|██████████| 40/40 [00:07<00:00,  5.70it/s]
100%|██████████| 40/40 [00:01<00:00, 21.52it/s]


Running scale 0.1



100%|██████████| 40/40 [00:07<00:00,  5.23it/s]
100%|██████████| 40/40 [00:02<00:00, 18.64it/s]


Running scale 0.2



100%|██████████| 40/40 [00:06<00:00,  5.83it/s]
100%|██████████| 40/40 [00:01<00:00, 22.26it/s]


Running scale 0.5



100%|██████████| 40/40 [00:06<00:00,  5.94it/s]
100%|██████████| 40/40 [00:01<00:00, 21.63it/s]


Running scale 0.7



100%|██████████| 40/40 [00:06<00:00,  5.93it/s]
100%|██████████| 40/40 [00:01<00:00, 23.11it/s]


Running scale 0.8



100%|██████████| 40/40 [00:06<00:00,  6.49it/s]
100%|██████████| 40/40 [00:01<00:00, 25.27it/s]


Running scale 1.0



100%|██████████| 40/40 [00:07<00:00,  5.31it/s]
100%|██████████| 40/40 [00:01<00:00, 20.60it/s]


Running scale 1.2



100%|██████████| 40/40 [00:06<00:00,  6.41it/s]
100%|██████████| 40/40 [00:01<00:00, 24.61it/s]


Running scale 1.4



100%|██████████| 40/40 [00:06<00:00,  6.48it/s]
100%|██████████| 40/40 [00:01<00:00, 22.38it/s]


Running scale 2.0



100%|██████████| 40/40 [00:07<00:00,  5.57it/s]
100%|██████████| 40/40 [00:01<00:00, 21.87it/s]


Running scale 5.0



100%|██████████| 40/40 [00:06<00:00,  5.83it/s]
100%|██████████| 40/40 [00:01<00:00, 22.32it/s]


Running scale 0.0



100%|██████████| 40/40 [00:06<00:00,  5.74it/s]
100%|██████████| 40/40 [00:01<00:00, 22.02it/s]


Running scale 0.1



100%|██████████| 40/40 [00:07<00:00,  5.45it/s]
100%|██████████| 40/40 [00:01<00:00, 21.30it/s]


Running scale 0.2



100%|██████████| 40/40 [00:07<00:00,  5.39it/s]
100%|██████████| 40/40 [00:01<00:00, 21.76it/s]


Running scale 0.5



100%|██████████| 40/40 [00:07<00:00,  5.38it/s]
100%|██████████| 40/40 [00:01<00:00, 22.36it/s]


Running scale 0.7



100%|██████████| 40/40 [00:07<00:00,  5.45it/s]
100%|██████████| 40/40 [00:01<00:00, 20.99it/s]


Running scale 0.8



100%|██████████| 40/40 [00:07<00:00,  5.48it/s]
100%|██████████| 40/40 [00:01<00:00, 20.52it/s]


Running scale 1.0



100%|██████████| 40/40 [00:09<00:00,  4.28it/s]
100%|██████████| 40/40 [00:02<00:00, 17.77it/s]


Running scale 1.2



100%|██████████| 40/40 [00:07<00:00,  5.49it/s]
100%|██████████| 40/40 [00:01<00:00, 21.50it/s]


Running scale 1.4



100%|██████████| 40/40 [00:07<00:00,  5.10it/s]
100%|██████████| 40/40 [00:01<00:00, 21.95it/s]


Running scale 2.0



100%|██████████| 40/40 [00:07<00:00,  5.46it/s]
100%|██████████| 40/40 [00:02<00:00, 19.63it/s]


Running scale 5.0



100%|██████████| 40/40 [00:07<00:00,  5.35it/s]
100%|██████████| 40/40 [00:02<00:00, 19.96it/s]


Running scale 0.0



100%|██████████| 40/40 [00:07<00:00,  5.45it/s]
100%|██████████| 40/40 [00:01<00:00, 20.80it/s]


Running scale 0.1



100%|██████████| 40/40 [00:07<00:00,  5.31it/s]
100%|██████████| 40/40 [00:01<00:00, 20.73it/s]


Running scale 0.2



100%|██████████| 40/40 [00:07<00:00,  5.43it/s]
100%|██████████| 40/40 [00:01<00:00, 20.44it/s]


Running scale 0.5



100%|██████████| 40/40 [00:07<00:00,  5.08it/s]
100%|██████████| 40/40 [00:01<00:00, 22.60it/s]


Running scale 0.7



100%|██████████| 40/40 [00:07<00:00,  5.24it/s]
100%|██████████| 40/40 [00:01<00:00, 21.36it/s]


Running scale 0.8



100%|██████████| 40/40 [00:07<00:00,  5.27it/s]
100%|██████████| 40/40 [00:02<00:00, 19.73it/s]


Running scale 1.0



100%|██████████| 40/40 [00:07<00:00,  5.41it/s]
100%|██████████| 40/40 [00:02<00:00, 17.50it/s]


Running scale 1.2



100%|██████████| 40/40 [00:08<00:00,  4.88it/s]
100%|██████████| 40/40 [00:02<00:00, 19.70it/s]


Running scale 1.4



100%|██████████| 40/40 [00:07<00:00,  5.29it/s]
100%|██████████| 40/40 [00:01<00:00, 22.08it/s]


Running scale 2.0



100%|██████████| 40/40 [00:07<00:00,  5.39it/s]
100%|██████████| 40/40 [00:01<00:00, 20.05it/s]


Running scale 5.0



100%|██████████| 40/40 [00:07<00:00,  5.41it/s]
100%|██████████| 40/40 [00:01<00:00, 21.70it/s]


Running scale 0.0



100%|██████████| 40/40 [00:07<00:00,  5.68it/s]
100%|██████████| 40/40 [00:01<00:00, 23.89it/s]


Running scale 0.1



100%|██████████| 40/40 [00:06<00:00,  6.20it/s]
100%|██████████| 40/40 [00:01<00:00, 25.07it/s]


Running scale 0.2



100%|██████████| 40/40 [00:06<00:00,  6.10it/s]
100%|██████████| 40/40 [00:01<00:00, 22.01it/s]


Running scale 0.5



100%|██████████| 40/40 [00:07<00:00,  5.60it/s]
100%|██████████| 40/40 [00:01<00:00, 22.35it/s]


Running scale 0.7



100%|██████████| 40/40 [00:07<00:00,  5.61it/s]
100%|██████████| 40/40 [00:01<00:00, 22.38it/s]


Running scale 0.8



100%|██████████| 40/40 [00:06<00:00,  5.86it/s]
100%|██████████| 40/40 [00:01<00:00, 24.32it/s]


Running scale 1.0



100%|██████████| 40/40 [00:07<00:00,  5.70it/s]
100%|██████████| 40/40 [00:01<00:00, 21.47it/s]


Running scale 1.2



100%|██████████| 40/40 [00:06<00:00,  5.93it/s]
100%|██████████| 40/40 [00:01<00:00, 23.01it/s]


Running scale 1.4



100%|██████████| 40/40 [00:06<00:00,  6.08it/s]
100%|██████████| 40/40 [00:01<00:00, 22.57it/s]


Running scale 2.0



100%|██████████| 40/40 [00:06<00:00,  5.78it/s]
100%|██████████| 40/40 [00:01<00:00, 22.75it/s]


Running scale 5.0



100%|██████████| 40/40 [00:07<00:00,  5.70it/s]
100%|██████████| 40/40 [00:01<00:00, 23.73it/s]


In [25]:
for key, df in combined_scales_df_orthogonal.items():
    df = df.sort_values(by=["status"], ascending=False)
    df = df.rename(columns={"scale 0.0_y": "scale 0.0"})
    combined_scales_df_orthogonal[key] = df

plot_causal_effect(combined_scales_df_orthogonal[(True, True)], scales, f"causal_effect_scale_{task.get_name()}_self_patch_ablate_high_variance", out_dir)
plot_causal_effect(combined_scales_df_orthogonal[(True, False)], scales, f"causal_effect_scale_{task.get_name()}_self_patch_ablate_mean", out_dir)
plot_causal_effect(combined_scales_df_orthogonal[(False, True)], scales, f"causal_effect_scale_{task.get_name()}_other_patched_ablate_high_variance", out_dir)
plot_causal_effect(combined_scales_df_orthogonal[(False, False)], scales, f"causal_effect_scale_{task.get_name()}_other_patched_ablate_mean", out_dir)


plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Line
  - plotly.graph_objs.layout.shape.Line
  - etc.


