### Setup

In [1]:
import sys
import os
import transformer_lens as tl
from torch.utils.data import Dataset
import torch as t
from tqdm import tqdm
from torch.utils.data import DataLoader
import numpy as np
import wandb
from typing import List, Dict, Any, Optional

from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase, CaseDataset
from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer
import iit.model_pairs as mp
import iit.utils.index as index
from iit_utils.dataset import create_dataset, TracrDataset, TracrIITDataset
import iit_utils.correspondence as correspondence
from circuits_benchmark.utils.get_cases import get_cases
from circuits_benchmark.commands.build_main_parser import build_main_parser
from iit_utils.iit_hl_model import make_iit_hl_model

DEVICE = t.device("cuda" if t.cuda.is_available() else "cpu")
WANDB_ENTITY = "cybershiptrooper" # TODO make this an env var

  from .autonotebook import tqdm as notebook_tqdm


### Train Model

In [2]:
attn_idx = None
atol = 5e-2
losses = "all"
tracr_model_class = mp.StrictIITModelPair
case_num = 3

training_args = {
    "lr" : 1e-2,
    "losses" : losses,
    "atol" : atol,
    "batch_size" : 512,
    "use_single_loss": False,
    "iit_weight": 1.0,
    "behavior_weight": 1.0,
    "strict_weight": 1.0,
}
tracr_model_class.__name__

'StrictIITModelPair'

In [3]:
np.random.seed(0)
t.manual_seed(0)

args, _ = build_main_parser().parse_known_args(["compile",
                                                f"-i={case_num}",
                                                "-f",])
cases = get_cases(args)
case = cases[0]

tracr_output = case.build_tracr_model()
hl_model = case.build_transformer_lens_model()
# this is the graph node -> hl node correspondence
hl_ll_corr = correspondence.TracrCorrespondence.from_output(case, tracr_output)

In [4]:
hl_model([['BOS', 'x', 'b', 'a', 'a']], return_type='decoded')

[['BOS', 1.0, 0.5, 0.3333333432674408, 0.25]]

In [5]:
# seed everything
t.manual_seed(0)
np.random.seed(0)
import random
random.seed(0)

In [6]:
data = case.get_clean_data(count=15000)
inputs = data.get_inputs().to_numpy()
outputs = data.get_correct_outputs().to_numpy()

train_inputs = inputs[:12000]
test_inputs = inputs[12000:]
train_outputs = outputs[:12000]
test_outputs = outputs[12000:]

train_set, test_set = create_dataset(case, hl_model)

In [7]:
from transformer_lens import HookedTransformer, HookedTransformerConfig

cfg_dict = {
    "n_layers": 2, 
    "n_heads": 4, 
    "d_head": 4,
    "d_model": 8,
    "d_mlp": 16,
    "act_fn": "gelu",
}
ll_cfg = hl_model.cfg.to_dict().copy()
ll_cfg.update(cfg_dict)


print(ll_cfg)
ll_cfg = HookedTransformerConfig.from_dict(ll_cfg)
model = HookedTransformer(ll_cfg)


{'n_layers': 2, 'd_model': 8, 'n_ctx': 5, 'd_head': 4, 'model_name': 'custom', 'n_heads': 4, 'd_mlp': 16, 'act_fn': 'gelu', 'd_vocab': 6, 'eps': 1e-05, 'use_attn_result': True, 'use_attn_scale': True, 'use_split_qkv_input': True, 'use_hook_mlp_in': True, 'use_attn_in': False, 'use_local_attn': False, 'original_architecture': None, 'from_checkpoint': False, 'checkpoint_index': None, 'checkpoint_label_type': None, 'checkpoint_value': None, 'tokenizer_name': None, 'window_size': None, 'attn_types': None, 'init_mode': 'gpt2', 'normalization_type': None, 'device': device(type='mps'), 'n_devices': 1, 'attention_dir': 'causal', 'attn_only': False, 'seed': None, 'initializer_range': 0.22188007849009167, 'init_weights': True, 'scale_attn_by_inverse_layer_idx': False, 'positional_embedding_type': 'standard', 'final_rms': False, 'd_vocab_out': 1, 'parallel_attn_mlp': False, 'rotary_dim': None, 'n_params': 676, 'use_hook_tokens': False, 'gated_mlp': False, 'default_prepend_bos': True, 'dtype': tor

In [8]:
model_pair = tracr_model_class(
    hl_model = make_iit_hl_model(hl_model),
    ll_model = model,
    corr = hl_ll_corr,
    training_args=training_args,
)

{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_attn_scores': HookPoint(), 'b

In [9]:

ll_model = HookedTracrTransformer(
    ll_cfg, hl_model.tracr_input_encoder, hl_model.tracr_output_encoder, hl_model.residual_stream_labels
)
ll_model.load_weights_from_file(f"ll_models/{case_num}/ll_model_510.pth")
model_pair.ll_model = ll_model

In [10]:
# model_pair.train(
#     train_set,
#     test_set,
#     epochs=50,
#     use_wandb=False,
# )

### Setup Eval

In [11]:
"""Create a new test set with unique inputs"""

arr, idxs = np.unique([", ".join(i) for i in np.array(test_inputs)], return_inverse=True)
# create indices that point to the first unique input
all_possible_inputs = np.arange(arr.shape[0])
# find the first occurence of all_possible_inputs in idxs
first_occurences = [np.where(idxs == i)[0][0] for i in all_possible_inputs]

unique_test_inputs = test_inputs[first_occurences]
unique_test_outputs = test_outputs[first_occurences]
assert len(unique_test_inputs) == len(unique_test_outputs)
assert len(unique_test_inputs) == len(np.unique([", ".join(i) for i in np.array(test_inputs)]))
assert len(np.unique([", ".join(i) for i in np.array(unique_test_inputs)])) == len(unique_test_inputs)

unique_test_data = TracrDataset(unique_test_inputs, unique_test_outputs)
test_set = TracrIITDataset(unique_test_data, unique_test_data, hl_model, every_combination=True)
test_loader = test_set.make_loader(batch_size=512, num_workers=0)

In [12]:
def tokenise_data(batch, model: HookedTracrTransformer) -> t.Tensor:
    x = list(map(list, zip(*batch)))
    encoded_x = model.map_tracr_input_to_tl_input(x)
    return encoded_x

In [13]:
tensorised_base_data = []
tensorised_ablation_data = []
base_answer_tokens = []
for base_in, ablation_in in test_loader:
    base_x, base_y,  _ = base_in
    ablation_x, ablation_y, _ = ablation_in

    tensorised_base_data.append((base_x))
    tensorised_ablation_data.append((ablation_x))
    base_answer_tokens.append(base_y)

base_tensor = t.cat(tensorised_base_data, dim=0)
ablation_tensor = t.cat(tensorised_ablation_data, dim=0)
base_answer_tokens = t.cat(base_answer_tokens, dim=0)

In [14]:
model.requires_grad_(False)
model.eval()
hl_model.requires_grad_(False)
hl_model.eval()
print()




In [15]:
original_logits, cache = model.run_with_cache(base_tensor)

In [16]:
hl_answers = hl_model(base_tensor)
(hl_answers.shape), hl_answers[0], base_answer_tokens[0] # Wtf???

(torch.Size([65536, 5, 1]),
 tensor([[0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.2500]], device='mps:0'),
 tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.2500], device='mps:0'))

In [17]:
original_logits.shape, original_logits[3], hl_answers.shape, hl_answers[3]

(torch.Size([65536, 5, 1]),
 tensor([[ 0.3080],
         [ 0.3513],
         [ 0.2932],
         [ 0.1778],
         [-0.0688]], device='mps:0'),
 torch.Size([65536, 5, 1]),
 tensor([[0.0000],
         [0.0000],
         [0.0000],
         [0.3333],
         [0.2500]], device='mps:0'))

### Patch Attention Heads to see Causal Effect

In [18]:
from iit.utils.node_picker import get_nodes_not_in_circuit, get_nodes_in_circuit, get_all_nodes

nodes_not_in_circuit = get_nodes_not_in_circuit(model_pair.ll_model, hl_ll_corr)
nodes_not_in_circuit, "---", list(hl_ll_corr.values())

([LLNode(name='blocks.0.attn.hook_result', index=[:, :, 0, :], subspace=None),
  LLNode(name='blocks.0.attn.hook_result', index=[:, :, 1, :], subspace=None),
  LLNode(name='blocks.0.attn.hook_result', index=[:, :, 2, :], subspace=None),
  LLNode(name='blocks.0.attn.hook_result', index=[:, :, 3, :], subspace=None),
  LLNode(name='blocks.1.attn.hook_result', index=[:, :, 0, :], subspace=None),
  LLNode(name='blocks.1.attn.hook_result', index=[:, :, 3, :], subspace=None),
  LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)],
 '---',
 [{LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)},
  {LLNode(name='blocks.1.attn.hook_result', index=[:, :, 1:3, :], subspace=None)}])

In [19]:
from iit_utils.evals import check_causal_effect, make_dataframe_of_results

In [20]:
np.random.seed(0)
t.manual_seed(0)
result_not_in_circuit = check_causal_effect(model_pair, test_set, node_type="n", verbose=False)
result_in_circuit = check_causal_effect(model_pair, test_set, node_type="c", verbose=False)

100%|██████████| 256/256 [00:19<00:00, 13.19it/s]
100%|██████████| 256/256 [00:05<00:00, 45.89it/s]


In [21]:
df = make_dataframe_of_results(result_not_in_circuit, result_in_circuit)
# df.style.apply(color_table, subset=["status"])
print(attn_idx, training_args, tracr_model_class)
df

None {'lr': 0.01, 'losses': 'all', 'atol': 0.05, 'batch_size': 512, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 1.0} <class 'iit.model_pairs.strict_iit_model_pair.StrictIITModelPair'>


Unnamed: 0,node,status,causal effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,1.3e-05
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.002159
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.0
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.0
4,"blocks.1.attn.hook_result, head 0",not_in_circuit,1.0
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.010526
6,blocks.1.mlp.hook_post,not_in_circuit,0.207744
7,blocks.0.mlp.hook_post,in_circuit,1.0
8,"blocks.1.attn.hook_result, head 1:3",in_circuit,1.0


In [22]:
from iit.utils.metric import MetricStore
def print_metrics(metrics: list[MetricStore]):
    for metric in metrics:
        print(f"{metric.get_name()}: {metric.get_value()}")

metric_collection = model_pair._run_eval_epoch(test_loader, model_pair.loss_fn)

In [23]:
print_metrics(metric_collection.metrics)

val/iit_loss: 0.035202935819370396
val/IIA: 74.02771129272878
val/accuracy: 100.0


### Do the same with zero ablations

In [24]:
from iit_utils.evals import check_causal_effect_on_ablation
from iit_utils.dataset import TracrUniqueDataset

In [25]:
uni_test_set = TracrUniqueDataset(unique_test_data, unique_test_data, hl_model, every_combination=True)

In [26]:
np.random.seed(0)
t.manual_seed(0)
use_mean_cache = True
za_result_not_in_circuit = check_causal_effect_on_ablation(model_pair, uni_test_set, node_type="n", verbose=False,  use_mean_cache=use_mean_cache)
za_result_in_circuit = check_causal_effect_on_ablation(model_pair, uni_test_set, node_type="c", verbose=False,  use_mean_cache=use_mean_cache)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00, 12.58it/s]
100%|██████████| 1/1 [00:00<00:00, 40.51it/s]


In [27]:
df = make_dataframe_of_results(za_result_not_in_circuit, za_result_in_circuit)
# df.style.map(color_table, subset=["status"])
print(attn_idx, training_args, tracr_model_class)
df

None {'lr': 0.01, 'losses': 'all', 'atol': 0.05, 'batch_size': 512, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 1.0} <class 'iit.model_pairs.strict_iit_model_pair.StrictIITModelPair'>


Unnamed: 0,node,status,causal effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.0
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.0
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.0
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.0
4,"blocks.1.attn.hook_result, head 0",not_in_circuit,0.715625
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.0
6,blocks.1.mlp.hook_post,not_in_circuit,0.0
7,blocks.0.mlp.hook_post,in_circuit,0.778906
8,"blocks.1.attn.hook_result, head 1:3",in_circuit,0.63125


### Combined table

In [28]:
from iit_utils.evals import make_combined_dataframe_of_results
df = make_combined_dataframe_of_results(result_not_in_circuit, result_in_circuit, za_result_not_in_circuit, za_result_in_circuit, use_mean_cache=use_mean_cache)
# df.style.apply(color_table, subset=["status"], method = "map")   
print(attn_idx, training_args, tracr_model_class)
df

None {'lr': 0.01, 'losses': 'all', 'atol': 0.05, 'batch_size': 512, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 1.0} <class 'iit.model_pairs.strict_iit_model_pair.StrictIITModelPair'>


Unnamed: 0,node,status,resample_ablate_effect,mean_ablate_effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,1.3e-05,0.0
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.002159,0.0
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.0,0.0
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.0,0.0
4,"blocks.1.attn.hook_result, head 0",not_in_circuit,1.0,0.715625
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.010526,0.0
6,blocks.1.mlp.hook_post,not_in_circuit,0.207744,0.0
7,blocks.0.mlp.hook_post,in_circuit,1.0,0.778906
8,"blocks.1.attn.hook_result, head 1:3",in_circuit,1.0,0.63125


In [29]:
# save the results
import time
save_dir = f"results/{tracr_model_class.__name__}/{time.strftime('%d-%H-%M-%S')}"
from iit_utils.evals import save_result
save_result(df, save_dir, model_pair)

### Rough

In [30]:
# def get_all_bad_examples(model_pair, loader, atol=5e-2):
#     model_pair.ll_model.eval()
#     model_pair.hl_model.eval()
#     bad_io_examples = []
#     bad_ii_examples = []

#     for base_in, ablation_in in tqdm(loader):
#         base_in = model_pair.get_encoded_input_from_torch_input(base_in)
#         ablation_in = model_pair.get_encoded_input_from_torch_input(ablation_in)
#         for node in model_pair.corr.keys():
#             hl_node = node.name
#             ll_out, hl_out = model_pair.do_intervention(base_in, ablation_in, hl_node)
#             if model_pair.hl_model.is_categorical():
#                 top1 = t.argmax(ll_out, dim=1)
#                 correct = (top1 == hl_out).float()
#             else:
#                 correct = ((ll_out - hl_out).abs() < atol).float()
            
#             for i, c in enumerate(correct):
#                 print(c)
#                 if c == 0:
#                     bad_ii_examples.append((base_in[i], ablation_in[i]))
#         base_x, base_y = base_in
#         ll_out = model_pair.ll_model(base_x)
#         if model_pair.hl_model.is_categorical():
#             top1 = t.argmax(ll_out, dim=1)
#             correct = (top1 == base_y).float()
#         else:
#             correct = ((ll_out - base_y).abs() < atol).float()
        
#         for i, c in enumerate(correct):
#             if c == 0:
#                 if base_x[i] not in bad_io_examples:
#                     bad_io_examples.append((base_x[i]))

#     return bad_io_examples, bad_ii_examples

# bad_io_examples, bad_ii_examples = get_all_bad_examples(model_pair, test_loader, atol)

# bad_io_examples, bad_ii_examples

In [31]:
# np.random.seed(0)
# t.manual_seed(0)
# test_loader = DataLoader(test_set, batch_size=2, shuffle=True)
# base_in, ablation_in = next(iter(test_loader))

# hooker = model_pair.make_ll_ablation_hook(nodes_not_in_circuit[2])
# base_x, base_y = model_pair.get_encoded_input_from_torch_input(base_in)
# ablation_x, ablation_y = model_pair.get_encoded_input_from_torch_input(ablation_in)
# ll_out = do_intervention(model_pair.ll_model, base_x, ablation_x, nodes_not_in_circuit[2], hooker)
# ll_base_out, ll_base_cache = model_pair.ll_model.run_with_cache(base_x)
# ll_ablation_out, ll_ablation_cache = model_pair.ll_model.run_with_cache(ablation_x)
# for i in range(2):
#     print(
#         "---",
#         f"example {i}", 
#         "base_y:", base_y[i],
#         "ll_base_out:", ll_base_out[i].T,
#         "",
#         "ablation_y:", ablation_y[i],
#         "ll_ablation_out:", ll_ablation_out[i].T,
#         "",
#         "ll_out:", ll_out[i].T,
#         sep="\n"
#     )