In [1]:
%load_ext autoreload
%autoreload 2

import torch
from circuits_benchmark.utils.get_cases import get_cases
import iit.model_pairs as mp
from circuits_benchmark.utils.ll_model_loader.ll_model_loader_factory import get_ll_model_loader
from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase
import os
from iit.model_pairs.ll_model import LLModel

task_idx = '19'
out_dir = f'results/tuned_lens/{task_idx}'
os.makedirs(out_dir, exist_ok=True)

task: BenchmarkCase = get_cases(indices=[task_idx])[0]

ll_model_loader = get_ll_model_loader(task, interp_bench=True)
hl_ll_corr, model = ll_model_loader.load_ll_model_and_correspondence(device='cuda' if torch.cuda.is_available() else 'cpu')
model.requires_grad_(True)
hl_model = task.get_hl_model()
model_pair = mp.StrictIITModelPair(hl_model, model, hl_ll_corr, training_args={
    "detach_while_caching" : False,
})
model = LLModel(model, detach_while_caching=False)



{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint()}
dict_keys([TracrHLNode(name: blocks.0.attn.hook_result,
 label: shift_by(1)_2,
 classes: 0,
 index: [:, :, 0, :]
), TracrHLNode(name: blocks.0.mlp.hook_post,
 label: sequential_duplicate_removal_1,

In [2]:
max_len = 1000
unique_test_data = task.get_clean_data(max_samples=max_len, unique_data=True)

loader = torch.utils.data.DataLoader(unique_test_data, batch_size=256, shuffle=False, drop_last=False)



In [3]:
natural_ll_model_loader = get_ll_model_loader(task, natural=True, load_from_wandb=True)
_, natural_model = natural_ll_model_loader.load_ll_model_and_correspondence(device='cuda' if torch.cuda.is_available() else 'cpu')
natural_model = LLModel(natural_model, detach_while_caching=False)

Created temporary directory at /var/folders/_k/_46xyqdj165bdcyw79k6758w0000gn/T/tmprx09o1_h


In [4]:
tracr_model = hl_model
tracr_model.requires_grad_(True)
tracr_model = LLModel(tracr_model, detach_while_caching=False)

In [17]:
from interp_utils.node_stats import get_grad_norms

loss_info = get_grad_norms(model, loader, model_pair.loss_fn)
natural_loss_info = get_grad_norms(natural_model, loader, model_pair.loss_fn)
tracr_loss_info = get_grad_norms(tracr_model, loader, model_pair.loss_fn)

100%|██████████| 4/4 [00:00<00:00, 18.13it/s]
100%|██████████| 4/4 [00:00<00:00, 19.07it/s]
100%|██████████| 4/4 [00:00<00:00, 63.30it/s]


In [19]:
import pandas as pd
activation_grads = pd.DataFrame([loss_info['grad_norms'], natural_loss_info['grad_norms']],
                  index=['tuned', 'natural']).T
param_grads = pd.DataFrame([loss_info['param_grad_norms'], natural_loss_info['param_grad_norms']],
                    index=['tuned', 'natural']).T
losses = pd.DataFrame([loss_info['loss'], natural_loss_info['loss']], index=['tuned', 'natural'], columns=['loss'])

In [7]:
losses

Unnamed: 0,loss
tuned,0.0
natural,7e-06


In [24]:
pd.Series(tracr_loss_info['grad_norms'])

hook_embed                        4.423489e-04
hook_pos_embed                    4.423489e-04
blocks.0.hook_resid_pre           4.423489e-04
blocks.0.hook_q_input             2.826291e-08
blocks.0.hook_k_input             5.957184e-08
blocks.0.hook_v_input             1.579227e-04
blocks.0.attn.hook_q              3.584334e-10
blocks.0.attn.hook_k              5.957184e-08
blocks.0.attn.hook_v              1.579227e-04
blocks.0.attn.hook_attn_scores    1.433734e-09
blocks.0.attn.hook_pattern        4.180968e-04
blocks.0.attn.hook_z              1.579232e-04
blocks.0.attn.hook_result         3.530687e-04
blocks.0.hook_attn_out            3.530687e-04
blocks.0.hook_resid_mid           3.530687e-04
blocks.0.hook_mlp_in              3.528860e-04
blocks.0.mlp.hook_pre             1.119208e-04
blocks.0.mlp.hook_post            1.606131e-05
blocks.0.hook_mlp_out             1.135706e-05
blocks.0.hook_resid_post          1.135706e-05
dtype: float64

In [16]:
activation_grads

Unnamed: 0,tuned,natural
hook_embed,2.5338869999999998e-19,4.208421e-07
hook_pos_embed,2.5338869999999998e-19,4.208421e-07
blocks.0.hook_resid_pre,2.5338869999999998e-19,4.208421e-07
blocks.0.hook_q_input,8.589705e-20,3.429198e-08
blocks.0.hook_k_input,8.818481e-20,4.27478e-08
blocks.0.hook_v_input,2.7128709999999997e-19,1.809127e-07
blocks.0.attn.hook_q,4.28922e-20,2.616654e-08
blocks.0.attn.hook_k,3.816347e-20,3.70591e-08
blocks.0.attn.hook_v,1.433786e-19,1.682733e-07
blocks.0.attn.hook_attn_scores,4.10472e-20,6.282845e-08
