In [None]:
from subspace_partition.preimage.cache_act import run_cache_act

import transformer_lens
import torch
from pathlib import Path
import copy_transformer.data
import copy_transformer.tokenizer

In [2]:
EMBEDDING_DIM = 64
NUM_HEADS = 8
VOCABULARY = [c for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
CONTEXT_LENGTH = 32

MAX_PATTERN_LENGTH = 16

OUTPUT_DIR = Path("out/preimage")

ACT_SITES = ["blocks.0.hook_resid_post", "blocks.1.hook_resid_post"]

NUM_SAMPLES = 1_000

In [3]:
tokenizer = copy_transformer.tokenizer.SingleCharTokenizer(
    alphabet=VOCABULARY,
    bos_token=">",
    eos_token="<",
    unk_token="?",
    pad_token="_",
    name_or_path="custom",
    add_bos_token=True,
)

model_config = transformer_lens.HookedTransformerConfig(
    d_model=EMBEDDING_DIM,
    d_head=EMBEDDING_DIM // NUM_HEADS,
    n_layers=2,
    n_ctx=CONTEXT_LENGTH,
    n_heads=NUM_HEADS,
    d_vocab=tokenizer.vocab_size,
    attn_only=True,
)

model_state_dict_path = Path("out/copy_transformer.pt")

model = transformer_lens.HookedTransformer(model_config)
model.load_state_dict(torch.load(model_state_dict_path))
model.tokenizer = tokenizer


dataset = copy_transformer.data.IterablePureRepeatingPatternDataset(
    num_samples=NUM_SAMPLES,
    vocabulary=VOCABULARY,
    context_length=CONTEXT_LENGTH,
    max_pattern_length=MAX_PATTERN_LENGTH,
)

In [4]:
import shutil
shutil.rmtree(OUTPUT_DIR, ignore_errors=True)

run_cache_act(
    model=model,
    dataset=dataset,
    act_sites=ACT_SITES,
    output_dir=OUTPUT_DIR,
    max_in_memory=100
)

caching activations: 1it [00:00,  6.60it/s]

saving batch to disk...
saving batch to disk...


caching activations: 3it [00:00,  8.69it/s]

saving batch to disk...
saving batch to disk...
saving batch to disk...


caching activations: 7it [00:00,  9.72it/s]

saving batch to disk...
saving batch to disk...
saving batch to disk...


caching activations: 9it [00:00,  9.87it/s]

saving batch to disk...
saving batch to disk...
saving batch to disk...


caching activations: 12it [00:01,  9.92it/s]

saving batch to disk...
saving batch to disk...
saving batch to disk...


caching activations: 15it [00:01,  9.96it/s]

saving batch to disk...
saving batch to disk...
saving batch to disk...


caching activations: 19it [00:01,  9.95it/s]

saving batch to disk...
saving batch to disk...
saving batch to disk...


caching activations: 22it [00:02,  9.98it/s]

saving batch to disk...
saving batch to disk...


caching activations: 23it [00:02,  9.95it/s]

saving batch to disk...
saving batch to disk...
saving batch to disk...


caching activations: 27it [00:02, 10.00it/s]

saving batch to disk...
saving batch to disk...
saving batch to disk...


caching activations: 29it [00:02,  9.98it/s]

saving batch to disk...
saving batch to disk...
saving batch to disk...


caching activations: 32it [00:03,  9.84it/s]


saving batch to disk...
loading and merging cached act...
total count 32000


100%|██████████| 2/2 [00:00<00:00, 91.83it/s]

saving input...



