In [1]:
from subspace_partition.subspace_partition import (
    run_subspace_partition,
    SubspacePartitionConfig,
)

import transformer_lens
from pathlib import Path
import copy_transformer.data
import copy_transformer.tokenizer

In [2]:
EMBEDDING_DIM = 64
NUM_HEADS = 8
VOCABULARY = [c for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
CONTEXT_LENGTH = 32

MAX_PATTERN_LENGTH = 16

MAX_STEPS = 10_000

EXPERIMENT_NAME = "copy_transformer"
OUTPUT_DIR = Path("out/subspace_partition")

In [3]:
tokenizer = copy_transformer.tokenizer.SingleCharTokenizer(
    alphabet=VOCABULARY,
    bos_token=">",
    eos_token="<",
    unk_token="?",
    pad_token="_",
    name_or_path="custom",
    add_bos_token=True,
)

model_config = transformer_lens.HookedTransformerConfig(
    d_model=EMBEDDING_DIM,
    d_head=EMBEDDING_DIM // NUM_HEADS,
    n_layers=2,
    n_ctx=CONTEXT_LENGTH,
    n_heads=NUM_HEADS,
    d_vocab=tokenizer.vocab_size,
    attn_only=True,
)
model_state_dict_path = Path("out/copy_transformer.pt")

dataset = copy_transformer.data.InfinitePureRepeatingPatternDataset(
    vocabulary=VOCABULARY,
    context_length=CONTEXT_LENGTH,
    max_pattern_length=MAX_PATTERN_LENGTH,
)

subspace_partition_config = SubspacePartitionConfig(
    exp_name=EXPERIMENT_NAME,
    output_dir=OUTPUT_DIR,
    model_config=model_config,
    model_weights_path=model_state_dict_path,
    act_sites=["blocks.0.hook_resid_post", "blocks.1.hook_resid_post"],
    tokenizer=tokenizer,
    dataset=dataset,
    max_steps=MAX_STEPS,
)

In [4]:
run_subspace_partition(subspace_partition_config)

training for blocks.0.hook_resid_post


  2%|▏         | 201/10000 [00:30<25:19,  6.45it/s]

{'R_grad_norm': 0.018802380366250874, 'training_loss': 0.19663480170071124}


  4%|▍         | 401/10000 [01:00<23:45,  6.73it/s]

{'R_grad_norm': 0.021874440442770718, 'training_loss': 0.1862760491669178}


  6%|▌         | 601/10000 [01:31<24:03,  6.51it/s]

{'R_grad_norm': 0.022338520642369984, 'training_loss': 0.1779511746019125}


  8%|▊         | 801/10000 [02:01<23:18,  6.58it/s]

{'R_grad_norm': 0.022894547060132026, 'training_loss': 0.17405482999980448}


 10%|█         | 1001/10000 [02:31<22:40,  6.61it/s]

{'R_grad_norm': 0.023545913696289063, 'training_loss': 0.17120515838265418}


 12%|█▏        | 1201/10000 [03:02<22:16,  6.59it/s]

{'R_grad_norm': 0.024020167412236334, 'training_loss': 0.16835944898426533}


 14%|█▍        | 1401/10000 [03:32<21:15,  6.74it/s]

{'R_grad_norm': 0.024342360682785513, 'training_loss': 0.16630152583122254}


 16%|█▌        | 1601/10000 [04:02<21:07,  6.63it/s]

{'R_grad_norm': 0.025006641047075392, 'training_loss': 0.16351116202771665}


 18%|█▊        | 1801/10000 [04:32<20:30,  6.66it/s]

{'R_grad_norm': 0.025677233152091503, 'training_loss': 0.16203404664993287}


 20%|██        | 2001/10000 [05:03<19:54,  6.69it/s]

{'R_grad_norm': 0.025884738583117725, 'training_loss': 0.1599464028328657}


 22%|██▏       | 2201/10000 [05:33<19:31,  6.66it/s]

{'R_grad_norm': 0.025860695289447903, 'training_loss': 0.15925950340926648}


 24%|██▍       | 2401/10000 [06:04<19:07,  6.62it/s]

{'R_grad_norm': 0.02608717138879001, 'training_loss': 0.15914075180888176}


 26%|██▌       | 2601/10000 [06:34<18:06,  6.81it/s]

{'R_grad_norm': 0.02610908613540232, 'training_loss': 0.15813943557441235}


 28%|██▊       | 2801/10000 [07:04<17:50,  6.72it/s]

{'R_grad_norm': 0.025894283996894956, 'training_loss': 0.15794371157884599}


 30%|███       | 3001/10000 [07:35<16:57,  6.88it/s]

{'R_grad_norm': 0.0260620238725096, 'training_loss': 0.1586367290467024}


 32%|███▏      | 3201/10000 [08:05<17:24,  6.51it/s]

{'R_grad_norm': 0.026135165421292187, 'training_loss': 0.15768157429993152}


 34%|███▍      | 3401/10000 [08:35<16:23,  6.71it/s]

{'R_grad_norm': 0.02646482759155333, 'training_loss': 0.15860639862716197}


 36%|███▌      | 3601/10000 [09:05<15:30,  6.87it/s]

{'R_grad_norm': 0.026456142300739883, 'training_loss': 0.1576370644569397}


 38%|███▊      | 3801/10000 [09:36<15:28,  6.67it/s]

{'R_grad_norm': 0.02664949604310095, 'training_loss': 0.15711097911000252}


 40%|████      | 4001/10000 [10:06<14:52,  6.72it/s]

{'R_grad_norm': 0.027193088522180914, 'training_loss': 0.15714989624917508}


 42%|████▏     | 4201/10000 [10:36<14:43,  6.56it/s]

{'R_grad_norm': 0.027494651898741722, 'training_loss': 0.1574119359254837}


 44%|████▍     | 4401/10000 [11:06<16:04,  5.81it/s]

{'R_grad_norm': 0.027812709053978323, 'training_loss': 0.15731367535889149}


 46%|████▌     | 4601/10000 [11:36<13:32,  6.65it/s]

{'R_grad_norm': 0.027572576133534313, 'training_loss': 0.15653642922639846}


 48%|████▊     | 4801/10000 [12:06<13:51,  6.25it/s]

{'R_grad_norm': 0.027629250809550285, 'training_loss': 0.15607360139489174}


 50%|█████     | 5001/10000 [12:37<12:33,  6.64it/s]

{'R_grad_norm': 0.027883620420470833, 'training_loss': 0.15647092431783677}


 52%|█████▏    | 5201/10000 [13:07<12:06,  6.61it/s]

{'R_grad_norm': 0.027792724948376417, 'training_loss': 0.15608837306499482}


 54%|█████▍    | 5401/10000 [13:36<11:34,  6.63it/s]

{'R_grad_norm': 0.027829403579235076, 'training_loss': 0.15690813809633256}


 56%|█████▌    | 5601/10000 [14:07<10:52,  6.75it/s]

{'R_grad_norm': 0.02798255373723805, 'training_loss': 0.15627516843378544}


 58%|█████▊    | 5801/10000 [14:37<10:34,  6.61it/s]

{'R_grad_norm': 0.028174521885812282, 'training_loss': 0.15733237087726593}


 60%|██████    | 6001/10000 [15:06<09:56,  6.70it/s]

{'R_grad_norm': 0.02822502374649048, 'training_loss': 0.1566131266951561}


 62%|██████▏   | 6201/10000 [15:36<09:31,  6.65it/s]

{'R_grad_norm': 0.028296121573075653, 'training_loss': 0.15644734218716622}


 64%|██████▍   | 6401/10000 [16:07<09:07,  6.57it/s]

{'R_grad_norm': 0.02848504704423249, 'training_loss': 0.15591424003243445}


 66%|██████▌   | 6601/10000 [16:37<08:22,  6.76it/s]

{'R_grad_norm': 0.028570834565907716, 'training_loss': 0.15587188348174094}


 68%|██████▊   | 6801/10000 [17:07<07:58,  6.69it/s]

{'R_grad_norm': 0.028748065372928976, 'training_loss': 0.15624612666666507}


 70%|███████   | 7001/10000 [17:36<07:27,  6.70it/s]

{'R_grad_norm': 0.028714062608778478, 'training_loss': 0.1569139563292265}


 72%|███████▏  | 7201/10000 [18:07<06:58,  6.69it/s]

{'R_grad_norm': 0.028902928363531828, 'training_loss': 0.15564521118998528}


 74%|███████▍  | 7401/10000 [18:37<06:38,  6.53it/s]

{'R_grad_norm': 0.028828633120283484, 'training_loss': 0.1559470321983099}


 76%|███████▌  | 7601/10000 [19:07<05:54,  6.77it/s]

{'R_grad_norm': 0.02843492946587503, 'training_loss': 0.1551356775313616}


 78%|███████▊  | 7801/10000 [19:37<05:23,  6.80it/s]

{'R_grad_norm': 0.0286838654614985, 'training_loss': 0.15559938617050648}


 80%|████████  | 8001/10000 [20:07<04:59,  6.68it/s]

{'R_grad_norm': 0.028862501410767436, 'training_loss': 0.15654790438711644}


 82%|████████▏ | 8201/10000 [20:36<04:22,  6.86it/s]

{'R_grad_norm': 0.028681256296113133, 'training_loss': 0.1556330118328333}


 84%|████████▍ | 8401/10000 [21:07<04:17,  6.21it/s]

{'R_grad_norm': 0.028757950542494654, 'training_loss': 0.1552976781129837}


 86%|████████▌ | 8601/10000 [21:36<03:29,  6.68it/s]

{'R_grad_norm': 0.028828074717894198, 'training_loss': 0.15589692667126656}


 88%|████████▊ | 8801/10000 [22:06<02:59,  6.67it/s]

{'R_grad_norm': 0.028613632936030628, 'training_loss': 0.15628701210021972}


 90%|█████████ | 9001/10000 [22:36<02:30,  6.66it/s]

{'R_grad_norm': 0.029039305401965976, 'training_loss': 0.1560700134932995}


 92%|█████████▏| 9201/10000 [23:06<01:58,  6.74it/s]

{'R_grad_norm': 0.028893548185005783, 'training_loss': 0.15624442487955092}


 94%|█████████▍| 9401/10000 [23:38<01:30,  6.64it/s]

{'R_grad_norm': 0.028971242448315025, 'training_loss': 0.15572859689593316}


 96%|█████████▌| 9601/10000 [24:08<01:01,  6.49it/s]

{'R_grad_norm': 0.028640836905688048, 'training_loss': 0.15494498386979103}


 98%|█████████▊| 9801/10000 [24:38<00:31,  6.31it/s]

{'R_grad_norm': 0.028956703571602703, 'training_loss': 0.15591946467757226}


100%|██████████| 10000/10000 [25:08<00:00,  6.63it/s]


{'R_grad_norm': 0.02886861658655107, 'training_loss': 0.15546441055834292}
finish training (10000)
saving.. out/subspace_partition/copy_transformer
evaluating (25 steps)...
 ******* eval result *******
mean (weighted) 0.155416339635849
mean (unweighted) 0.155416339635849
tensor([0.1416, 0.1692])
training for blocks.1.hook_resid_post


  2%|▏         | 201/10000 [00:30<26:03,  6.27it/s]

{'R_grad_norm': 0.26566193252801895, 'training_loss': 2.444055632352829}


  4%|▍         | 401/10000 [01:02<23:49,  6.71it/s]

{'R_grad_norm': 0.6239995001256466, 'training_loss': 2.2400707161426543}


  6%|▌         | 601/10000 [01:32<22:59,  6.81it/s]

{'R_grad_norm': 0.834159831404686, 'training_loss': 2.1093051338195803}


  8%|▊         | 801/10000 [02:01<22:54,  6.69it/s]

{'R_grad_norm': 0.8627157308161258, 'training_loss': 2.017840737104416}


 10%|█         | 1001/10000 [02:31<21:56,  6.83it/s]

{'R_grad_norm': 0.8362627279758453, 'training_loss': 1.9317820942401887}


 12%|█▏        | 1201/10000 [03:01<21:42,  6.76it/s]

{'R_grad_norm': 0.8365358445048332, 'training_loss': 1.8668744301795959}


 14%|█▍        | 1401/10000 [03:31<22:27,  6.38it/s]

{'R_grad_norm': 0.8600969704985618, 'training_loss': 1.8192048007249833}


 16%|█▌        | 1601/10000 [04:00<20:45,  6.74it/s]

{'R_grad_norm': 0.8534293836355209, 'training_loss': 1.7924591213464738}


 18%|█▊        | 1801/10000 [04:30<20:15,  6.75it/s]

{'R_grad_norm': 0.8637300617992878, 'training_loss': 1.7683659797906877}


 20%|██        | 2001/10000 [05:00<19:21,  6.89it/s]

{'R_grad_norm': 0.8601392975449562, 'training_loss': 1.758798560500145}


 22%|██▏       | 2201/10000 [05:30<18:58,  6.85it/s]

{'R_grad_norm': 0.8596263030171394, 'training_loss': 1.7596831250190734}


 24%|██▍       | 2401/10000 [05:59<19:01,  6.66it/s]

{'R_grad_norm': 0.8444304981827736, 'training_loss': 1.7464054030179978}


 26%|██▌       | 2601/10000 [06:29<18:18,  6.73it/s]

{'R_grad_norm': 0.8422423270344734, 'training_loss': 1.726412494778633}


 28%|██▊       | 2801/10000 [06:59<17:41,  6.78it/s]

{'R_grad_norm': 0.8521186527609825, 'training_loss': 1.7219601315259934}


 30%|███       | 3001/10000 [07:29<17:12,  6.78it/s]

{'R_grad_norm': 0.854252196252346, 'training_loss': 1.721098888516426}


 32%|███▏      | 3201/10000 [07:59<16:39,  6.80it/s]

{'R_grad_norm': 0.8446025756001473, 'training_loss': 1.7114272451400756}


 34%|███▍      | 3401/10000 [08:30<16:37,  6.61it/s]

{'R_grad_norm': 0.8677602583169937, 'training_loss': 1.7253992569446563}


 36%|███▌      | 3601/10000 [09:01<15:59,  6.67it/s]

{'R_grad_norm': 0.8549583247303962, 'training_loss': 1.7181349325180053}


 38%|███▊      | 3801/10000 [09:31<15:05,  6.85it/s]

{'R_grad_norm': 0.8478782349824905, 'training_loss': 1.7141839784383774}


 40%|████      | 4001/10000 [10:03<15:47,  6.33it/s]

{'R_grad_norm': 0.86450954079628, 'training_loss': 1.7186213010549545}


 42%|████▏     | 4201/10000 [10:38<15:11,  6.37it/s]

{'R_grad_norm': 0.8516459342837334, 'training_loss': 1.7248185151815414}


 44%|████▍     | 4401/10000 [11:12<15:35,  5.98it/s]

{'R_grad_norm': 0.8784329009056091, 'training_loss': 1.714373410344124}


 46%|████▌     | 4601/10000 [11:44<14:08,  6.36it/s]

{'R_grad_norm': 0.8725066423416138, 'training_loss': 1.727242425084114}


 48%|████▊     | 4801/10000 [12:16<12:48,  6.76it/s]

{'R_grad_norm': 0.8620118543505668, 'training_loss': 1.7109360086917877}


 50%|█████     | 5001/10000 [12:47<12:47,  6.51it/s]

{'R_grad_norm': 0.8790587383508682, 'training_loss': 1.7151882821321487}


 52%|█████▏    | 5201/10000 [13:18<12:02,  6.64it/s]

{'R_grad_norm': 0.8757707533240319, 'training_loss': 1.7180528968572617}


 54%|█████▍    | 5401/10000 [13:48<11:51,  6.46it/s]

{'R_grad_norm': 0.8704691034555435, 'training_loss': 1.7218947041034698}


 56%|█████▌    | 5601/10000 [14:19<11:25,  6.41it/s]

{'R_grad_norm': 0.8688487243652344, 'training_loss': 1.709959573149681}


 58%|█████▊    | 5801/10000 [14:50<10:39,  6.57it/s]

{'R_grad_norm': 0.852636870443821, 'training_loss': 1.7144217717647552}


 60%|██████    | 6001/10000 [15:21<10:11,  6.54it/s]

{'R_grad_norm': 0.8656411778926849, 'training_loss': 1.72062313914299}


 62%|██████▏   | 6201/10000 [15:52<10:18,  6.14it/s]

{'R_grad_norm': 0.8429254227876664, 'training_loss': 1.714916695356369}


 64%|██████▍   | 6400/10000 [16:23<10:05,  5.94it/s]

{'R_grad_norm': 0.8621246588230133, 'training_loss': 1.7147205489873887}


 66%|██████▌   | 6601/10000 [16:54<08:37,  6.57it/s]

{'R_grad_norm': 0.8727171936631203, 'training_loss': 1.720009589791298}


 68%|██████▊   | 6801/10000 [17:24<08:00,  6.66it/s]

{'R_grad_norm': 0.8764675840735435, 'training_loss': 1.7226540178060532}


 70%|███████   | 7001/10000 [17:57<08:02,  6.21it/s]

{'R_grad_norm': 0.872649516761303, 'training_loss': 1.7129084855318069}


 72%|███████▏  | 7201/10000 [18:32<08:21,  5.58it/s]

{'R_grad_norm': 0.8681100276112557, 'training_loss': 1.7200444024801254}


 74%|███████▍  | 7401/10000 [19:06<06:47,  6.37it/s]

{'R_grad_norm': 0.8686911013722419, 'training_loss': 1.7089224362373352}


 76%|███████▌  | 7601/10000 [19:40<06:01,  6.64it/s]

{'R_grad_norm': 0.8653464916348458, 'training_loss': 1.7151049596071244}


 78%|███████▊  | 7801/10000 [20:13<06:24,  5.72it/s]

{'R_grad_norm': 0.8702481263875961, 'training_loss': 1.7109094780683518}


 80%|████████  | 8001/10000 [20:46<05:12,  6.40it/s]

{'R_grad_norm': 0.8723378145694732, 'training_loss': 1.7147845125198364}


 82%|████████▏ | 8201/10000 [21:18<04:22,  6.84it/s]

{'R_grad_norm': 0.8705295652151108, 'training_loss': 1.7059847044944763}


 84%|████████▍ | 8401/10000 [21:49<04:07,  6.46it/s]

{'R_grad_norm': 0.8634156930446625, 'training_loss': 1.71678406894207}


 86%|████████▌ | 8600/10000 [22:19<03:31,  6.61it/s]

{'R_grad_norm': 0.8579341465234757, 'training_loss': 1.7180260282754898}


 88%|████████▊ | 8801/10000 [22:50<02:58,  6.71it/s]

{'R_grad_norm': 0.8634241539239883, 'training_loss': 1.7106476348638535}


 90%|█████████ | 9001/10000 [23:21<02:44,  6.09it/s]

{'R_grad_norm': 0.8777515268325806, 'training_loss': 1.722537910938263}


 92%|█████████▏| 9201/10000 [23:55<02:10,  6.12it/s]

{'R_grad_norm': 0.8811356291174889, 'training_loss': 1.7170611697435378}


 94%|█████████▍| 9401/10000 [24:27<01:31,  6.54it/s]

{'R_grad_norm': 0.8681453493237495, 'training_loss': 1.7119470429420471}


 96%|█████████▌| 9601/10000 [24:57<00:59,  6.70it/s]

{'R_grad_norm': 0.8883127182722091, 'training_loss': 1.7186156529188157}


 98%|█████████▊| 9801/10000 [25:28<00:29,  6.67it/s]

{'R_grad_norm': 0.8799321511387825, 'training_loss': 1.7148295032978058}


100%|██████████| 10000/10000 [25:58<00:00,  6.42it/s]


{'R_grad_norm': 0.8736403515934944, 'training_loss': 1.711547617316246}
finish training (10000)
saving.. out/subspace_partition/copy_transformer
evaluating (25 steps)...
 ******* eval result *******
mean (weighted) 1.7082897424697876
mean (unweighted) 1.7082897424697876
tensor([1.4647, 1.9519])
