In [1]:
import torch as t
import numpy as np
import wandb
from transformer_lens import HookedTransformer
import iit.model_pairs as mp
import iit.utils.index as index
from circuits_benchmark.utils.get_cases import get_cases
from circuits_benchmark.commands.build_main_parser import build_main_parser
from iit_utils import make_iit_hl_model, create_dataset
import iit_utils.correspondence as correspondence
import random
from iit_utils.tracr_ll_corrs import get_tracr_ll_corr
import argparse
import os
import json

# seed everything
t.manual_seed(0)
np.random.seed(0)
# t.use_deterministic_algorithms(True)
random.seed(0)
# DEVICE = t.device("cuda" if t.cuda.is_available() else "cpu")
WANDB_ENTITY = "cybershiptrooper"  # TODO make this an env var

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tasks = [3, 21]

In [3]:
def get_case_outputs(task):
    args, _ = build_main_parser().parse_known_args(
            [
                "compile",
                f"-i={task}",
                "-f",
            ]
        )
    case = get_cases(args)[0]
    assert case.get_index() == str(task), f"Expected case {task}, got {case.get_index()}"
    if not case.supports_causal_masking():
        raise NotImplementedError(f"Case {case.get_index()} does not support causal masking")
    
    tracr_output = case.build_tracr_model()
    hl_model = case.build_transformer_lens_model()
    # this is the graph node -> hl node correspondence
    tracr_hl_corr = correspondence.TracrCorrespondence.from_output(tracr_output)
    return case, hl_model, tracr_hl_corr

In [4]:
cases =[]
hl_models = []
tracr_hl_corrs = []

for t in tasks:
    case, hl_model, tracr_hl_corr = get_case_outputs(t)
    cases.append(case)
    hl_models.append(hl_model)
    tracr_hl_corrs.append(tracr_hl_corr)

# get the tracr_ll_corrs
tracr_ll_corrs = [get_tracr_ll_corr(case) for case in cases]

# get train and test sets

train_sets, test_sets = zip(*[create_dataset(case, hl_model) for case, hl_model in zip(cases, hl_models)])

In [5]:
for i in cases:
    print(i.get_max_seq_len())

10
10
