In [10]:
%load_ext autoreload
%autoreload 2

import torch
from circuits_benchmark.utils.get_cases import get_cases
import iit.model_pairs as mp
from circuits_benchmark.utils.ll_model_loader.ll_model_loader_factory import get_ll_model_loader
from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase
import os
from iit.model_pairs.ll_model import LLModel

task_idx = '11'
out_dir = f'results/tuned_lens/{task_idx}'
os.makedirs(out_dir, exist_ok=True)

task: BenchmarkCase = get_cases(indices=[task_idx])[0]

ll_model_loader = get_ll_model_loader(task, interp_bench=True)
hl_ll_corr, model = ll_model_loader.load_ll_model_and_correspondence(device='cuda' if torch.cuda.is_available() else 'cpu')
model.requires_grad_(True)
hl_model = task.get_hl_model()
model_pair = mp.StrictIITModelPair(hl_model, model, hl_ll_corr, training_args={
    "detach_while_caching" : False,
})
model = LLModel(model, detach_while_caching=False)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint()

In [2]:
max_len = 1000
unique_test_data = task.get_clean_data(max_samples=max_len, unique_data=True)

loader = torch.utils.data.DataLoader(unique_test_data, batch_size=256, shuffle=False, drop_last=False)

In [18]:
natural_ll_model_loader = get_ll_model_loader(task, natural=True, load_from_wandb=True)
_, natural_model = natural_ll_model_loader.load_ll_model_and_correspondence(device='cuda' if torch.cuda.is_available() else 'cpu')
natural_model = LLModel(natural_model, detach_while_caching=False)

In [29]:
tracr_model = hl_model
tracr_model.requires_grad_(True)
tracr_model = LLModel(tracr_model, detach_while_caching=False)

In [30]:
from tqdm import tqdm
import numpy as np

def get_grad_norms(model, loader, loss_fn):
    grad_norms = {}
    param_grad_norms = {}
    losses = []
    for x, y in tqdm(loader):
        logits, cache = model.run_with_cache(x)
        loss = loss_fn(logits, y)
        model.zero_grad()
        loss.backward()
        losses.append(loss.item())
        for k, v in cache.items():
            if k not in grad_norms:
                grad_norms[k] = v.grad.mean(dim=0) / len(loader)
            else:
                grad_norms[k] += v.grad.mean(dim=0) / len(loader)
        
        for k, v in model.named_parameters():
            if k not in param_grad_norms:
                param_grad_norms[k] = v.grad.mean(dim=0) / len(loader)
            else:
                param_grad_norms[k] += v.grad.mean(dim=0) / len(loader)

    for k, v in grad_norms.items():
        grad_norms[k] = v.norm().item() 
    for k, v in param_grad_norms.items():
        param_grad_norms[k] = v.norm().item()

    return {
        'grad_norms': grad_norms,
        'param_grad_norms': param_grad_norms,
        'loss': np.mean(losses)
    }

loss_info = get_grad_norms(model, loader, model_pair.loss_fn)
natural_loss_info = get_grad_norms(natural_model, loader, model_pair.loss_fn)

100%|██████████| 4/4 [00:00<00:00, 60.75it/s]
100%|██████████| 4/4 [00:00<00:00, 65.70it/s]
100%|██████████| 4/4 [00:00<00:00, 66.55it/s]


In [33]:
import pandas as pd
activation_grads = pd.DataFrame([loss_info['grad_norms'], natural_loss_info['grad_norms']],
                  index=['tuned', 'natural']).T
param_grads = pd.DataFrame([loss_info['param_grad_norms'], natural_loss_info['param_grad_norms']],
                    index=['tuned', 'natural']).T
losses = pd.DataFrame([loss_info['loss'], natural_loss_info['loss']], index=['tuned', 'natural'], columns=['loss'])

In [35]:
activation_grads

Unnamed: 0,tuned,natural
hook_embed,3.59635e-13,5.776981e-08
hook_pos_embed,3.59635e-13,5.776981e-08
blocks.0.hook_resid_pre,3.59635e-13,5.776981e-08
blocks.0.hook_q_input,2.449378e-14,3.410433e-09
blocks.0.hook_k_input,1.124774e-14,8.548683e-10
blocks.0.hook_v_input,3.881581e-14,7.945035e-09
blocks.0.attn.hook_q,1.654241e-14,1.947697e-09
blocks.0.attn.hook_k,1.08053e-14,6.690914e-10
blocks.0.attn.hook_v,6.721299e-14,1.260907e-08
blocks.0.attn.hook_attn_scores,1.396365e-14,2.515759e-09
