In [None]:
import torch
import numpy as np
import random

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


In [3]:
from circuits_benchmark.utils import hf
from circuits_benchmark.utils.get_cases import get_cases
from circuits_benchmark.commands.build_main_parser import build_main_parser
from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer
from argparse import Namespace

parser_args = Namespace(task=3, model="510")

args, _ = build_main_parser().parse_known_args(
    [
        "compile",
        f"-i={parser_args.task}",
        "-f",
    ]
)

cases = get_cases(args)
case = cases[0]
if not case.supports_causal_masking():
    raise NotImplementedError(f"Case {case.get_index()} does not support causal masking")

tracr_output = case.build_tracr_model()
hl_model = case.build_transformer_lens_model()

cfg_dict = {
        "n_layers": 2,
        "n_heads": 4,
        "d_head": 4,
        "d_model": 8,
        "d_mlp": 16,
        "seed": 0,
        "act_fn": "gelu",
    }
ll_cfg = hl_model.cfg.to_dict().copy()
ll_cfg.update(cfg_dict)

ll_model = HookedTracrTransformer(
    ll_cfg, hl_model.tracr_input_encoder, hl_model.tracr_output_encoder, hl_model.residual_stream_labels,
    remove_extra_tensor_cloning=True
)
ll_model.load_weights_from_file(f"ll_models/{case.get_index()}/ll_model_{parser_args.model}.pth")

wrapped_ll_model = hf.make_hf_wrapper_from_tl_model(ll_model)

Moving model to device:  cpu
Moving model to device:  cpu


In [5]:
from circuits_benchmark.utils.iit import make_iit_hl_model, create_dataset
iit_hl_model = make_iit_hl_model(hl_model)
train_data, test_data = create_dataset(case, hl_model) 

(<circuits_benchmark.utils.iit.dataset.TracrIITDataset at 0x2b2e98d10>,
 <circuits_benchmark.utils.iit.dataset.TracrIITDataset at 0x2b1054210>)

In [6]:
batch_size = 256
num_workers = 0

loader = train_data.make_loader(batch_size, num_workers)
test_loader = test_data.make_loader(batch_size, num_workers)

In [12]:
from pyvene import IntervenableModel, IntervenableConfig, RepresentationConfig, RotatedSpaceIntervention

In [13]:
# TODO: Why twice of the same thing???

config = IntervenableConfig(
    model_type=type(wrapped_ll_model),
    representations=[
        RepresentationConfig(
            0,  # layer
            "block_output",  # intervention type
            "pos",  # intervention unit is now aligne with tokens
            1,  # max number of unit
            subspace_partition=None,  # binary partition with equal sizes
            intervention_link_key=0,
        ),
        RepresentationConfig(
            0,  # layer
            "block_output",  # intervention type
            "pos",  # intervention unit is now aligne with tokens
            1,  # max number of unit
            subspace_partition=None,  # binary partition with equal sizes,
            intervention_link_key=0,
        ),
    ],
    intervention_types=RotatedSpaceIntervention,
)

In [14]:
intervenable = IntervenableModel(config, wrapped_ll_model, use_fast=True)
intervenable.set_device("cpu")
intervenable.disable_model_gradients()


In case multiple location tags are passed only the first one will be considered


block_output
n_embd
model.blocks[0]
block_output
n_embd
model.blocks[0]


In [None]:
epochs = 10
gradient_accumulation_steps = 1
# total_step = 0
# target_total_step = len(train_data) * epochs

optimizer_params = []
for k, v in intervenable.interventions.items():
    optimizer_params += [{"params": v[0].rotate_layer.parameters()}]
    break
optimizer = torch.optim.Adam(optimizer_params, lr=0.001)


def compute_metrics(eval_preds, eval_labels):
    total_count = 0
    correct_count = 0
    for eval_pred, eval_label in zip(eval_preds, eval_labels):
        total_count += 1
        correct_count += eval_pred == eval_label
    accuracy = float(correct_count) / float(total_count)
    return {"accuracy": accuracy}


def compute_loss(outputs, labels):
    CE = torch.nn.CrossEntropyLoss()
    return CE(outputs, labels)


def batched_random_sampler(data):
    batch_indices = [_ for _ in range(int(len(data) / batch_size))]
    random.shuffle(batch_indices)
    for b_i in batch_indices:
        for i in range(b_i * batch_size, (b_i + 1) * batch_size):
            yield i