### Setup

In [1]:
import sys
import os
import transformer_lens as tl
from torch.utils.data import Dataset
import torch as t
from tqdm import tqdm
from torch.utils.data import DataLoader
import numpy as np
import wandb
from typing import List, Dict, Any, Optional

from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase, CaseDataset
from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer
import iit.model_pairs as mp
import iit.utils.index as index
from iit_utils.dataset import create_dataset, TracrDataset, TracrIITDataset
import iit_utils.correspondence as correspondence
from circuits_benchmark.utils.get_cases import get_cases
from circuits_benchmark.commands.build_main_parser import build_main_parser
from iit_utils.iit_hl_model import make_iit_hl_model

DEVICE = t.device("cuda" if t.cuda.is_available() else "cpu")
WANDB_ENTITY = "cybershiptrooper" # TODO make this an env var

  from .autonotebook import tqdm as notebook_tqdm


### Train Model

In [2]:
attn_idx = None
atol = 5e-2
losses = "all"
tracr_model_class = mp.StrictIITModelPair
case_num = 3
train_model = True
training_args = {
    "lr" : 1e-3,
    "losses" : losses,
    "atol" : atol,
    "batch_size" : 512,
    "use_single_loss": False,
    "iit_weight": 1.0,
    "behavior_weight": 1.0,
    "strict_weight": 0.4,
    # "scale": 1.0,
    # "use_ln_hooks": True,
    "clip_grad_norm": 1,
    # "lr_scheduler": None
}
tracr_model_class.__name__

'StrictIITModelPair'

In [3]:
np.random.seed(0)
t.manual_seed(0)

args, _ = build_main_parser().parse_known_args(["compile",
                                                f"-i={case_num}",
                                                "-f",])
cases = get_cases(args)
case = cases[0]

tracr_output = case.build_tracr_model()
hl_model = case.build_transformer_lens_model()
# this is the graph node -> hl node correspondence
hl_ll_corr = correspondence.TracrCorrespondence.from_output(case, tracr_output)

In [4]:
hl_model([['BOS', 'x', 'b', 'a', 'a']], return_type='decoded')

[['BOS', 1.0, 0.5, 0.3333333432674408, 0.25]]

In [5]:
# seed everything
t.manual_seed(1)
np.random.seed(1)
import random
random.seed(1)

In [6]:
data = case.get_clean_data(count=15000)
inputs = data.get_inputs().to_numpy()
outputs = data.get_correct_outputs().to_numpy()

train_inputs = inputs[:12000]
test_inputs = inputs[12000:]
train_outputs = outputs[:12000]
test_outputs = outputs[12000:]

train_set, test_set = create_dataset(case, hl_model)

In [7]:
from transformer_lens import HookedTransformer, HookedTransformerConfig

cfg_dict = {
    "n_layers": 2, 
    "n_heads": 4, 
    "d_head": 4,
    "d_model": 8,
    "d_mlp": 16,
    "act_fn": "gelu",
}
ll_cfg = hl_model.cfg.to_dict().copy()
ll_cfg.update(cfg_dict)


print(ll_cfg)
ll_cfg = HookedTransformerConfig.from_dict(ll_cfg)
model = HookedTransformer(ll_cfg)


{'n_layers': 2, 'd_model': 8, 'n_ctx': 5, 'd_head': 4, 'model_name': 'custom', 'n_heads': 4, 'd_mlp': 16, 'act_fn': 'gelu', 'd_vocab': 6, 'eps': 1e-05, 'use_attn_result': True, 'use_attn_scale': True, 'use_split_qkv_input': True, 'use_hook_mlp_in': True, 'use_attn_in': False, 'use_local_attn': False, 'original_architecture': None, 'from_checkpoint': False, 'checkpoint_index': None, 'checkpoint_label_type': None, 'checkpoint_value': None, 'tokenizer_name': None, 'window_size': None, 'attn_types': None, 'init_mode': 'gpt2', 'normalization_type': None, 'device': device(type='mps'), 'n_devices': 1, 'attention_dir': 'causal', 'attn_only': False, 'seed': None, 'initializer_range': 0.22188007849009167, 'init_weights': True, 'scale_attn_by_inverse_layer_idx': False, 'positional_embedding_type': 'standard', 'final_rms': False, 'd_vocab_out': 1, 'parallel_attn_mlp': False, 'rotary_dim': None, 'n_params': 676, 'use_hook_tokens': False, 'gated_mlp': False, 'default_prepend_bos': True, 'dtype': tor

In [8]:
model_pair = tracr_model_class(
    hl_model = make_iit_hl_model(hl_model),
    ll_model = model,
    corr = hl_ll_corr,
    training_args=training_args,
)

{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_attn_scores': HookPoint(), 'b

In [9]:

ll_model = HookedTracrTransformer(
    ll_cfg, hl_model.tracr_input_encoder, hl_model.tracr_output_encoder, hl_model.residual_stream_labels
)


In [10]:
if train_model:
    model_pair.train(
        train_set,
        test_set,
        epochs=1000,
        use_wandb=False,
    )
else:
    ll_model.load_weights_from_file(f"ll_models/{case_num}/ll_model_510.pth")
    model_pair.ll_model = ll_model

training_args={'batch_size': 512, 'lr': 0.001, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, 'scheduler_val_metric': 'val/accuracy', 'scheduler_mode': 'max', 'clip_grad_norm': 1, 'atol': 0.05, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 0.4, 'losses': 'all'}


  0%|          | 0/1000 [00:00<?, ?it/s]


100%|██████████| 24/24 [00:02<00:00,  8.76it/s]
  0%|          | 1/1000 [00:03<51:16,  3.08s/it]


Epoch 0: train/iit_loss: 0.0791, train/behavior_loss: 0.0350, train/strict_loss: 0.0161, val/iit_loss: 0.0680, val/IIA: 32.34%, val/accuracy: 46.23%, 


100%|██████████| 24/24 [00:02<00:00, 10.67it/s]
  0%|          | 2/1000 [00:05<45:03,  2.71s/it]


Epoch 1: train/iit_loss: 0.0513, train/behavior_loss: 0.0048, train/strict_loss: 0.0044, val/iit_loss: 0.0238, val/IIA: 43.14%, val/accuracy: 72.28%, 


100%|██████████| 24/24 [00:02<00:00, 10.74it/s]
  0%|          | 3/1000 [00:07<42:58,  2.59s/it]


Epoch 2: train/iit_loss: 0.0104, train/behavior_loss: 0.0017, train/strict_loss: 0.0022, val/iit_loss: 0.0032, val/IIA: 75.15%, val/accuracy: 86.68%, 


100%|██████████| 24/24 [00:02<00:00, 10.75it/s]
  0%|          | 4/1000 [00:10<41:56,  2.53s/it]


Epoch 3: train/iit_loss: 0.0033, train/behavior_loss: 0.0010, train/strict_loss: 0.0021, val/iit_loss: 0.0020, val/IIA: 80.45%, val/accuracy: 95.70%, 


100%|██████████| 24/24 [00:02<00:00, 10.72it/s]
  0%|          | 5/1000 [00:12<41:22,  2.49s/it]


Epoch 4: train/iit_loss: 0.0016, train/behavior_loss: 0.0006, train/strict_loss: 0.0011, val/iit_loss: 0.0013, val/IIA: 86.18%, val/accuracy: 97.17%, 


100%|██████████| 24/24 [00:02<00:00, 10.80it/s]
  1%|          | 6/1000 [00:15<40:54,  2.47s/it]


Epoch 5: train/iit_loss: 0.0011, train/behavior_loss: 0.0004, train/strict_loss: 0.0009, val/iit_loss: 0.0017, val/IIA: 79.18%, val/accuracy: 97.92%, 


100%|██████████| 24/24 [00:02<00:00, 10.63it/s]
  1%|          | 7/1000 [00:17<40:48,  2.47s/it]


Epoch 6: train/iit_loss: 0.0008, train/behavior_loss: 0.0003, train/strict_loss: 0.0009, val/iit_loss: 0.0005, val/IIA: 96.74%, val/accuracy: 99.68%, 


100%|██████████| 24/24 [00:02<00:00, 10.72it/s]
  1%|          | 8/1000 [00:20<40:36,  2.46s/it]


Epoch 7: train/iit_loss: 0.0005, train/behavior_loss: 0.0003, train/strict_loss: 0.0005, val/iit_loss: 0.0010, val/IIA: 88.80%, val/accuracy: 99.92%, 


100%|██████████| 24/24 [00:02<00:00, 10.60it/s]
  1%|          | 9/1000 [00:22<40:35,  2.46s/it]


Epoch 8: train/iit_loss: 0.0008, train/behavior_loss: 0.0005, train/strict_loss: 0.0005, val/iit_loss: 0.0007, val/IIA: 92.37%, val/accuracy: 99.92%, 


100%|██████████| 24/24 [00:02<00:00, 10.75it/s]
  1%|          | 10/1000 [00:25<40:25,  2.45s/it]


Epoch 9: train/iit_loss: 0.0005, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0007, val/IIA: 91.73%, val/accuracy: 99.92%, 


100%|██████████| 24/24 [00:02<00:00, 10.59it/s]
  1%|          | 11/1000 [00:27<40:27,  2.45s/it]


Epoch 10: train/iit_loss: 0.0007, train/behavior_loss: 0.0003, train/strict_loss: 0.0003, val/iit_loss: 0.0005, val/IIA: 94.45%, val/accuracy: 99.92%, 


100%|██████████| 24/24 [00:02<00:00, 10.71it/s]
  1%|          | 12/1000 [00:29<40:20,  2.45s/it]


Epoch 11: train/iit_loss: 0.0005, train/behavior_loss: 0.0002, train/strict_loss: 0.0003, val/iit_loss: 0.0005, val/IIA: 95.37%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.65it/s]
  1%|▏         | 13/1000 [00:32<40:20,  2.45s/it]


Epoch 12: train/iit_loss: 0.0004, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0002, val/IIA: 98.37%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.77it/s]
  1%|▏         | 14/1000 [00:34<40:09,  2.44s/it]


Epoch 13: train/iit_loss: 0.0004, train/behavior_loss: 0.0002, train/strict_loss: 0.0002, val/iit_loss: 0.0006, val/IIA: 94.17%, val/accuracy: 99.92%, 


100%|██████████| 24/24 [00:02<00:00, 10.74it/s]
  2%|▏         | 15/1000 [00:37<40:04,  2.44s/it]


Epoch 14: train/iit_loss: 0.0003, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0004, val/IIA: 95.23%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.56it/s]
  2%|▏         | 16/1000 [00:39<40:10,  2.45s/it]


Epoch 15: train/iit_loss: 0.0003, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0004, val/IIA: 96.71%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.62it/s]
  2%|▏         | 17/1000 [00:42<40:10,  2.45s/it]


Epoch 16: train/iit_loss: 0.0003, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0003, val/IIA: 97.50%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.67it/s]
  2%|▏         | 18/1000 [00:44<40:07,  2.45s/it]


Epoch 17: train/iit_loss: 0.0003, train/behavior_loss: 0.0002, train/strict_loss: 0.0002, val/iit_loss: 0.0001, val/IIA: 99.32%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.75it/s]
  2%|▏         | 19/1000 [00:47<39:58,  2.45s/it]


Epoch 18: train/iit_loss: 0.0003, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0002, val/IIA: 98.74%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.64it/s]
  2%|▏         | 20/1000 [00:49<39:59,  2.45s/it]


Epoch 19: train/iit_loss: 0.0002, train/behavior_loss: 0.0001, train/strict_loss: 0.0001, val/iit_loss: 0.0003, val/IIA: 97.83%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.68it/s]
  2%|▏         | 21/1000 [00:51<39:56,  2.45s/it]


Epoch 20: train/iit_loss: 0.0002, train/behavior_loss: 0.0001, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.99%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.65it/s]
  2%|▏         | 22/1000 [00:54<39:54,  2.45s/it]


Epoch 21: train/iit_loss: 0.0002, train/behavior_loss: 0.0001, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.81%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.59it/s]
  2%|▏         | 23/1000 [00:56<39:57,  2.45s/it]


Epoch 22: train/iit_loss: 0.0002, train/behavior_loss: 0.0001, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.00%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.73it/s]
  2%|▏         | 24/1000 [00:59<39:49,  2.45s/it]


Epoch 23: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.97%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.60it/s]
  2%|▎         | 25/1000 [01:01<39:51,  2.45s/it]


Epoch 24: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.28%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.56it/s]
  3%|▎         | 26/1000 [01:04<39:55,  2.46s/it]


Epoch 25: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.72%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.70it/s]
  3%|▎         | 27/1000 [01:06<39:48,  2.45s/it]


Epoch 26: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.68%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.47it/s]
  3%|▎         | 28/1000 [01:09<39:57,  2.47s/it]


Epoch 27: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.65%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.34it/s]
  3%|▎         | 29/1000 [01:11<40:13,  2.49s/it]


Epoch 28: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.67%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.15it/s]
  3%|▎         | 30/1000 [01:14<40:36,  2.51s/it]


Epoch 29: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.69%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.13it/s]
  3%|▎         | 31/1000 [01:16<40:50,  2.53s/it]


Epoch 30: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.41%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.12it/s]
  3%|▎         | 32/1000 [01:19<41:01,  2.54s/it]


Epoch 31: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.68%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.45it/s]
  3%|▎         | 33/1000 [01:21<40:45,  2.53s/it]


Epoch 32: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.28%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.70it/s]
  3%|▎         | 34/1000 [01:24<40:17,  2.50s/it]


Epoch 33: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.03%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.61it/s]
  4%|▎         | 35/1000 [01:26<40:02,  2.49s/it]


Epoch 34: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.38%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.67it/s]
  4%|▎         | 36/1000 [01:29<39:47,  2.48s/it]


Epoch 35: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.29%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.67it/s]
  4%|▎         | 37/1000 [01:31<39:36,  2.47s/it]


Epoch 36: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.42%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.74it/s]
  4%|▍         | 38/1000 [01:34<39:23,  2.46s/it]


Epoch 37: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.18%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.71it/s]
  4%|▍         | 39/1000 [01:36<39:14,  2.45s/it]


Epoch 38: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.14%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.81it/s]
  4%|▍         | 40/1000 [01:39<39:02,  2.44s/it]


Epoch 39: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.34%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.73it/s]
  4%|▍         | 41/1000 [01:41<38:59,  2.44s/it]


Epoch 40: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.14%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.79it/s]
  4%|▍         | 42/1000 [01:43<38:52,  2.43s/it]


Epoch 41: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.73%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.74it/s]
  4%|▍         | 43/1000 [01:46<38:49,  2.43s/it]


Epoch 42: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.11%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.81it/s]
  4%|▍         | 44/1000 [01:48<38:42,  2.43s/it]


Epoch 43: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.26%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.73it/s]
  4%|▍         | 45/1000 [01:51<38:41,  2.43s/it]


Epoch 44: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.77%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.71it/s]
  5%|▍         | 46/1000 [01:53<38:40,  2.43s/it]


Epoch 45: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.21%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.99it/s]
  5%|▍         | 47/1000 [01:55<38:22,  2.42s/it]


Epoch 46: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.15%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.85it/s]
  5%|▍         | 48/1000 [01:58<38:18,  2.41s/it]


Epoch 47: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.09%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.90it/s]
  5%|▍         | 49/1000 [02:00<38:11,  2.41s/it]


Epoch 48: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.78%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 11.00it/s]
  5%|▌         | 50/1000 [02:03<37:59,  2.40s/it]


Epoch 49: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.87%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.95it/s]
  5%|▌         | 51/1000 [02:05<37:54,  2.40s/it]


Epoch 50: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.05%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.82it/s]
  5%|▌         | 52/1000 [02:07<37:55,  2.40s/it]


Epoch 51: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.06%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 11.00it/s]
  5%|▌         | 53/1000 [02:10<37:46,  2.39s/it]


Epoch 52: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.85%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 11.00it/s]
  5%|▌         | 54/1000 [02:12<37:38,  2.39s/it]


Epoch 53: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.03%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.93it/s]
  6%|▌         | 55/1000 [02:15<37:37,  2.39s/it]


Epoch 54: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.76%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.87it/s]
  6%|▌         | 56/1000 [02:17<37:40,  2.39s/it]


Epoch 55: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.85%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 11.04it/s]
  6%|▌         | 57/1000 [02:19<37:30,  2.39s/it]


Epoch 56: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.78%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.98it/s]
  6%|▌         | 58/1000 [02:22<37:26,  2.38s/it]


Epoch 57: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.17%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.85it/s]
  6%|▌         | 59/1000 [02:24<37:31,  2.39s/it]


Epoch 58: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.77%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.98it/s]
  6%|▌         | 60/1000 [02:27<37:25,  2.39s/it]


Epoch 59: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.84%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 11.01it/s]
  6%|▌         | 61/1000 [02:29<37:18,  2.38s/it]


Epoch 60: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.84%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.87it/s]
  6%|▌         | 62/1000 [02:31<37:22,  2.39s/it]


Epoch 61: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.12%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.95it/s]
  6%|▋         | 63/1000 [02:34<37:19,  2.39s/it]


Epoch 62: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.84%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 11.02it/s]
  6%|▋         | 64/1000 [02:36<37:11,  2.38s/it]


Epoch 63: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.04%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.95it/s]
  6%|▋         | 65/1000 [02:38<37:10,  2.39s/it]


Epoch 64: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.39%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.89it/s]
  7%|▋         | 66/1000 [02:41<37:11,  2.39s/it]


Epoch 65: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.91%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 11.01it/s]
  7%|▋         | 67/1000 [02:43<37:04,  2.38s/it]


Epoch 66: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.82%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.94it/s]
  7%|▋         | 68/1000 [02:46<37:03,  2.39s/it]


Epoch 67: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0003, val/IIA: 98.24%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.95it/s]
  7%|▋         | 69/1000 [02:48<37:02,  2.39s/it]


Epoch 68: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0003, val/IIA: 98.23%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 11.04it/s]
  7%|▋         | 70/1000 [02:50<36:55,  2.38s/it]


Epoch 69: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.16%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.97it/s]
  7%|▋         | 71/1000 [02:53<36:53,  2.38s/it]


Epoch 70: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.50%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.96it/s]
  7%|▋         | 72/1000 [02:55<36:51,  2.38s/it]


Epoch 71: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.83%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.85it/s]
  7%|▋         | 73/1000 [02:58<36:54,  2.39s/it]


Epoch 72: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.88%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.97it/s]
  7%|▋         | 74/1000 [03:00<36:51,  2.39s/it]


Epoch 73: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.15%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 11.00it/s]
  8%|▊         | 75/1000 [03:02<36:45,  2.38s/it]


Epoch 74: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.09%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.96it/s]
  8%|▊         | 76/1000 [03:05<36:43,  2.38s/it]


Epoch 75: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.02%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.45it/s]
  8%|▊         | 77/1000 [03:07<37:12,  2.42s/it]


Epoch 76: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.55%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.98it/s]
  8%|▊         | 78/1000 [03:10<36:59,  2.41s/it]


Epoch 77: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.71%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.92it/s]
  8%|▊         | 79/1000 [03:12<36:52,  2.40s/it]


Epoch 78: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.18%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 11.00it/s]
  8%|▊         | 80/1000 [03:14<36:44,  2.40s/it]


Epoch 79: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.57%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.97it/s]
  8%|▊         | 81/1000 [03:17<36:39,  2.39s/it]


Epoch 80: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.88%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.94it/s]
  8%|▊         | 82/1000 [03:19<36:36,  2.39s/it]


Epoch 81: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.09%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.95it/s]
  8%|▊         | 83/1000 [03:22<36:32,  2.39s/it]


Epoch 82: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.74%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.86it/s]
  8%|▊         | 84/1000 [03:24<36:34,  2.40s/it]


Epoch 83: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.83%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.97it/s]
  8%|▊         | 85/1000 [03:26<36:28,  2.39s/it]


Epoch 84: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.20%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.93it/s]
  9%|▊         | 86/1000 [03:29<36:27,  2.39s/it]


Epoch 85: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.14%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.59it/s]
  9%|▊         | 87/1000 [03:31<36:44,  2.41s/it]


Epoch 86: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.79%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 11.09it/s]
  9%|▉         | 88/1000 [03:34<36:26,  2.40s/it]


Epoch 87: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.84%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [13:54<00:00, 34.77s/it]
  9%|▉         | 89/1000 [17:28<63:47:29, 252.09s/it]


Epoch 88: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.41%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.96it/s]
  9%|▉         | 90/1000 [17:31<44:47:09, 177.17s/it]


Epoch 89: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.38%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.94it/s]
  9%|▉         | 91/1000 [17:33<31:29:47, 124.74s/it]


Epoch 90: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.13%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.99it/s]
  9%|▉         | 92/1000 [17:35<22:12:13, 88.03s/it] 


Epoch 91: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.14%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.56it/s]
  9%|▉         | 93/1000 [17:38<15:42:45, 62.37s/it]


Epoch 92: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.10%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.29it/s]
  9%|▉         | 94/1000 [17:40<11:10:42, 44.42s/it]


Epoch 93: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.13%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.83it/s]
 10%|▉         | 95/1000 [17:43<7:59:53, 31.82s/it] 


Epoch 94: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.15%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.71it/s]
 10%|▉         | 96/1000 [17:45<5:46:35, 23.00s/it]


Epoch 95: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.89%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.91it/s]
 10%|▉         | 97/1000 [17:48<4:13:10, 16.82s/it]


Epoch 96: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.07%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.80it/s]
 10%|▉         | 98/1000 [17:50<3:07:56, 12.50s/it]


Epoch 97: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.96%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.93it/s]
 10%|▉         | 99/1000 [17:52<2:22:10,  9.47s/it]


Epoch 98: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.42%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.63it/s]
 10%|█         | 100/1000 [17:55<1:50:27,  7.36s/it]


Epoch 99: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.84%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.95it/s]
 10%|█         | 101/1000 [17:57<1:27:57,  5.87s/it]


Epoch 100: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.70%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.75it/s]
 10%|█         | 102/1000 [18:00<1:12:24,  4.84s/it]


Epoch 101: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.33%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.78it/s]
 10%|█         | 103/1000 [18:02<1:01:31,  4.11s/it]


Epoch 102: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0003, val/IIA: 98.23%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.82it/s]
 10%|█         | 104/1000 [18:05<53:50,  3.61s/it]  


Epoch 103: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.76%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.77it/s]
 10%|█         | 105/1000 [18:07<48:30,  3.25s/it]


Epoch 104: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.84%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.92it/s]
 11%|█         | 106/1000 [18:09<44:36,  2.99s/it]


Epoch 105: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.82%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.93it/s]
 11%|█         | 107/1000 [18:12<41:51,  2.81s/it]


Epoch 106: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.21%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.93it/s]
 11%|█         | 108/1000 [18:14<39:55,  2.69s/it]


Epoch 107: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.83%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.54it/s]
 11%|█         | 109/1000 [18:17<38:56,  2.62s/it]


Epoch 108: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.15%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.55it/s]
 11%|█         | 110/1000 [18:19<38:14,  2.58s/it]


Epoch 109: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.47%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.87it/s]
 11%|█         | 111/1000 [18:22<37:25,  2.53s/it]


Epoch 110: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.37%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.85it/s]
 11%|█         | 112/1000 [18:24<36:51,  2.49s/it]


Epoch 111: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.17%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.89it/s]
 11%|█▏        | 113/1000 [18:26<36:25,  2.46s/it]


Epoch 112: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.07%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.88it/s]
 11%|█▏        | 114/1000 [18:29<36:05,  2.44s/it]


Epoch 113: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.69%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.86it/s]
 12%|█▏        | 115/1000 [18:31<35:52,  2.43s/it]


Epoch 114: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.87%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.84it/s]
 12%|█▏        | 116/1000 [18:34<35:44,  2.43s/it]


Epoch 115: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.49%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.31it/s]
 12%|█▏        | 117/1000 [18:36<36:08,  2.46s/it]


Epoch 116: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.89%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.24it/s]
 12%|█▏        | 118/1000 [18:39<36:31,  2.49s/it]


Epoch 117: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.16%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.43it/s]
 12%|█▏        | 119/1000 [18:41<37:40,  2.57s/it]


Epoch 118: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.77%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.21it/s]
 12%|█▏        | 120/1000 [18:44<38:45,  2.64s/it]


Epoch 119: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.43%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.87it/s]
 12%|█▏        | 121/1000 [18:47<38:43,  2.64s/it]


Epoch 120: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.88%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.80it/s]
 12%|█▏        | 122/1000 [18:50<38:43,  2.65s/it]


Epoch 121: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.46%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.46it/s]
 12%|█▏        | 123/1000 [18:52<38:00,  2.60s/it]


Epoch 122: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.80%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.78it/s]
 12%|█▏        | 124/1000 [18:55<38:13,  2.62s/it]


Epoch 123: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.10%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.70it/s]
 12%|█▎        | 125/1000 [18:57<37:23,  2.56s/it]


Epoch 124: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.82%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.90it/s]
 13%|█▎        | 126/1000 [19:00<36:36,  2.51s/it]


Epoch 125: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.47%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.89it/s]
 13%|█▎        | 127/1000 [19:02<36:04,  2.48s/it]


Epoch 126: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.31%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.89it/s]
 13%|█▎        | 128/1000 [19:04<35:41,  2.46s/it]


Epoch 127: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.69%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.93it/s]
 13%|█▎        | 129/1000 [19:07<35:21,  2.44s/it]


Epoch 128: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.43%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.40it/s]
 13%|█▎        | 130/1000 [19:09<35:41,  2.46s/it]


Epoch 129: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.81%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.00it/s]
 13%|█▎        | 131/1000 [19:12<36:13,  2.50s/it]


Epoch 130: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.79%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.40it/s]
 13%|█▎        | 132/1000 [19:14<36:13,  2.50s/it]


Epoch 131: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.11%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.41it/s]
 13%|█▎        | 133/1000 [19:17<36:12,  2.51s/it]


Epoch 132: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.15%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.17it/s]
 13%|█▎        | 134/1000 [19:19<36:26,  2.52s/it]


Epoch 133: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.22%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.13it/s]
 14%|█▎        | 135/1000 [19:22<36:37,  2.54s/it]


Epoch 134: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.43%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.30it/s]
 14%|█▎        | 136/1000 [19:25<36:35,  2.54s/it]


Epoch 135: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.45%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.08it/s]
 14%|█▎        | 137/1000 [19:27<36:44,  2.55s/it]


Epoch 136: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.10%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.25it/s]
 14%|█▍        | 138/1000 [19:30<36:39,  2.55s/it]


Epoch 137: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.47%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.29it/s]
 14%|█▍        | 139/1000 [19:32<36:32,  2.55s/it]


Epoch 138: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.32%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.35it/s]
 14%|█▍        | 140/1000 [19:35<36:22,  2.54s/it]


Epoch 139: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.15%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.39it/s]
 14%|█▍        | 141/1000 [19:37<36:13,  2.53s/it]


Epoch 140: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.80%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.46it/s]
 14%|█▍        | 142/1000 [19:40<36:02,  2.52s/it]


Epoch 141: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.71%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.44it/s]
 14%|█▍        | 143/1000 [19:42<35:55,  2.52s/it]


Epoch 142: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.08%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.26it/s]
 14%|█▍        | 144/1000 [19:45<35:59,  2.52s/it]


Epoch 143: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.61%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.53it/s]
 14%|█▍        | 145/1000 [19:47<35:47,  2.51s/it]


Epoch 144: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.40%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.29it/s]
 15%|█▍        | 146/1000 [19:50<35:51,  2.52s/it]


Epoch 145: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.35%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.29it/s]
 15%|█▍        | 147/1000 [19:52<35:53,  2.52s/it]


Epoch 146: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.43%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.77it/s]
 15%|█▍        | 148/1000 [19:55<36:31,  2.57s/it]


Epoch 147: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.45%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.46it/s]
 15%|█▍        | 149/1000 [19:58<37:09,  2.62s/it]


Epoch 148: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.85%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.24it/s]
 15%|█▌        | 150/1000 [20:00<36:47,  2.60s/it]


Epoch 149: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.17%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.30it/s]
 15%|█▌        | 151/1000 [20:03<36:29,  2.58s/it]


Epoch 150: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.38%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.38it/s]
 15%|█▌        | 152/1000 [20:05<36:10,  2.56s/it]


Epoch 151: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.86%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.20it/s]
 15%|█▌        | 153/1000 [20:08<36:07,  2.56s/it]


Epoch 152: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.65%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.45it/s]
 15%|█▌        | 154/1000 [20:11<37:06,  2.63s/it]


Epoch 153: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.73%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.15it/s]
 16%|█▌        | 155/1000 [20:13<36:48,  2.61s/it]


Epoch 154: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.79%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.39it/s]
 16%|█▌        | 156/1000 [20:16<37:40,  2.68s/it]


Epoch 155: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.55%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:03<00:00,  7.15it/s]
 16%|█▌        | 157/1000 [20:20<41:44,  2.97s/it]


Epoch 156: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.74%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:03<00:00,  7.86it/s]
 16%|█▌        | 158/1000 [20:23<42:53,  3.06s/it]


Epoch 157: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.35%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.10it/s]
 16%|█▌        | 159/1000 [20:26<40:51,  2.91s/it]


Epoch 158: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.07%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.34it/s]
 16%|█▌        | 160/1000 [20:28<39:09,  2.80s/it]


Epoch 159: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.19%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.13it/s]
 16%|█▌        | 161/1000 [20:31<38:10,  2.73s/it]


Epoch 160: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0001, val/IIA: 99.36%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.76it/s]
 16%|█▌        | 162/1000 [20:33<37:49,  2.71s/it]


Epoch 161: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.51%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.88it/s]
 16%|█▋        | 163/1000 [20:36<37:26,  2.68s/it]


Epoch 162: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.19%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.15it/s]
 16%|█▋        | 164/1000 [20:39<36:52,  2.65s/it]


Epoch 163: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.58%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.70it/s]
 16%|█▋        | 165/1000 [20:41<37:00,  2.66s/it]


Epoch 164: train/iit_loss: 0.0001, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 99.15%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00,  9.87it/s]
 17%|█▋        | 166/1000 [20:44<36:50,  2.65s/it]


Epoch 165: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.87%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:02<00:00, 10.46it/s]
 17%|█▋        | 166/1000 [20:46<1:44:24,  7.51s/it]


Epoch 166: train/iit_loss: 0.0002, train/behavior_loss: 0.0000, train/strict_loss: 0.0001, val/iit_loss: 0.0000, val/IIA: 100.00%, val/accuracy: 100.00%, 





### Setup Eval

In [11]:
"""Create a new test set with unique inputs"""

arr, idxs = np.unique([", ".join(i) for i in np.array(test_inputs)], return_inverse=True)
# create indices that point to the first unique input
all_possible_inputs = np.arange(arr.shape[0])
# find the first occurence of all_possible_inputs in idxs
first_occurences = [np.where(idxs == i)[0][0] for i in all_possible_inputs]

unique_test_inputs = test_inputs[first_occurences]
unique_test_outputs = test_outputs[first_occurences]
assert len(unique_test_inputs) == len(unique_test_outputs)
assert len(unique_test_inputs) == len(np.unique([", ".join(i) for i in np.array(test_inputs)]))
assert len(np.unique([", ".join(i) for i in np.array(unique_test_inputs)])) == len(unique_test_inputs)

unique_test_data = TracrDataset(unique_test_inputs, unique_test_outputs)
test_set = TracrIITDataset(unique_test_data, unique_test_data, hl_model, every_combination=True)
test_loader = test_set.make_loader(batch_size=512, num_workers=0)

In [12]:
def tokenise_data(batch, model: HookedTracrTransformer) -> t.Tensor:
    x = list(map(list, zip(*batch)))
    encoded_x = model.map_tracr_input_to_tl_input(x)
    return encoded_x

In [13]:
tensorised_base_data = []
tensorised_ablation_data = []
base_answer_tokens = []
for base_in, ablation_in in test_loader:
    base_x, base_y,  _ = base_in
    ablation_x, ablation_y, _ = ablation_in

    tensorised_base_data.append((base_x))
    tensorised_ablation_data.append((ablation_x))
    base_answer_tokens.append(base_y)

base_tensor = t.cat(tensorised_base_data, dim=0)
ablation_tensor = t.cat(tensorised_ablation_data, dim=0)
base_answer_tokens = t.cat(base_answer_tokens, dim=0)

In [14]:
model.requires_grad_(False)
model.eval()
hl_model.requires_grad_(False)
hl_model.eval()
print()




In [15]:
original_logits, cache = model.run_with_cache(base_tensor)

In [16]:
hl_answers = hl_model(base_tensor)
(hl_answers.shape), hl_answers[0], base_answer_tokens[0] # Wtf???

(torch.Size([65536, 5, 1]),
 tensor([[0.0000],
         [0.0000],
         [0.0000],
         [0.3333],
         [0.2500]], device='mps:0'),
 tensor([0.0000, 0.0000, 0.0000, 0.3333, 0.2500], device='mps:0'))

In [17]:
original_logits.shape, original_logits[3], hl_answers.shape, hl_answers[3]

(torch.Size([65536, 5, 1]),
 tensor([[-1.0685e-05],
         [ 1.0009e+00],
         [ 5.0476e-01],
         [ 3.3514e-01],
         [ 2.4876e-01]], device='mps:0'),
 torch.Size([65536, 5, 1]),
 tensor([[0.0000],
         [1.0000],
         [0.5000],
         [0.3333],
         [0.2500]], device='mps:0'))

### Patch Attention Heads to see Causal Effect

In [18]:
from iit.utils.node_picker import get_nodes_not_in_circuit, get_nodes_in_circuit, get_all_nodes

nodes_not_in_circuit = get_nodes_not_in_circuit(model_pair.ll_model, hl_ll_corr)
nodes_not_in_circuit, "---", list(hl_ll_corr.values())

([LLNode(name='blocks.0.attn.hook_result', index=[:, :, 0, :], subspace=None),
  LLNode(name='blocks.0.attn.hook_result', index=[:, :, 1, :], subspace=None),
  LLNode(name='blocks.0.attn.hook_result', index=[:, :, 2, :], subspace=None),
  LLNode(name='blocks.0.attn.hook_result', index=[:, :, 3, :], subspace=None),
  LLNode(name='blocks.1.attn.hook_result', index=[:, :, 2, :], subspace=None),
  LLNode(name='blocks.1.attn.hook_result', index=[:, :, 3, :], subspace=None),
  LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)],
 '---',
 [{LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)},
  {LLNode(name='blocks.1.attn.hook_result', index=[:, :, :2, :], subspace=None)}])

In [19]:
from iit.utils.eval_ablations import check_causal_effect, make_dataframe_of_results

In [20]:
np.random.seed(0)
t.manual_seed(0)
result_not_in_circuit = check_causal_effect(model_pair, test_set, node_type="n", verbose=False)
result_in_circuit = check_causal_effect(model_pair, test_set, node_type="c", verbose=False)

100%|██████████| 256/256 [00:22<00:00, 11.14it/s]
100%|██████████| 256/256 [00:07<00:00, 35.81it/s]


In [21]:
df = make_dataframe_of_results(result_not_in_circuit, result_in_circuit)
# df.style.apply(color_table, subset=["status"])
print(attn_idx, training_args, tracr_model_class)
df

None {'lr': 0.001, 'losses': 'all', 'atol': 0.05, 'batch_size': 512, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 0.4, 'clip_grad_norm': 1} <class 'iit.model_pairs.strict_iit_model_pair.StrictIITModelPair'>


Unnamed: 0,node,status,causal effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.002144
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.007174
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.000595
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.006195
4,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.011373
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.000119
6,blocks.1.mlp.hook_post,not_in_circuit,0.023031
7,"blocks.1.attn.hook_result, head :2",in_circuit,1.0
8,blocks.0.mlp.hook_post,in_circuit,1.0


In [22]:
from iit.utils.metric import MetricStore
def print_metrics(metrics: list[MetricStore]):
    for metric in metrics:
        print(f"{metric.get_name()}: {metric.get_value()}")

metric_collection = model_pair._run_eval_epoch(test_loader, model_pair.loss_fn)

In [23]:
print_metrics(metric_collection.metrics)

val/iit_loss: 0.00015295915579827124
val/IIA: 99.14306686259806
val/accuracy: 100.0


### Do the same with zero ablations

In [1]:
from iit.utils.eval_ablations import get_causal_effects_for_all_nodes
from iit_utils.dataset import TracrUniqueDataset

  from .autonotebook import tqdm as notebook_tqdm


In [25]:
uni_test_set = TracrUniqueDataset(unique_test_data, unique_test_data, hl_model, every_combination=True)

In [26]:
np.random.seed(0)
t.manual_seed(0)
use_mean_cache = True
za_result_not_in_circuit, za_result_in_circuit = get_causal_effects_for_all_nodes(model_pair, uni_test_set, use_mean_cache=use_mean_cache, batch_size=len(uni_test_set))

100%|██████████| 32/32 [00:00<00:00, 74.20it/s]
100%|██████████| 1/1 [00:00<00:00, 10.60it/s]
100%|██████████| 32/32 [00:00<00:00, 100.41it/s]
100%|██████████| 1/1 [00:00<00:00, 36.74it/s]


In [27]:
df = make_dataframe_of_results(za_result_not_in_circuit, za_result_in_circuit)
# df.style.map(color_table, subset=["status"])
print(attn_idx, training_args, tracr_model_class)
df

None {'lr': 0.001, 'losses': 'all', 'atol': 0.05, 'batch_size': 512, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 0.4, 'clip_grad_norm': 1} <class 'iit.model_pairs.strict_iit_model_pair.StrictIITModelPair'>


Unnamed: 0,node,status,causal effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.000781
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.0
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.0
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.0
4,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.0
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.0
6,blocks.1.mlp.hook_post,not_in_circuit,0.000781
7,"blocks.1.attn.hook_result, head :2",in_circuit,0.715625
8,blocks.0.mlp.hook_post,in_circuit,0.715625


### Combined table

In [28]:
from iit.utils.eval_ablations import make_combined_dataframe_of_results
df = make_combined_dataframe_of_results(result_not_in_circuit, result_in_circuit, za_result_not_in_circuit, za_result_in_circuit, use_mean_cache=use_mean_cache)
# df.style.apply(color_table, subset=["status"], method = "map")   
print(attn_idx, training_args, tracr_model_class)
df

None {'lr': 0.001, 'losses': 'all', 'atol': 0.05, 'batch_size': 512, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 0.4, 'clip_grad_norm': 1} <class 'iit.model_pairs.strict_iit_model_pair.StrictIITModelPair'>


Unnamed: 0,node,status,resample_ablate_effect,mean_ablate_effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.002144,0.000781
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.007174,0.0
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.000595,0.0
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.006195,0.0
4,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.011373,0.0
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.000119,0.0
6,blocks.1.mlp.hook_post,not_in_circuit,0.023031,0.000781
7,"blocks.1.attn.hook_result, head :2",in_circuit,1.0,0.715625
8,blocks.0.mlp.hook_post,in_circuit,1.0,0.715625


In [29]:
# save the results
# import time
# save_dir = f"results/{tracr_model_class.__name__}/{time.strftime('%d-%H-%M-%S')}"
# from iit.utils.eval_ablations import save_result
# save_result(df, save_dir, model_pair)

In [30]:
# batch = next(iter(test_loader))
# base_in, ablation_in = batch
# base_x, base_y, _ = base_in
# ablation_x, ablation_y, _ = ablation_in
# out, cache = model_pair.ll_model.run_with_cache(base_x)
# base_x.shape, ablation_x.shape

In [31]:
# get_nodes_not_in_circuit(model_pair.ll_model, hl_ll_corr)

In [32]:
# cache['blocks.0.hook_attn_out'].shape, cache['blocks.0.attn.hook_result'].shape, model_pair.ll_model.cfg.n_heads

In [33]:
# np.linalg.norm(cache['blocks.0.hook_attn_out'].cpu().detach().numpy(), axis=2)

### Rough

In [34]:
# def get_all_bad_examples(model_pair, loader, atol=5e-2):
#     model_pair.ll_model.eval()
#     model_pair.hl_model.eval()
#     bad_io_examples = []
#     bad_ii_examples = []

#     for base_in, ablation_in in tqdm(loader):
#         base_in = model_pair.get_encoded_input_from_torch_input(base_in)
#         ablation_in = model_pair.get_encoded_input_from_torch_input(ablation_in)
#         for node in model_pair.corr.keys():
#             hl_node = node.name
#             ll_out, hl_out = model_pair.do_intervention(base_in, ablation_in, hl_node)
#             if model_pair.hl_model.is_categorical():
#                 top1 = t.argmax(ll_out, dim=1)
#                 correct = (top1 == hl_out).float()
#             else:
#                 correct = ((ll_out - hl_out).abs() < atol).float()
            
#             for i, c in enumerate(correct):
#                 print(c)
#                 if c == 0:
#                     bad_ii_examples.append((base_in[i], ablation_in[i]))
#         base_x, base_y = base_in
#         ll_out = model_pair.ll_model(base_x)
#         if model_pair.hl_model.is_categorical():
#             top1 = t.argmax(ll_out, dim=1)
#             correct = (top1 == base_y).float()
#         else:
#             correct = ((ll_out - base_y).abs() < atol).float()
        
#         for i, c in enumerate(correct):
#             if c == 0:
#                 if base_x[i] not in bad_io_examples:
#                     bad_io_examples.append((base_x[i]))

#     return bad_io_examples, bad_ii_examples

# bad_io_examples, bad_ii_examples = get_all_bad_examples(model_pair, test_loader, atol)

# bad_io_examples, bad_ii_examples

In [35]:
# np.random.seed(0)
# t.manual_seed(0)
# test_loader = DataLoader(test_set, batch_size=2, shuffle=True)
# base_in, ablation_in = next(iter(test_loader))

# hooker = model_pair.make_ll_ablation_hook(nodes_not_in_circuit[2])
# base_x, base_y = model_pair.get_encoded_input_from_torch_input(base_in)
# ablation_x, ablation_y = model_pair.get_encoded_input_from_torch_input(ablation_in)
# ll_out = do_intervention(model_pair.ll_model, base_x, ablation_x, nodes_not_in_circuit[2], hooker)
# ll_base_out, ll_base_cache = model_pair.ll_model.run_with_cache(base_x)
# ll_ablation_out, ll_ablation_cache = model_pair.ll_model.run_with_cache(ablation_x)
# for i in range(2):
#     print(
#         "---",
#         f"example {i}", 
#         "base_y:", base_y[i],
#         "ll_base_out:", ll_base_out[i].T,
#         "",
#         "ablation_y:", ablation_y[i],
#         "ll_ablation_out:", ll_ablation_out[i].T,
#         "",
#         "ll_out:", ll_out[i].T,
#         sep="\n"
#     )