In [1]:
from subspace_partition.preimage.build_index import run_build_index

import transformer_lens
import torch
from pathlib import Path
import copy_transformer.data
import copy_transformer.tokenizer

In [2]:
EMBEDDING_DIM = 64
NUM_HEADS = 8
VOCABULARY = [c for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
CONTEXT_LENGTH = 32

MAX_PATTERN_LENGTH = 16

TRAINED_RS_DIR = Path("out/subspace_partition/copy_transformer_sensible_sp_parameters")
CACHED_ACTS_DIR = Path("out/preimage")
OUTPUT_DIR = Path("out/index")

ACT_SITES = ["blocks.0.hook_resid_post", "blocks.1.hook_resid_post"]

NUM_SAMPLES = 1_000

In [3]:
tokenizer = copy_transformer.tokenizer.SingleCharTokenizer(
    alphabet=VOCABULARY,
    bos_token=">",
    eos_token="<",
    unk_token="?",
    pad_token="_",
    name_or_path="custom",
    add_bos_token=True,
)

model_config = transformer_lens.HookedTransformerConfig(
    d_model=EMBEDDING_DIM,
    d_head=EMBEDDING_DIM // NUM_HEADS,
    n_layers=2,
    n_ctx=CONTEXT_LENGTH,
    n_heads=NUM_HEADS,
    d_vocab=tokenizer.vocab_size,
    attn_only=True,
)

model_state_dict_path = Path("out/copy_transformer.pt")

model = transformer_lens.HookedTransformer(model_config)
model.load_state_dict(torch.load(model_state_dict_path))
model.tokenizer = tokenizer


dataset = copy_transformer.data.IterablePureRepeatingPatternDataset(
    num_samples=NUM_SAMPLES,
    vocabulary=VOCABULARY,
    context_length=CONTEXT_LENGTH,
    max_pattern_length=MAX_PATTERN_LENGTH,
)

In [4]:
import shutil
shutil.rmtree(OUTPUT_DIR, ignore_errors=True)

run_build_index(
    trained_Rs_dir=TRAINED_RS_DIR,
    cached_acts_dir=CACHED_ACTS_DIR,
    output_dir=OUTPUT_DIR,
    experiment_name="copy_transformer_sensible_sp_parameters",
)

build index using R from out/subspace_partition/copy_transformer_sensible_sp_parameters/R-custom-x1.post.pt


100%|██████████| 32/32 [00:00<00:00, 1778.90it/s]


saving indices...
build index using R from out/subspace_partition/copy_transformer_sensible_sp_parameters/R-custom-x0.post.pt


100%|██████████| 32/32 [00:00<00:00, 1930.69it/s]

saving indices...



