In [1]:
%load_ext autoreload
%autoreload 2

import torch
from circuits_benchmark.utils.get_cases import get_cases
import iit.model_pairs as mp
from circuits_benchmark.utils.ll_model_loader.ll_model_loader_factory import get_ll_model_loader
from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase
import os

task_idx = '3'
out_dir = f'results/tuned_lens/{task_idx}'
os.makedirs(out_dir, exist_ok=True)

task: BenchmarkCase = get_cases(indices=[task_idx])[0]

ll_model_loader = get_ll_model_loader(task, interp_bench=True)
hl_ll_corr, model = ll_model_loader.load_ll_model_and_correspondence(device='cuda' if torch.cuda.is_available() else 'cpu')
# turn off grads
model.eval()
model.requires_grad_(False)

hl_model = task.get_hl_model()
model_pair = mp.StrictIITModelPair(hl_model, model, hl_ll_corr)

{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_attn_scores': HookPoint(), 'b

In [2]:
%%capture
max_len = 1000
unique_test_data = task.get_clean_data(max_samples=max_len, unique_data=True)

loader = torch.utils.data.DataLoader(unique_test_data, batch_size=256, shuffle=False, drop_last=False)

In [3]:
if model_pair.hl_model.is_categorical():
    # preprocess model for logit lens
    model.center_writing_weights(state_dict=model.state_dict())
    model.center_unembed(state_dict=model.state_dict())
    model.refactor_factored_attn_matrices(state_dict=model.state_dict())
try:
    model.fold_layer_norm(state_dict=model.state_dict())
except:
    print("No layer norm to fold")

No layer norm to fold


In [4]:
import interp_utils.lens.logit_lens as logit_lens

logit_lens_results, labels = logit_lens.do_logit_lens(model_pair, loader)

torch.Size([256, 4, 12]) torch.Size([4, 256, 4, 12])


In [5]:
from interp_utils.lens.plot_utils import get_formatted_node_names_in_circuit
nodes = get_formatted_node_names_in_circuit(model_pair)

In [12]:
from interp_utils.lens.plot_utils import plot_pearson, plot_combined_pearson
for k in logit_lens_results.keys():
    plot_pearson(
        key=k, lens_results=logit_lens_results, labels=labels, 
        is_categorical=model_pair.hl_model.is_categorical(),
        in_circuit=k in nodes,
        tuned_lens=False,
        case_name=task.get_name(),
        show=False,
    )


An input array is constant; the correlation coefficient is not defined.



In [13]:
# k = "L1H2"
k = "1_mlp_out"
plot_pearson(key=k, lens_results=logit_lens_results, labels=labels, 
             is_categorical=model_pair.hl_model.is_categorical(),
             in_circuit=k in nodes,
             tuned_lens=False,
             case_name=task.get_name(),
             show=True,
)

'./interp_results//3/logit_lens/1_mlp_out/pearson.png'

In [8]:
plot_combined_pearson(
    lens_results=logit_lens_results, 
    labels=labels,
    nodes_in_circuit=get_formatted_node_names_in_circuit(model_pair),
    is_categorical=model_pair.hl_model.is_categorical(),
    tuned_lens=False,
    case_name=task.get_name(),
)


An input array is constant; the correlation coefficient is not defined.



'./interp_results//3/logit_lens/combined_pearson.png'

In [9]:
if model_pair.hl_model.is_categorical():
    logit_lens_per_vocab, per_vocab_labels = logit_lens.do_logit_lens_per_vocab_idx(model_pair, loader)

In [10]:
from interp_utils.lens.plot_utils import plot_pearson_at_vocab_idx
if model_pair.hl_model.is_categorical():
    for k in logit_lens_per_vocab.keys():
        for i in logit_lens_per_vocab[k].keys():
            plot_pearson_at_vocab_idx(
                key=k, vocab_idx=i, lens_results_per_vocab=logit_lens_per_vocab, 
                per_vocab_labels=per_vocab_labels, 
                in_circuit=k in nodes,
                tuned_lens=False,
                case_name=task.get_name(),
                show=False,
            )

In [11]:
 # k = "L1H2"
k = "0_mlp_out"
vocab_idx = 2

max_vocab_idx = model.cfg.d_vocab_out - 1
vocab_idx = min(vocab_idx, max_vocab_idx)
if model_pair.hl_model.is_categorical():    
    plot_pearson_at_vocab_idx(key=k, 
                              vocab_idx=vocab_idx,
                              lens_results_per_vocab=logit_lens_per_vocab, 
                              per_vocab_labels=per_vocab_labels,
                              in_circuit=k in nodes,
                              tuned_lens=False,
                              case_name=task.get_name(),
                              show=True,
    )