### Setup

In [1]:
import sys
import os
import transformer_lens as tl
from torch.utils.data import Dataset
import torch as t
from tqdm import tqdm
from torch.utils.data import DataLoader
import numpy as np
import wandb
from typing import List, Dict, Any, Optional

from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase, CaseDataset
from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer
import iit.model_pairs as mp
import iit.utils.index as index
from iit_utils.dataset import create_dataset, TracrDataset, TracrIITDataset
import iit_utils.correspondence as correspondence
from circuits_benchmark.utils.get_cases import get_cases
from circuits_benchmark.commands.build_main_parser import build_main_parser
from iit_utils.iit_hl_model import make_iit_hl_model

DEVICE = t.device("cuda" if t.cuda.is_available() else "cpu")
WANDB_ENTITY = "cybershiptrooper" # TODO make this an env var

  from .autonotebook import tqdm as notebook_tqdm


### Train Model

In [2]:
attn_idx = None
atol = 5e-2
losses = "all"
tracr_model_class = mp.StopGradModelPair
case_num = 3
train_model = True
training_args = {
    "lr" : 1e-3,
    "losses" : losses,
    "atol" : atol,
    "batch_size" : 512,
    "use_single_loss": False,
    "iit_weight": 0.0,
    "behavior_weight": 1.0,
    "strict_weight": 0.0,
    "scale": 1.0,
    "use_ln_hooks": True,
    "clip_grad_norm": 1,
    # "lr_scheduler": None
}
tracr_model_class.__name__

'StopGradModelPair'

In [3]:
np.random.seed(0)
t.manual_seed(0)

args, _ = build_main_parser().parse_known_args(["compile",
                                                f"-i={case_num}",
                                                "-f",])
cases = get_cases(args)
case = cases[0]

tracr_output = case.build_tracr_model()
hl_model = case.build_transformer_lens_model()
# this is the graph node -> hl node correspondence
hl_ll_corr = correspondence.TracrCorrespondence.from_output(case, tracr_output)

In [4]:
hl_model([['BOS', 'x', 'b', 'a', 'a']], return_type='decoded')

[['BOS', 1.0, 0.5, 0.3333333432674408, 0.25]]

In [5]:
# seed everything
t.manual_seed(1)
np.random.seed(1)
import random
random.seed(1)

In [6]:
data = case.get_clean_data(count=15000)
inputs = data.get_inputs().to_numpy()
outputs = data.get_correct_outputs().to_numpy()

train_inputs = inputs[:12000]
test_inputs = inputs[12000:]
train_outputs = outputs[:12000]
test_outputs = outputs[12000:]

train_set, test_set = create_dataset(case, hl_model)

In [7]:
from transformer_lens import HookedTransformer, HookedTransformerConfig

cfg_dict = {
    "n_layers": 2, 
    "n_heads": 4, 
    "d_head": 4,
    "d_model": 8,
    "d_mlp": 16,
    "act_fn": "gelu",
}
ll_cfg = hl_model.cfg.to_dict().copy()
ll_cfg.update(cfg_dict)


print(ll_cfg)
ll_cfg = HookedTransformerConfig.from_dict(ll_cfg)
model = HookedTransformer(ll_cfg)


{'n_layers': 2, 'd_model': 8, 'n_ctx': 5, 'd_head': 4, 'model_name': 'custom', 'n_heads': 4, 'd_mlp': 16, 'act_fn': 'gelu', 'd_vocab': 6, 'eps': 1e-05, 'use_attn_result': True, 'use_attn_scale': True, 'use_split_qkv_input': True, 'use_hook_mlp_in': True, 'use_attn_in': False, 'use_local_attn': False, 'original_architecture': None, 'from_checkpoint': False, 'checkpoint_index': None, 'checkpoint_label_type': None, 'checkpoint_value': None, 'tokenizer_name': None, 'window_size': None, 'attn_types': None, 'init_mode': 'gpt2', 'normalization_type': None, 'device': device(type='mps'), 'n_devices': 1, 'attention_dir': 'causal', 'attn_only': False, 'seed': None, 'initializer_range': 0.22188007849009167, 'init_weights': True, 'scale_attn_by_inverse_layer_idx': False, 'positional_embedding_type': 'standard', 'final_rms': False, 'd_vocab_out': 1, 'parallel_attn_mlp': False, 'rotary_dim': None, 'n_params': 676, 'use_hook_tokens': False, 'gated_mlp': False, 'default_prepend_bos': True, 'dtype': tor

In [8]:
model_pair = tracr_model_class(
    hl_model = make_iit_hl_model(hl_model),
    ll_model = model,
    corr = hl_ll_corr,
    training_args=training_args,
)

{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_attn_scores': HookPoint(), 'b

In [9]:

ll_model = HookedTracrTransformer(
    ll_cfg, hl_model.tracr_input_encoder, hl_model.tracr_output_encoder, hl_model.residual_stream_labels
)


In [10]:
if train_model:
    model_pair.train(
        train_set,
        test_set,
        epochs=1000,
        use_wandb=True,
    )
else:
    ll_model.load_weights_from_file(f"ll_models/{case_num}/ll_model_510.pth")
    model_pair.ll_model = ll_model

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


training_args={'batch_size': 512, 'lr': 0.001, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, 'scheduler_val_metric': 'val/accuracy', 'scheduler_mode': 'max', 'clip_grad_norm': 1, 'atol': 0.05, 'use_single_loss': False, 'iit_weight': 0.0, 'behavior_weight': 1.0, 'scale': 1.0, 'use_ln_hooks': True, 'losses': 'all', 'strict_weight': 0.0}


[34m[1mwandb[0m: Currently logged in as: [33mcybershiptrooper[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 24/24 [00:02<00:00, 11.07it/s]
  0%|          | 1/1000 [00:02<41:04,  2.47s/it]


Epoch 0: train/iit_loss: 0.0000, train/behavior_loss: 0.0630, val/iit_loss: 0.0902, val/IIA: 6.24%, val/accuracy: 11.31%, 


100%|██████████| 24/24 [00:01<00:00, 14.18it/s]
  0%|          | 2/1000 [00:04<35:25,  2.13s/it]


Epoch 1: train/iit_loss: 0.0000, train/behavior_loss: 0.0324, val/iit_loss: 0.1079, val/IIA: 25.29%, val/accuracy: 30.45%, 


100%|██████████| 24/24 [00:01<00:00, 14.49it/s]
  0%|          | 3/1000 [00:06<33:12,  2.00s/it]


Epoch 2: train/iit_loss: 0.0000, train/behavior_loss: 0.0212, val/iit_loss: 0.1009, val/IIA: 27.61%, val/accuracy: 30.59%, 


100%|██████████| 24/24 [00:01<00:00, 14.56it/s]
  0%|          | 4/1000 [00:08<31:59,  1.93s/it]


Epoch 3: train/iit_loss: 0.0000, train/behavior_loss: 0.0149, val/iit_loss: 0.1011, val/IIA: 29.09%, val/accuracy: 38.98%, 


100%|██████████| 24/24 [00:01<00:00, 14.16it/s]
  0%|          | 5/1000 [00:09<31:34,  1.90s/it]


Epoch 4: train/iit_loss: 0.0000, train/behavior_loss: 0.0095, val/iit_loss: 0.0533, val/IIA: 34.51%, val/accuracy: 47.89%, 


100%|██████████| 24/24 [00:01<00:00, 14.67it/s]
  1%|          | 6/1000 [00:11<31:07,  1.88s/it]


Epoch 5: train/iit_loss: 0.0000, train/behavior_loss: 0.0052, val/iit_loss: 0.0866, val/IIA: 39.01%, val/accuracy: 62.78%, 


100%|██████████| 24/24 [00:01<00:00, 14.19it/s]
  1%|          | 7/1000 [00:13<31:01,  1.87s/it]


Epoch 6: train/iit_loss: 0.0000, train/behavior_loss: 0.0025, val/iit_loss: 0.1073, val/IIA: 43.64%, val/accuracy: 73.20%, 


100%|██████████| 24/24 [00:01<00:00, 14.00it/s]
  1%|          | 8/1000 [00:15<31:02,  1.88s/it]


Epoch 7: train/iit_loss: 0.0000, train/behavior_loss: 0.0012, val/iit_loss: 0.0688, val/IIA: 46.37%, val/accuracy: 90.82%, 


100%|██████████| 24/24 [00:01<00:00, 14.37it/s]
  1%|          | 9/1000 [00:17<30:54,  1.87s/it]


Epoch 8: train/iit_loss: 0.0000, train/behavior_loss: 0.0007, val/iit_loss: 0.0855, val/IIA: 49.38%, val/accuracy: 93.89%, 


100%|██████████| 24/24 [00:01<00:00, 14.17it/s]
  1%|          | 10/1000 [00:19<30:50,  1.87s/it]


Epoch 9: train/iit_loss: 0.0000, train/behavior_loss: 0.0005, val/iit_loss: 0.1256, val/IIA: 53.26%, val/accuracy: 96.22%, 


100%|██████████| 24/24 [00:01<00:00, 14.16it/s]
  1%|          | 11/1000 [00:21<30:48,  1.87s/it]


Epoch 10: train/iit_loss: 0.0000, train/behavior_loss: 0.0005, val/iit_loss: 0.0866, val/IIA: 50.54%, val/accuracy: 96.12%, 


100%|██████████| 24/24 [00:01<00:00, 14.75it/s]
  1%|          | 12/1000 [00:22<30:25,  1.85s/it]


Epoch 11: train/iit_loss: 0.0000, train/behavior_loss: 0.0004, val/iit_loss: 0.1421, val/IIA: 55.34%, val/accuracy: 97.60%, 


100%|██████████| 24/24 [00:01<00:00, 14.31it/s]
  1%|▏         | 13/1000 [00:24<30:23,  1.85s/it]


Epoch 12: train/iit_loss: 0.0000, train/behavior_loss: 0.0004, val/iit_loss: 0.1014, val/IIA: 52.81%, val/accuracy: 99.38%, 


100%|██████████| 24/24 [00:01<00:00, 14.32it/s]
  1%|▏         | 14/1000 [00:26<30:24,  1.85s/it]


Epoch 13: train/iit_loss: 0.0000, train/behavior_loss: 0.0004, val/iit_loss: 0.0855, val/IIA: 50.96%, val/accuracy: 99.11%, 


100%|██████████| 24/24 [00:01<00:00, 14.69it/s]
  2%|▏         | 15/1000 [00:28<30:08,  1.84s/it]


Epoch 14: train/iit_loss: 0.0000, train/behavior_loss: 0.0003, val/iit_loss: 0.0642, val/IIA: 49.75%, val/accuracy: 99.09%, 


100%|██████████| 24/24 [00:01<00:00, 13.85it/s]
  2%|▏         | 16/1000 [00:30<30:25,  1.86s/it]


Epoch 15: train/iit_loss: 0.0000, train/behavior_loss: 0.0003, val/iit_loss: 0.1222, val/IIA: 54.17%, val/accuracy: 99.11%, 


100%|██████████| 24/24 [00:01<00:00, 13.00it/s]
  2%|▏         | 17/1000 [00:32<31:14,  1.91s/it]


Epoch 16: train/iit_loss: 0.0000, train/behavior_loss: 0.0003, val/iit_loss: 0.0831, val/IIA: 51.24%, val/accuracy: 99.17%, 


100%|██████████| 24/24 [00:01<00:00, 14.40it/s]
  2%|▏         | 18/1000 [00:34<30:55,  1.89s/it]


Epoch 17: train/iit_loss: 0.0000, train/behavior_loss: 0.0003, val/iit_loss: 0.0632, val/IIA: 49.37%, val/accuracy: 99.60%, 


100%|██████████| 24/24 [00:01<00:00, 13.85it/s]
  2%|▏         | 19/1000 [00:36<30:56,  1.89s/it]


Epoch 18: train/iit_loss: 0.0000, train/behavior_loss: 0.0003, val/iit_loss: 0.0817, val/IIA: 51.03%, val/accuracy: 99.61%, 


100%|██████████| 24/24 [00:01<00:00, 13.97it/s]
  2%|▏         | 20/1000 [00:37<30:58,  1.90s/it]


Epoch 19: train/iit_loss: 0.0000, train/behavior_loss: 0.0003, val/iit_loss: 0.0826, val/IIA: 50.62%, val/accuracy: 99.61%, 


100%|██████████| 24/24 [00:01<00:00, 13.77it/s]
  2%|▏         | 21/1000 [00:39<31:10,  1.91s/it]


Epoch 20: train/iit_loss: 0.0000, train/behavior_loss: 0.0003, val/iit_loss: 0.1214, val/IIA: 53.93%, val/accuracy: 99.66%, 


100%|██████████| 24/24 [00:01<00:00, 13.59it/s]
  2%|▏         | 22/1000 [00:41<31:19,  1.92s/it]


Epoch 21: train/iit_loss: 0.0000, train/behavior_loss: 0.0003, val/iit_loss: 0.0834, val/IIA: 50.53%, val/accuracy: 99.65%, 


100%|██████████| 24/24 [00:01<00:00, 14.20it/s]
  2%|▏         | 23/1000 [00:43<31:01,  1.91s/it]


Epoch 22: train/iit_loss: 0.0000, train/behavior_loss: 0.0003, val/iit_loss: 0.1218, val/IIA: 54.04%, val/accuracy: 99.66%, 


100%|██████████| 24/24 [00:01<00:00, 14.82it/s]
  2%|▏         | 24/1000 [00:45<30:27,  1.87s/it]


Epoch 23: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0418, val/IIA: 45.18%, val/accuracy: 99.91%, 


100%|██████████| 24/24 [00:01<00:00, 14.19it/s]
  2%|▎         | 25/1000 [00:47<30:22,  1.87s/it]


Epoch 24: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1018, val/IIA: 51.32%, val/accuracy: 99.91%, 


100%|██████████| 24/24 [00:01<00:00, 14.47it/s]
  3%|▎         | 26/1000 [00:49<30:14,  1.86s/it]


Epoch 25: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0806, val/IIA: 49.80%, val/accuracy: 99.91%, 


100%|██████████| 24/24 [00:01<00:00, 13.52it/s]
  3%|▎         | 27/1000 [00:51<30:48,  1.90s/it]


Epoch 26: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0625, val/IIA: 46.50%, val/accuracy: 99.86%, 


100%|██████████| 24/24 [00:01<00:00, 13.97it/s]
  3%|▎         | 28/1000 [00:53<30:41,  1.89s/it]


Epoch 27: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0617, val/IIA: 46.29%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.66it/s]
  3%|▎         | 29/1000 [00:54<30:25,  1.88s/it]


Epoch 28: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1199, val/IIA: 53.84%, val/accuracy: 99.91%, 


100%|██████████| 24/24 [00:01<00:00, 14.34it/s]
  3%|▎         | 30/1000 [00:56<30:13,  1.87s/it]


Epoch 29: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0617, val/IIA: 46.66%, val/accuracy: 99.91%, 


100%|██████████| 24/24 [00:01<00:00, 14.13it/s]
  3%|▎         | 31/1000 [00:58<30:12,  1.87s/it]


Epoch 30: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0819, val/IIA: 48.98%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.71it/s]
  3%|▎         | 32/1000 [01:00<29:54,  1.85s/it]


Epoch 31: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0421, val/IIA: 44.16%, val/accuracy: 99.91%, 


100%|██████████| 24/24 [00:01<00:00, 14.40it/s]
  3%|▎         | 33/1000 [01:02<29:48,  1.85s/it]


Epoch 32: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1209, val/IIA: 53.61%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 13.68it/s]
  3%|▎         | 34/1000 [01:04<30:08,  1.87s/it]


Epoch 33: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1193, val/IIA: 53.74%, val/accuracy: 99.90%, 


100%|██████████| 24/24 [00:01<00:00, 14.78it/s]
  4%|▎         | 35/1000 [01:06<29:45,  1.85s/it]


Epoch 34: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0409, val/IIA: 44.50%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.10it/s]
  4%|▎         | 36/1000 [01:07<29:52,  1.86s/it]


Epoch 35: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0616, val/IIA: 46.84%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 13.76it/s]
  4%|▎         | 37/1000 [01:09<30:10,  1.88s/it]


Epoch 36: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0605, val/IIA: 47.05%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.17it/s]
  4%|▍         | 38/1000 [01:11<30:06,  1.88s/it]


Epoch 37: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0627, val/IIA: 47.31%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.14it/s]
  4%|▍         | 39/1000 [01:13<30:00,  1.87s/it]


Epoch 38: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0830, val/IIA: 49.50%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.18it/s]
  4%|▍         | 40/1000 [01:15<29:56,  1.87s/it]


Epoch 39: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0615, val/IIA: 47.34%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.85it/s]
  4%|▍         | 41/1000 [01:17<29:29,  1.85s/it]


Epoch 40: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0811, val/IIA: 49.82%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 13.91it/s]
  4%|▍         | 42/1000 [01:19<29:42,  1.86s/it]


Epoch 41: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1025, val/IIA: 51.72%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.26it/s]
  4%|▍         | 43/1000 [01:20<29:41,  1.86s/it]


Epoch 42: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0825, val/IIA: 49.36%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.34it/s]
  4%|▍         | 44/1000 [01:22<29:35,  1.86s/it]


Epoch 43: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0423, val/IIA: 45.22%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 13.89it/s]
  4%|▍         | 45/1000 [01:24<29:46,  1.87s/it]


Epoch 44: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1027, val/IIA: 51.66%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.25it/s]
  5%|▍         | 46/1000 [01:26<29:45,  1.87s/it]


Epoch 45: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1003, val/IIA: 51.84%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.18it/s]
  5%|▍         | 47/1000 [01:28<29:41,  1.87s/it]


Epoch 46: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0809, val/IIA: 50.24%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 13.91it/s]
  5%|▍         | 48/1000 [01:30<29:46,  1.88s/it]


Epoch 47: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1003, val/IIA: 52.29%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.36it/s]
  5%|▍         | 49/1000 [01:32<29:37,  1.87s/it]


Epoch 48: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1009, val/IIA: 52.16%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.53it/s]
  5%|▌         | 50/1000 [01:34<29:22,  1.86s/it]


Epoch 49: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0613, val/IIA: 48.50%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.05it/s]
  5%|▌         | 51/1000 [01:35<29:27,  1.86s/it]


Epoch 50: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0800, val/IIA: 50.93%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.08it/s]
  5%|▌         | 52/1000 [01:37<29:35,  1.87s/it]


Epoch 51: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0410, val/IIA: 47.58%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.03it/s]
  5%|▌         | 53/1000 [01:39<29:34,  1.87s/it]


Epoch 52: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1018, val/IIA: 52.72%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.11it/s]
  5%|▌         | 54/1000 [01:41<29:31,  1.87s/it]


Epoch 53: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0597, val/IIA: 49.28%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.66it/s]
  6%|▌         | 55/1000 [01:43<29:14,  1.86s/it]


Epoch 54: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0797, val/IIA: 51.07%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.23it/s]
  6%|▌         | 56/1000 [01:45<29:13,  1.86s/it]


Epoch 55: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0820, val/IIA: 50.89%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.07it/s]
  6%|▌         | 57/1000 [01:47<29:19,  1.87s/it]


Epoch 56: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0821, val/IIA: 50.81%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.63it/s]
  6%|▌         | 58/1000 [01:48<29:03,  1.85s/it]


Epoch 57: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1013, val/IIA: 52.71%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 13.70it/s]
  6%|▌         | 59/1000 [01:50<29:21,  1.87s/it]


Epoch 58: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1407, val/IIA: 56.10%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.16it/s]
  6%|▌         | 60/1000 [01:52<29:20,  1.87s/it]


Epoch 59: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1000, val/IIA: 52.77%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.30it/s]
  6%|▌         | 61/1000 [01:54<29:11,  1.87s/it]


Epoch 60: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1024, val/IIA: 52.73%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 13.98it/s]
  6%|▌         | 62/1000 [01:56<29:15,  1.87s/it]


Epoch 61: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0998, val/IIA: 52.96%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.24it/s]
  6%|▋         | 63/1000 [01:58<29:12,  1.87s/it]


Epoch 62: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1011, val/IIA: 52.82%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.75it/s]
  6%|▋         | 64/1000 [02:00<28:59,  1.86s/it]


Epoch 63: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0807, val/IIA: 51.24%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 13.90it/s]
  6%|▋         | 65/1000 [02:02<29:09,  1.87s/it]


Epoch 64: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1197, val/IIA: 54.41%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.15it/s]
  7%|▋         | 66/1000 [02:03<29:10,  1.87s/it]


Epoch 65: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1036, val/IIA: 52.58%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.56it/s]
  7%|▋         | 67/1000 [02:05<28:54,  1.86s/it]


Epoch 66: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0839, val/IIA: 50.76%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.05it/s]
  7%|▋         | 68/1000 [02:07<28:57,  1.86s/it]


Epoch 67: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0817, val/IIA: 50.81%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.52it/s]
  7%|▋         | 69/1000 [02:09<28:48,  1.86s/it]


Epoch 68: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.1020, val/IIA: 52.76%, val/accuracy: 99.94%, 


100%|██████████| 24/24 [00:01<00:00, 14.52it/s]
  7%|▋         | 70/1000 [02:11<28:38,  1.85s/it]


Epoch 69: train/iit_loss: 0.0000, train/behavior_loss: 0.0002, val/iit_loss: 0.0421, val/IIA: 47.52%, val/accuracy: 99.94%, 




### Setup Eval

In [None]:
"""Create a new test set with unique inputs"""

arr, idxs = np.unique([", ".join(i) for i in np.array(test_inputs)], return_inverse=True)
# create indices that point to the first unique input
all_possible_inputs = np.arange(arr.shape[0])
# find the first occurence of all_possible_inputs in idxs
first_occurences = [np.where(idxs == i)[0][0] for i in all_possible_inputs]

unique_test_inputs = test_inputs[first_occurences]
unique_test_outputs = test_outputs[first_occurences]
assert len(unique_test_inputs) == len(unique_test_outputs)
assert len(unique_test_inputs) == len(np.unique([", ".join(i) for i in np.array(test_inputs)]))
assert len(np.unique([", ".join(i) for i in np.array(unique_test_inputs)])) == len(unique_test_inputs)

unique_test_data = TracrDataset(unique_test_inputs, unique_test_outputs)
test_set = TracrIITDataset(unique_test_data, unique_test_data, hl_model, every_combination=True)
test_loader = test_set.make_loader(batch_size=512, num_workers=0)

In [None]:
def tokenise_data(batch, model: HookedTracrTransformer) -> t.Tensor:
    x = list(map(list, zip(*batch)))
    encoded_x = model.map_tracr_input_to_tl_input(x)
    return encoded_x

In [None]:
tensorised_base_data = []
tensorised_ablation_data = []
base_answer_tokens = []
for base_in, ablation_in in test_loader:
    base_x, base_y,  _ = base_in
    ablation_x, ablation_y, _ = ablation_in

    tensorised_base_data.append((base_x))
    tensorised_ablation_data.append((ablation_x))
    base_answer_tokens.append(base_y)

base_tensor = t.cat(tensorised_base_data, dim=0)
ablation_tensor = t.cat(tensorised_ablation_data, dim=0)
base_answer_tokens = t.cat(base_answer_tokens, dim=0)

In [None]:
model.requires_grad_(False)
model.eval()
hl_model.requires_grad_(False)
hl_model.eval()
print()




In [None]:
original_logits, cache = model.run_with_cache(base_tensor)

In [None]:
hl_answers = hl_model(base_tensor)
(hl_answers.shape), hl_answers[0], base_answer_tokens[0] # Wtf???

(torch.Size([65536, 5, 1]),
 tensor([[0.0000],
         [1.0000],
         [0.5000],
         [0.3333],
         [0.2500]], device='mps:0'),
 tensor([0.0000, 1.0000, 0.5000, 0.3333, 0.2500], device='mps:0'))

In [None]:
original_logits.shape, original_logits[3], hl_answers.shape, hl_answers[3]

(torch.Size([65536, 5, 1]),
 tensor([[-3.6373e-04],
         [-3.6645e-02],
         [ 5.0384e-01],
         [ 7.0400e-01],
         [ 4.5902e-01]], device='mps:0'),
 torch.Size([65536, 5, 1]),
 tensor([[0.0000],
         [0.0000],
         [0.5000],
         [0.6667],
         [0.5000]], device='mps:0'))

### Patch Attention Heads to see Causal Effect

In [None]:
from iit.utils.node_picker import get_nodes_not_in_circuit, get_nodes_in_circuit, get_all_nodes

nodes_not_in_circuit = get_nodes_not_in_circuit(model_pair.ll_model, hl_ll_corr)
nodes_not_in_circuit, "---", list(hl_ll_corr.values())

([LLNode(name='blocks.0.attn.hook_result', index=[:, :, 0, :], subspace=None),
  LLNode(name='blocks.0.attn.hook_result', index=[:, :, 1, :], subspace=None),
  LLNode(name='blocks.0.attn.hook_result', index=[:, :, 2, :], subspace=None),
  LLNode(name='blocks.0.attn.hook_result', index=[:, :, 3, :], subspace=None),
  LLNode(name='blocks.1.attn.hook_result', index=[:, :, 2, :], subspace=None),
  LLNode(name='blocks.1.attn.hook_result', index=[:, :, 3, :], subspace=None),
  LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)],
 '---',
 [{LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)},
  {LLNode(name='blocks.1.attn.hook_result', index=[:, :, :2, :], subspace=None)}])

In [None]:
from iit_utils.evals import check_causal_effect, make_dataframe_of_results

In [None]:
np.random.seed(0)
t.manual_seed(0)
result_not_in_circuit = check_causal_effect(model_pair, test_set, node_type="n", verbose=False)
result_in_circuit = check_causal_effect(model_pair, test_set, node_type="c", verbose=False)

  0%|          | 0/256 [00:00<?, ?it/s]

100%|██████████| 256/256 [00:57<00:00,  4.42it/s]
100%|██████████| 256/256 [00:16<00:00, 15.46it/s]


In [None]:
df = make_dataframe_of_results(result_not_in_circuit, result_in_circuit)
# df.style.apply(color_table, subset=["status"])
print(attn_idx, training_args, tracr_model_class)
df

None {'lr': 0.0001, 'losses': 'all', 'atol': 0.05, 'batch_size': 512, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 0.0, 'scale': 10.0, 'use_ln_hooks': False} <class 'iit.model_pairs.stop_grad_pair.StopGradModelPair'>


Unnamed: 0,node,status,causal effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.823506
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.71186
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.866308
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.787053
4,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.818214
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.7742
6,blocks.1.mlp.hook_post,not_in_circuit,0.904683
7,"blocks.1.attn.hook_result, head :2",in_circuit,0.999641
8,blocks.0.mlp.hook_post,in_circuit,0.999891


In [None]:
from iit.utils.metric import MetricStore
def print_metrics(metrics: list[MetricStore]):
    for metric in metrics:
        print(f"{metric.get_name()}: {metric.get_value()}")

metric_collection = model_pair._run_eval_epoch(test_loader, model_pair.loss_fn)

In [None]:
print_metrics(metric_collection.metrics)

val/iit_loss: 0.0029948651190352393
val/IIA: 68.0581680033356
val/accuracy: 71.48437616415322


### Do the same with zero ablations

In [None]:
from iit_utils.evals import check_causal_effect_on_ablation
from iit_utils.dataset import TracrUniqueDataset

In [None]:
uni_test_set = TracrUniqueDataset(unique_test_data, unique_test_data, hl_model, every_combination=True)

In [None]:
np.random.seed(0)
t.manual_seed(0)
use_mean_cache = True
za_result_not_in_circuit = check_causal_effect_on_ablation(model_pair, uni_test_set, node_type="n", verbose=False,  use_mean_cache=use_mean_cache)
za_result_in_circuit = check_causal_effect_on_ablation(model_pair, uni_test_set, node_type="c", verbose=False,  use_mean_cache=use_mean_cache)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  5.47it/s]
100%|██████████| 1/1 [00:00<00:00, 18.11it/s]


In [None]:
df = make_dataframe_of_results(za_result_not_in_circuit, za_result_in_circuit)
# df.style.map(color_table, subset=["status"])
print(attn_idx, training_args, tracr_model_class)
df

None {'lr': 0.0001, 'losses': 'all', 'atol': 0.05, 'batch_size': 512, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 0.0, 'scale': 10.0, 'use_ln_hooks': False} <class 'iit.model_pairs.stop_grad_pair.StopGradModelPair'>


Unnamed: 0,node,status,causal effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.593443
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.504918
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.580328
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.502732
4,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.31694
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.29071
6,blocks.1.mlp.hook_post,not_in_circuit,0.568306
7,"blocks.1.attn.hook_result, head :2",in_circuit,0.660109
8,blocks.0.mlp.hook_post,in_circuit,0.660109


### Combined table

In [None]:
from iit_utils.evals import make_combined_dataframe_of_results
df = make_combined_dataframe_of_results(result_not_in_circuit, result_in_circuit, za_result_not_in_circuit, za_result_in_circuit, use_mean_cache=use_mean_cache)
# df.style.apply(color_table, subset=["status"], method = "map")   
print(attn_idx, training_args, tracr_model_class)
df

None {'lr': 0.0001, 'losses': 'all', 'atol': 0.05, 'batch_size': 512, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 0.0, 'scale': 10.0, 'use_ln_hooks': False} <class 'iit.model_pairs.stop_grad_pair.StopGradModelPair'>


Unnamed: 0,node,status,resample_ablate_effect,mean_ablate_effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.823506,0.593443
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.71186,0.504918
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.866308,0.580328
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.787053,0.502732
4,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.818214,0.31694
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.7742,0.29071
6,blocks.1.mlp.hook_post,not_in_circuit,0.904683,0.568306
7,"blocks.1.attn.hook_result, head :2",in_circuit,0.999641,0.660109
8,blocks.0.mlp.hook_post,in_circuit,0.999891,0.660109


In [None]:
# save the results
import time
save_dir = f"results/{tracr_model_class.__name__}/{time.strftime('%d-%H-%M-%S')}"
from iit_utils.evals import save_result
save_result(df, save_dir, model_pair)

In [None]:
# batch = next(iter(test_loader))
# base_in, ablation_in = batch
# base_x, base_y, _ = base_in
# ablation_x, ablation_y, _ = ablation_in
# out, cache = model_pair.ll_model.run_with_cache(base_x)
# base_x.shape, ablation_x.shape

In [None]:
# get_nodes_not_in_circuit(model_pair.ll_model, hl_ll_corr)

In [None]:
# cache['blocks.0.hook_attn_out'].shape, cache['blocks.0.attn.hook_result'].shape, model_pair.ll_model.cfg.n_heads

In [None]:
# np.linalg.norm(cache['blocks.0.hook_attn_out'].cpu().detach().numpy(), axis=2)

### Rough

In [None]:
# def get_all_bad_examples(model_pair, loader, atol=5e-2):
#     model_pair.ll_model.eval()
#     model_pair.hl_model.eval()
#     bad_io_examples = []
#     bad_ii_examples = []

#     for base_in, ablation_in in tqdm(loader):
#         base_in = model_pair.get_encoded_input_from_torch_input(base_in)
#         ablation_in = model_pair.get_encoded_input_from_torch_input(ablation_in)
#         for node in model_pair.corr.keys():
#             hl_node = node.name
#             ll_out, hl_out = model_pair.do_intervention(base_in, ablation_in, hl_node)
#             if model_pair.hl_model.is_categorical():
#                 top1 = t.argmax(ll_out, dim=1)
#                 correct = (top1 == hl_out).float()
#             else:
#                 correct = ((ll_out - hl_out).abs() < atol).float()
            
#             for i, c in enumerate(correct):
#                 print(c)
#                 if c == 0:
#                     bad_ii_examples.append((base_in[i], ablation_in[i]))
#         base_x, base_y = base_in
#         ll_out = model_pair.ll_model(base_x)
#         if model_pair.hl_model.is_categorical():
#             top1 = t.argmax(ll_out, dim=1)
#             correct = (top1 == base_y).float()
#         else:
#             correct = ((ll_out - base_y).abs() < atol).float()
        
#         for i, c in enumerate(correct):
#             if c == 0:
#                 if base_x[i] not in bad_io_examples:
#                     bad_io_examples.append((base_x[i]))

#     return bad_io_examples, bad_ii_examples

# bad_io_examples, bad_ii_examples = get_all_bad_examples(model_pair, test_loader, atol)

# bad_io_examples, bad_ii_examples

In [None]:
# np.random.seed(0)
# t.manual_seed(0)
# test_loader = DataLoader(test_set, batch_size=2, shuffle=True)
# base_in, ablation_in = next(iter(test_loader))

# hooker = model_pair.make_ll_ablation_hook(nodes_not_in_circuit[2])
# base_x, base_y = model_pair.get_encoded_input_from_torch_input(base_in)
# ablation_x, ablation_y = model_pair.get_encoded_input_from_torch_input(ablation_in)
# ll_out = do_intervention(model_pair.ll_model, base_x, ablation_x, nodes_not_in_circuit[2], hooker)
# ll_base_out, ll_base_cache = model_pair.ll_model.run_with_cache(base_x)
# ll_ablation_out, ll_ablation_cache = model_pair.ll_model.run_with_cache(ablation_x)
# for i in range(2):
#     print(
#         "---",
#         f"example {i}", 
#         "base_y:", base_y[i],
#         "ll_base_out:", ll_base_out[i].T,
#         "",
#         "ablation_y:", ablation_y[i],
#         "ll_ablation_out:", ll_ablation_out[i].T,
#         "",
#         "ll_out:", ll_out[i].T,
#         sep="\n"
#     )