### Setup

In [1]:
import sys
import os
import transformer_lens as tl
from torch.utils.data import Dataset
import torch as t
from tqdm import tqdm
from torch.utils.data import DataLoader
import numpy as np
import wandb
from typing import List, Dict, Any, Optional

from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase, CaseDataset
from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer
import iit.model_pairs as mp
import iit.utils.index as index
from iit_utils.dataset import create_dataset, TracrDataset, TracrIITDataset
import iit_utils.correspondence as correspondence
from circuits_benchmark.utils.get_cases import get_cases
from circuits_benchmark.commands.build_main_parser import build_main_parser
from iit_utils.iit_hl_model import make_iit_hl_model

DEVICE = t.device("cuda" if t.cuda.is_available() else "cpu")
WANDB_ENTITY = "cybershiptrooper" # TODO make this an env var

  from .autonotebook import tqdm as notebook_tqdm


### Train Model

In [2]:
attn_idx = index.Ix[:, :, :2, :]
atol = 5e-2
losses = "all"
tracr_model_class = mp.StopGradModelPair
case_num = 3

training_args = {
    "lr" : 1e-2,
    "losses" : losses,
    "atol" : atol,
    "batch_size" : 512,
    "use_single_loss": False,
    "iit_weight": 1.0,
    "behavior_weight": 1.0,
    "strict_weight": 1.0,
}
tracr_model_class.__name__

'StopGradModelPair'

In [3]:
np.random.seed(0)
t.manual_seed(0)

args, _ = build_main_parser().parse_known_args(["compile",
                                                f"-i={case_num}",
                                                "-f",])
cases = get_cases(args)
case = cases[0]

tracr_output = case.build_tracr_model()
hl_model = case.build_transformer_lens_model()
# this is the graph node -> hl node correspondence
tracr_hl_corr = correspondence.TracrCorrespondence.from_output(tracr_output)

In [4]:
hl_model([['BOS', 'x', 'b', 'a', 'a']], return_type='decoded')

[['BOS', 1.0, 0.5, 0.3333333432674408, 0.25]]

In [5]:
# seed everything
t.manual_seed(0)
np.random.seed(0)
import random
random.seed(0)

In [6]:
data = case.get_clean_data(count=15000)
inputs = data.get_inputs().to_numpy()
outputs = data.get_correct_outputs().to_numpy()

train_inputs = inputs[:12000]
test_inputs = inputs[12000:]
train_outputs = outputs[:12000]
test_outputs = outputs[12000:]

train_set, test_set = create_dataset(case, hl_model)

In [7]:
from transformer_lens import HookedTransformer, HookedTransformerConfig

cfg_dict = {
    "n_layers": 2, 
    "n_heads": 4, 
    "d_head": 4,
    "d_model": 8,
    "d_mlp": 16,
    "act_fn": "gelu",
}
ll_cfg = hl_model.cfg.to_dict().copy()
ll_cfg.update(cfg_dict)


print(ll_cfg)
ll_cfg = HookedTransformerConfig.from_dict(ll_cfg)
model = HookedTransformer(ll_cfg)

tracr_ll_corr = {
    ('is_x_3', None): {(0, 'mlp', index.Ix[[None]])},
    ('frac_prevs_1', None): {(1, 'attn', attn_idx)},
}

{'n_layers': 2, 'd_model': 8, 'n_ctx': 5, 'd_head': 4, 'model_name': 'custom', 'n_heads': 4, 'd_mlp': 16, 'act_fn': 'gelu', 'd_vocab': 6, 'eps': 1e-05, 'use_attn_result': True, 'use_attn_scale': True, 'use_split_qkv_input': True, 'use_hook_mlp_in': True, 'use_attn_in': False, 'use_local_attn': False, 'original_architecture': None, 'from_checkpoint': False, 'checkpoint_index': None, 'checkpoint_label_type': None, 'checkpoint_value': None, 'tokenizer_name': None, 'window_size': None, 'attn_types': None, 'init_mode': 'gpt2', 'normalization_type': None, 'device': device(type='mps'), 'n_devices': 1, 'attention_dir': 'causal', 'attn_only': False, 'seed': None, 'initializer_range': 0.22188007849009167, 'init_weights': True, 'scale_attn_by_inverse_layer_idx': False, 'positional_embedding_type': 'standard', 'final_rms': False, 'd_vocab_out': 1, 'parallel_attn_mlp': False, 'rotary_dim': None, 'n_params': 676, 'use_hook_tokens': False, 'gated_mlp': False, 'default_prepend_bos': True, 'dtype': tor

In [8]:
import iit_utils.correspondence as correspondence
hl_ll_corr = correspondence.make_hl_ll_corr(tracr_hl_corr=tracr_hl_corr, tracr_ll_corr=tracr_ll_corr)

In [9]:
model_pair = tracr_model_class(
    hl_model = make_iit_hl_model(hl_model),
    ll_model = model,
    corr = hl_ll_corr,
    training_args=training_args,
)

{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_attn_scores': HookPoint(), 'b

In [10]:
model_pair.train(
    train_set,
    test_set,
    epochs=50,
    use_wandb=False,
)

training_args={'batch_size': 512, 'lr': 0.01, 'num_workers': 0, 'early_stop': True, 'atol': 0.05, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'losses': 'all', 'strict_weight': 1.0}


100%|██████████| 24/24 [00:01<00:00, 13.40it/s]
  2%|▏         | 1/50 [00:02<01:41,  2.08s/it]


Epoch 0: train/iit_loss: 0.0527, train/behavior_loss: 0.0275, val/iit_loss: 0.0126, val/IIA: 58.24%, val/accuracy: 66.99%, 


100%|██████████| 24/24 [00:01<00:00, 19.27it/s]
  4%|▍         | 2/50 [00:03<01:20,  1.68s/it]


Epoch 1: train/iit_loss: 0.0132, train/behavior_loss: 0.0111, val/iit_loss: 0.0025, val/IIA: 70.80%, val/accuracy: 75.28%, 


100%|██████████| 24/24 [00:01<00:00, 18.90it/s]
  6%|▌         | 3/50 [00:04<01:13,  1.57s/it]


Epoch 2: train/iit_loss: 0.0015, train/behavior_loss: 0.0005, val/iit_loss: 0.0007, val/IIA: 93.35%, val/accuracy: 99.88%, 


100%|██████████| 24/24 [00:01<00:00, 19.49it/s]
  8%|▊         | 4/50 [00:06<01:09,  1.50s/it]


Epoch 3: train/iit_loss: 0.0007, train/behavior_loss: 0.0001, val/iit_loss: 0.0004, val/IIA: 96.70%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:01<00:00, 19.26it/s]
 10%|█         | 5/50 [00:07<01:06,  1.47s/it]


Epoch 4: train/iit_loss: 0.0004, train/behavior_loss: 0.0000, val/iit_loss: 0.0006, val/IIA: 94.76%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:01<00:00, 18.84it/s]
 12%|█▏        | 6/50 [00:09<01:04,  1.46s/it]


Epoch 5: train/iit_loss: 0.0004, train/behavior_loss: 0.0000, val/iit_loss: 0.0003, val/IIA: 97.75%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:01<00:00, 19.02it/s]
 14%|█▍        | 7/50 [00:10<01:02,  1.45s/it]


Epoch 6: train/iit_loss: 0.0003, train/behavior_loss: 0.0000, val/iit_loss: 0.0002, val/IIA: 98.73%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:01<00:00, 19.49it/s]
 16%|█▌        | 8/50 [00:11<01:00,  1.43s/it]


Epoch 7: train/iit_loss: 0.0003, train/behavior_loss: 0.0008, val/iit_loss: 0.0003, val/IIA: 97.78%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:01<00:00, 19.44it/s]
 18%|█▊        | 9/50 [00:13<00:58,  1.42s/it]


Epoch 8: train/iit_loss: 0.0002, train/behavior_loss: 0.0001, val/iit_loss: 0.0002, val/IIA: 98.64%, val/accuracy: 100.00%, 


100%|██████████| 24/24 [00:01<00:00, 19.41it/s]
 20%|██        | 10/50 [00:14<00:56,  1.42s/it]


Epoch 9: train/iit_loss: 0.0032, train/behavior_loss: 0.0059, val/iit_loss: 0.0013, val/IIA: 83.56%, val/accuracy: 91.76%, 


100%|██████████| 24/24 [00:01<00:00, 19.50it/s]
 20%|██        | 10/50 [00:16<01:04,  1.62s/it]


Epoch 10: train/iit_loss: 0.0005, train/behavior_loss: 0.0002, val/iit_loss: 0.0002, val/IIA: 99.02%, val/accuracy: 100.00%, 





### Setup Eval

In [11]:
"""Create a new test set with unique inputs"""

arr, idxs = np.unique([", ".join(i) for i in np.array(test_inputs)], return_inverse=True)
# create indices that point to the first unique input
all_possible_inputs = np.arange(arr.shape[0])
# find the first occurence of all_possible_inputs in idxs
first_occurences = [np.where(idxs == i)[0][0] for i in all_possible_inputs]

unique_test_inputs = test_inputs[first_occurences]
unique_test_outputs = test_outputs[first_occurences]
assert len(unique_test_inputs) == len(unique_test_outputs)
assert len(unique_test_inputs) == len(np.unique([", ".join(i) for i in np.array(test_inputs)]))
assert len(np.unique([", ".join(i) for i in np.array(unique_test_inputs)])) == len(unique_test_inputs)

unique_test_data = TracrDataset(unique_test_inputs, unique_test_outputs)
test_set = TracrIITDataset(unique_test_data, unique_test_data, hl_model, every_combination=True)
test_loader = test_set.make_loader(batch_size=512, num_workers=0)

In [12]:
def tokenise_data(batch, model: HookedTracrTransformer) -> t.Tensor:
    x = list(map(list, zip(*batch)))
    encoded_x = model.map_tracr_input_to_tl_input(x)
    return encoded_x

In [13]:
tensorised_base_data = []
tensorised_ablation_data = []
base_answer_tokens = []
for base_in, ablation_in in test_loader:
    base_x, base_y,  _ = base_in
    ablation_x, ablation_y, _ = ablation_in

    tensorised_base_data.append((base_x))
    tensorised_ablation_data.append((ablation_x))
    base_answer_tokens.append(base_y)

base_tensor = t.cat(tensorised_base_data, dim=0)
ablation_tensor = t.cat(tensorised_ablation_data, dim=0)
base_answer_tokens = t.cat(base_answer_tokens, dim=0)

In [14]:
model.requires_grad_(False)
model.eval()
hl_model.requires_grad_(False)
hl_model.eval()
print()




In [15]:
original_logits, cache = model.run_with_cache(base_tensor)

In [16]:
hl_answers = hl_model(base_tensor)
(hl_answers.shape), hl_answers[0], base_answer_tokens[0] # Wtf???

(torch.Size([65536, 5, 1]),
 tensor([[0.0000],
         [0.0000],
         [0.0000],
         [0.3333],
         [0.2500]], device='mps:0'),
 tensor([0.0000, 0.0000, 0.0000, 0.3333, 0.2500], device='mps:0'))

In [17]:
original_logits.shape, original_logits[3], hl_answers.shape, hl_answers[3]

(torch.Size([65536, 5, 1]),
 tensor([[0.0002],
         [0.0029],
         [0.0032],
         [0.0033],
         [0.0038]], device='mps:0'),
 torch.Size([65536, 5, 1]),
 tensor([[0.],
         [0.],
         [0.],
         [0.],
         [0.]], device='mps:0'))

### Patch Attention Heads to see Causal Effect

In [18]:
from iit.utils.node_picker import get_nodes_not_in_circuit, get_nodes_in_circuit, get_all_nodes

nodes_not_in_circuit = get_nodes_not_in_circuit(model_pair.ll_model, hl_ll_corr)
nodes_not_in_circuit, "---", list(hl_ll_corr.values())

([LLNode(name='blocks.0.attn.hook_result', index=[:, :, 0, :], subspace=None),
  LLNode(name='blocks.0.attn.hook_result', index=[:, :, 1, :], subspace=None),
  LLNode(name='blocks.0.attn.hook_result', index=[:, :, 2, :], subspace=None),
  LLNode(name='blocks.0.attn.hook_result', index=[:, :, 3, :], subspace=None),
  LLNode(name='blocks.1.attn.hook_result', index=[:, :, 2, :], subspace=None),
  LLNode(name='blocks.1.attn.hook_result', index=[:, :, 3, :], subspace=None),
  LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)],
 '---',
 [{LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)},
  {LLNode(name='blocks.1.attn.hook_result', index=[:, :, :2, :], subspace=None)}])

In [19]:
def do_intervention(model, base_input, ablation_input, node:mp.LLNode, hooker: callable):
    ablation_outs, cache = model.run_with_cache(ablation_input)
    model_pair.ll_cache = cache # TODO: make this better when converting to script
    out = model_pair.ll_model.run_with_hooks(base_input, fwd_hooks=[(node.name, hooker)])
    return out

In [20]:
def resample_ablate_node(model_pair: mp.IITModelPair,
                        base_in: tuple[t.Tensor, t.Tensor, t.Tensor],
                        ablation_in: tuple[t.Tensor, t.Tensor, t.Tensor],
                        node: mp.LLNode, 
                        results: Dict[str, float],
                        hooker: callable,
                        verbose=False
):
    base_x, base_y, _ = base_in
    ablation_x, ablation_y, _ = ablation_in
    ll_out = do_intervention(model_pair.ll_model, base_x, ablation_x, node, hooker)
    if verbose:
        print(node)

    if model_pair.hl_model.is_categorical():
        raise NotImplementedError("Categorical models not supported yet.")
    else:
        base_hl_out = model_pair.hl_model(base_in).squeeze()
        label_unchanged = (base_y == ablation_y)
        ll_unchanged = t.isclose(ll_out.float().squeeze(), base_hl_out.float().to(ll_out.device), atol=atol) 
        changed_result = ((~label_unchanged).cpu().float() * (~ll_unchanged).cpu().float())
        results[node] += changed_result.sum().item() / (~label_unchanged).float().sum().item()

        if verbose:
            print("\nlabel changed:", (~label_unchanged).float().mean(), 
                    "\nouts_changed:", (~ll_unchanged).float().mean(), 
                    "\ndot product:", changed_result.mean(),
                    "\ndifference:", (ll_out.float().squeeze() - base_y.float().to(ll_out.device)).mean(),
                    "\nfinal:", results[node])

In [21]:
from tqdm import tqdm

def check_causal_effect(model_pair: mp.BaseModelPair, dataset: TracrIITDataset, 
                        batch_size: int = 256, node_type: str = "a",
                        verbose: bool = False):
    assert node_type in ["a", "c", "n"], "type must be one of 'a', 'c', or 'n'"
    hookers = {}
    results = {}
    all_nodes = get_nodes_not_in_circuit(model_pair.ll_model, model_pair.corr) if node_type == "n" \
                else get_all_nodes(model_pair.ll_model) if node_type == "a"\
                else get_nodes_in_circuit(model_pair.corr)
    
    for node in all_nodes:
        hookers[node] = model_pair.make_ll_ablation_hook(node)
        results[node] = 0

    loader = dataset.make_loader(batch_size=batch_size, num_workers=0)
    for base_in, ablation_in in tqdm(loader):
        for node, hooker in hookers.items():
            resample_ablate_node(model_pair, base_in, ablation_in, node, results, hooker, verbose=verbose)

    for node, result in results.items():
        results[node] = result / len(loader)
    return results

In [22]:
np.random.seed(0)
t.manual_seed(0)
result_not_in_circuit = check_causal_effect(model_pair, test_set, node_type="n", verbose=False)
result_in_circuit = check_causal_effect(model_pair, test_set, node_type="c", verbose=False)

100%|██████████| 256/256 [00:18<00:00, 13.61it/s]
100%|██████████| 256/256 [00:05<00:00, 46.58it/s]


In [23]:
# plot a table of results
import pandas as pd
def make_dataframe_of_results(result_not_in_circuit, result_in_circuit):
    create_name = lambda node: node.name if "mlp" in node.name else ", head ".join([node.name, str(node.index).split(",")[-2]])
    df = pd.DataFrame({
        "node": [create_name(node) for node in result_not_in_circuit.keys()] + [create_name(node) for node in result_in_circuit.keys()],
        "status": ["not_in_circuit"] * len(result_not_in_circuit) + ["in_circuit"] * len(result_in_circuit),
        "causal effect": list(result_not_in_circuit.values()) + list(result_in_circuit.values())
    })
    df = df.sort_values("status", ascending=False)
    return df

# color the table according to the status
def color_table(val):
    color = 'red' if val == "not_in_circuit" else 'green'
    raise
    return f'background-color: {color}'


In [24]:
df = make_dataframe_of_results(result_not_in_circuit, result_in_circuit)
df.style.apply(color_table, subset=["status"])
print(attn_idx, training_args, tracr_model_class)
df

[:, :, :2, :] {'lr': 0.01, 'losses': 'all', 'atol': 0.05, 'batch_size': 512, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 1.0} <class 'iit.model_pairs.stop_grad_pair.StopGradModelPair'>


Unnamed: 0,node,status,causal effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.001067
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.003468
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.000559
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,4.1e-05
4,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.0
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.042254
6,blocks.1.mlp.hook_post,not_in_circuit,0.719165
7,"blocks.1.attn.hook_result, head :2",in_circuit,1.0
8,blocks.0.mlp.hook_post,in_circuit,1.0


In [25]:
from iit.utils.metric import MetricStore
def print_metrics(metrics: list[MetricStore]):
    for metric in metrics:
        print(f"{metric.get_name()}: {metric.get_value()}")

metric_collection = model_pair._run_eval_epoch(test_loader, model_pair.loss_fn)

In [26]:
print_metrics(metric_collection.metrics)

val/iit_loss: 0.0001665809666064888
val/IIA: 99.27490344271064
val/accuracy: 100.0


### Do the same with zero ablations

In [27]:
class TracrUniqueDataset(TracrIITDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def __getitem__(self, index):
        return self.base_data[index]
    
    def __len__(self):
        return len(self.base_data)
    
    @staticmethod
    def collate_fn(batch, hl_model, device=DEVICE):
        def get_encoded_input_from_torch_input(xy):
            """Encode input to the format expected by the model"""
            x, y = zip(*xy)
            encoded_x = hl_model.map_tracr_input_to_tl_input(x)
            
            if hl_model.is_categorical():
                y = list(y)
                for i in range(len(y)):
                    y[i] =[0] + hl_model.tracr_output_encoder.encode(y[i][1:])
                y = list(map(list, zip(*y)))
                y = t.tensor(y, dtype=t.long).transpose(0, 1)
                # print(y, y.shape)
                num_classes = len(hl_model.tracr_output_encoder.encoding_map.keys())
                y = t.nn.functional.one_hot(y, num_classes=num_classes).float()
            else:
                y = list(map(list, zip(*y)))
                y[0] = list(np.zeros(len(y[0])))
                y = t.tensor(y, dtype=t.float32).transpose(0, 1)
            intermediate_values = None
            return encoded_x.to(device), y.to(device), intermediate_values

        encoded_base_input = get_encoded_input_from_torch_input(batch)
        return encoded_base_input

In [28]:
def get_mean_cache(model_pair, dataset):
    loader = dataset.make_loader(batch_size=len(dataset), num_workers=0)
    batch = next(iter(loader))
    cache_dict = {}
    _, cache = model_pair.ll_model.run_with_cache(batch[0])
    for key, value in cache.items():
        cache_dict[key] = value.mean(dim=0).unsqueeze(0)
    return cache_dict

In [29]:
from transformer_lens.hook_points import HookPoint

def make_ablation_hook(node: mp.LLNode, mean_cache: dict[str, t.Tensor], use_mean_cache: bool = True) -> callable:
    if node.subspace is not None:
        raise NotImplementedError("Subspace not supported yet.")
    def zero_hook(hook_point_out: t.Tensor, hook: HookPoint) -> t.Tensor:
        hook_point_out[node.index.as_index] = 0
        return hook_point_out
    
    def mean_hook(hook_point_out: t.Tensor, hook: HookPoint) -> t.Tensor:
        cached_tensor = mean_cache[node.name]
        hook_point_out[node.index.as_index] = cached_tensor[node.index.as_index]
        return hook_point_out
    if use_mean_cache:
        return mean_hook
    return zero_hook

def ablate_node(model_pair: mp.IITModelPair,
                        base_in: tuple[t.Tensor, t.Tensor, t.Tensor],
                        node: mp.LLNode,
                        results: Dict[str, float],
                        hook: callable,
                        verbose=False
):
    base_x, base_y, _ = base_in
    ll_out = model_pair.ll_model.run_with_hooks(base_x, fwd_hooks=[(node.name, hook)])
    
    if model_pair.hl_model.is_categorical():
        raise NotImplementedError("Categorical models not supported yet.")
    else:
        base_hl_out = model_pair.hl_model(base_in).squeeze()
        base_ll_out = model_pair.ll_model(base_x).squeeze()
        ll_unchanged = t.isclose(ll_out.float().squeeze(), base_hl_out.float().to(ll_out.device), atol=atol) 
        accuracy = t.isclose(base_ll_out.float(), base_hl_out.float(), atol=atol).cpu().float()
        changed_result = (~ll_unchanged).cpu().float() * accuracy
        results[node] += changed_result.sum().item() / accuracy.float().sum().item()


In [30]:
def check_causal_effect_on_ablation(
        model_pair: mp.BaseModelPair, dataset: TracrUniqueDataset, 
        batch_size: int = 256, node_type: str = "a",
        use_mean_cache: bool = False,
        verbose: bool = False):
    if use_mean_cache:
        mean_cache = get_mean_cache(model_pair, dataset)
    assert node_type in ["a", "c", "n"], "type must be one of 'a', 'c', or 'n'"
    hookers = {}
    results = {}
    all_nodes = get_nodes_not_in_circuit(model_pair.ll_model, model_pair.corr) if node_type == "n" \
                else get_all_nodes(model_pair.ll_model) if node_type == "a"\
                else get_nodes_in_circuit(model_pair.corr)
    
    for node in all_nodes:
        hookers[node] = make_ablation_hook(node, mean_cache, use_mean_cache)
        results[node] = 0

    loader = dataset.make_loader(batch_size=batch_size, num_workers=0)
    for base_in in tqdm(loader):
        for node, hooker in hookers.items():
            ablate_node(model_pair, base_in, node, results, hooker, verbose=verbose)

    for node, result in results.items():
        results[node] = result / len(loader)
    return results

In [31]:
uni_test_set = TracrUniqueDataset(unique_test_data, unique_test_data, hl_model, every_combination=True)


In [32]:
np.random.seed(0)
t.manual_seed(0)
use_mean_cache = True
za_result_not_in_circuit = check_causal_effect_on_ablation(model_pair, uni_test_set, node_type="n", verbose=False,  use_mean_cache=use_mean_cache)
za_result_in_circuit = check_causal_effect_on_ablation(model_pair, uni_test_set, node_type="c", verbose=False,  use_mean_cache=use_mean_cache)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00, 12.29it/s]
100%|██████████| 1/1 [00:00<00:00, 44.72it/s]


In [33]:
df = make_dataframe_of_results(za_result_not_in_circuit, za_result_in_circuit)
df.style.map(color_table, subset=["status"])
print(attn_idx, training_args, tracr_model_class)
df

[:, :, :2, :] {'lr': 0.01, 'losses': 'all', 'atol': 0.05, 'batch_size': 512, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 1.0} <class 'iit.model_pairs.stop_grad_pair.StopGradModelPair'>


Unnamed: 0,node,status,causal effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.0
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.0
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.0
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.0
4,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.0
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.003125
6,blocks.1.mlp.hook_post,not_in_circuit,0.471875
7,"blocks.1.attn.hook_result, head :2",in_circuit,0.715625
8,blocks.0.mlp.hook_post,in_circuit,0.8


### Combined table

In [34]:
def make_combined_dataframe_of_results(result_not_in_circuit, result_in_circuit, za_result_not_in_circuit, za_result_in_circuit, use_mean_cache: bool = False):
    df = make_dataframe_of_results(result_not_in_circuit, result_in_circuit)
    df2 = make_dataframe_of_results(za_result_not_in_circuit, za_result_in_circuit)
    df2_causal_effect = df2.pop("causal effect")
    # rename the columns
    df["resample_ablate_effect"] = df.pop("causal effect")
    if use_mean_cache:
        df["mean_ablate_effect"] = df2_causal_effect
    else:
        df["zero_ablate_effect"] = df2_causal_effect
    
    return df
df = make_combined_dataframe_of_results(result_not_in_circuit, result_in_circuit, za_result_not_in_circuit, za_result_in_circuit, use_mean_cache=use_mean_cache)
df.style.apply(color_table, subset=["status"], method = "map")   
print(attn_idx, training_args, tracr_model_class)
df

[:, :, :2, :] {'lr': 0.01, 'losses': 'all', 'atol': 0.05, 'batch_size': 512, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 1.0} <class 'iit.model_pairs.stop_grad_pair.StopGradModelPair'>


Unnamed: 0,node,status,resample_ablate_effect,mean_ablate_effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.001067,0.0
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.003468,0.0
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.000559,0.0
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,4.1e-05,0.0
4,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.0,0.0
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.042254,0.003125
6,blocks.1.mlp.hook_post,not_in_circuit,0.719165,0.471875
7,"blocks.1.attn.hook_result, head :2",in_circuit,1.0,0.715625
8,blocks.0.mlp.hook_post,in_circuit,1.0,0.8


In [35]:
# save the results
import time
import json
import dataframe_image as dfi
save_dir = f"results/{tracr_model_class.__name__}/{time.strftime('%d-%H-%M-%S')}"
os.makedirs(save_dir, exist_ok=True)
dfi.export(df, f"{save_dir}/results.png")
df.to_csv(f"{save_dir}/results.csv")
with open(f"{save_dir}/meta.json", "w") as f:
    json.dump(training_args, f)

### Rough

In [36]:
# def get_all_bad_examples(model_pair, loader, atol=5e-2):
#     model_pair.ll_model.eval()
#     model_pair.hl_model.eval()
#     bad_io_examples = []
#     bad_ii_examples = []

#     for base_in, ablation_in in tqdm(loader):
#         base_in = model_pair.get_encoded_input_from_torch_input(base_in)
#         ablation_in = model_pair.get_encoded_input_from_torch_input(ablation_in)
#         for node in model_pair.corr.keys():
#             hl_node = node.name
#             ll_out, hl_out = model_pair.do_intervention(base_in, ablation_in, hl_node)
#             if model_pair.hl_model.is_categorical():
#                 top1 = t.argmax(ll_out, dim=1)
#                 correct = (top1 == hl_out).float()
#             else:
#                 correct = ((ll_out - hl_out).abs() < atol).float()
            
#             for i, c in enumerate(correct):
#                 print(c)
#                 if c == 0:
#                     bad_ii_examples.append((base_in[i], ablation_in[i]))
#         base_x, base_y = base_in
#         ll_out = model_pair.ll_model(base_x)
#         if model_pair.hl_model.is_categorical():
#             top1 = t.argmax(ll_out, dim=1)
#             correct = (top1 == base_y).float()
#         else:
#             correct = ((ll_out - base_y).abs() < atol).float()
        
#         for i, c in enumerate(correct):
#             if c == 0:
#                 if base_x[i] not in bad_io_examples:
#                     bad_io_examples.append((base_x[i]))

#     return bad_io_examples, bad_ii_examples

# bad_io_examples, bad_ii_examples = get_all_bad_examples(model_pair, test_loader, atol)

# bad_io_examples, bad_ii_examples

In [37]:
# np.random.seed(0)
# t.manual_seed(0)
# test_loader = DataLoader(test_set, batch_size=2, shuffle=True)
# base_in, ablation_in = next(iter(test_loader))

# hooker = model_pair.make_ll_ablation_hook(nodes_not_in_circuit[2])
# base_x, base_y = model_pair.get_encoded_input_from_torch_input(base_in)
# ablation_x, ablation_y = model_pair.get_encoded_input_from_torch_input(ablation_in)
# ll_out = do_intervention(model_pair.ll_model, base_x, ablation_x, nodes_not_in_circuit[2], hooker)
# ll_base_out, ll_base_cache = model_pair.ll_model.run_with_cache(base_x)
# ll_ablation_out, ll_ablation_cache = model_pair.ll_model.run_with_cache(ablation_x)
# for i in range(2):
#     print(
#         "---",
#         f"example {i}", 
#         "base_y:", base_y[i],
#         "ll_base_out:", ll_base_out[i].T,
#         "",
#         "ablation_y:", ablation_y[i],
#         "ll_ablation_out:", ll_ablation_out[i].T,
#         "",
#         "ll_out:", ll_out[i].T,
#         sep="\n"
#     )