# Training

In [1]:
# wandb MA-sae-train: 03ca085d27243e9a7876d511e0c7402d861df34c

In [1]:
from sae_lens import (
    SAETrainingRunner, LanguageModelSAERunnerConfig,
    upload_saes_to_huggingface
)

In [2]:
# ——— Toggle TEST_MODE ——————————————————————————————————————————————————
# If True → tiny run, no logging, no uploads. If False → full run + HuggingFace upload.
# Full run should be done via slurm job!
TEST_MODE = False

In [3]:
# —— Schedule & batch sizing —————————————————————————————————————
if TEST_MODE:
    total_steps        = 10
    batch_size_tokens  = 512
    dataset_path       = "karpathy/tiny_shakespeare"
    log_to_wandb       = False
    wandb_project      = None
    n_checkpoints      = 0
    checkpoint_path    = None

    # disable all precision tricks to avoid dtype mismatches
    use_autocast       = False
    use_autocast_lm    = False
    use_compile_sae    = False
    use_compile_llm    = False
    sae_dtype          = "float32"
else:
    total_steps        = 75_000
    batch_size_tokens  = 4_096
    # dataset_path       = "wikitext/wikitext-103-raw-v1"
    dataset_path  = "apollo-research/monology-pile-uncopyrighted-tokenizer-EleutherAI-gpt-neox-20b"
    log_to_wandb       = True
    wandb_project      = "MA-sae-train"
    n_checkpoints      = 5
    checkpoint_path    = "./checkpoints"

    use_autocast       = True
    use_autocast_lm    = True
    use_compile_sae    = True
    use_compile_llm    = True
    sae_dtype          = "float16"  # or "bfloat16" if you want to match the model

In [5]:
total_training_tokens = total_steps * batch_size_tokens
lr_warmup_steps       = int(0.10 * total_steps)
lr_decay_steps        = total_steps
l1_warmup_steps       = int(0.20 * total_steps)

1) Using `wikitext-103` for the full run: is its domain coverage sufficient,
   or would a subset of the Pile/OpenWebText better capture diverse contexts? - 
--> !Check
2) Expansion factor=16 on small data can kill features. For a quick run,
   we might try 8 or even 4 to see denser utilization.
3) dtype="float16" + compile_sae can speed up, but have you compared
   reconstruction quality vs float32?

In [6]:
cfg = LanguageModelSAERunnerConfig(
    # — data & model hooks —
    model_name                    = "EleutherAI/pythia-410m-deduped",
    hook_name                     = "blocks.4.hook_mlp_out",
    hook_layer                    = 4,
    d_in                          = 1024,
    dataset_path                  = dataset_path,
    is_dataset_tokenized          = True,#False,
    streaming                     = True,

    # — SAE architecture & sparsity —
    architecture                  = "standard",
    expansion_factor              = 16,
    l1_coefficient                = 2.0,
    l1_warm_up_steps              = l1_warmup_steps,
    normalize_activations         = "expected_average_only_in",
    mse_loss_normalization        = "layer",

    # — init & symmetry —
    b_dec_init_method             = "zeros",
    init_encoder_as_decoder_transpose = True,
    decoder_heuristic_init        = False,

    # — optimization & scheduling —
    lr                            = 5e-5,
    adam_beta1                    = 0.9,
    adam_beta2                    = 0.999,
    lr_scheduler_name             = "cosineannealing",
    lr_warm_up_steps              = lr_warmup_steps,
    lr_decay_steps                = lr_decay_steps,

    # — context & batch sizing —
    context_size                  = 2048, #512,
    train_batch_size_tokens       = batch_size_tokens,

    # — logging, checkpoints & precision —
    training_tokens               = total_training_tokens,
    feature_sampling_window       = 1_000,
    log_to_wandb                  = log_to_wandb,
    wandb_project                 = wandb_project,
    wandb_log_frequency           = 100,
    n_checkpoints                 = n_checkpoints,
    checkpoint_path               = checkpoint_path,
    compile_sae                   = use_compile_sae,
    compile_llm                   = use_compile_llm,
    autocast                      = use_autocast,
    autocast_lm                   = use_autocast_lm,
    device                        = "cuda:0",
    seed                          = 42,
    dtype                         = sae_dtype,
)

In [7]:
print(f"{'TEST_MODE' if TEST_MODE else 'FULL_RUN'} ➞ steps={total_steps}, dataset={dataset_path}")

FULL_RUN ➞ steps=75000, dataset=apollo-research/monology-pile-uncopyrighted-tokenizer-EleutherAI-gpt-neox-20b


In [8]:
# — run training —
sparse_autoencoder = SAETrainingRunner(cfg).run()

Loaded pretrained model EleutherAI/pythia-410m-deduped into HookedTransformer


Downloading readme:   0%|          | 0.00/303 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/338 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/338 [00:00<?, ?it/s]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtilmankerl[0m ([33mtilmankerl-technical-university-of-vienna[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training SAE:   0%|          | 0/307200000 [00:00<?, ?it/s]
Estimating norm scaling factor:   0%|          | 0/1000 [00:00<?, ?it/s][A

Refilling buffer:   0%|          | 0/10 [00:00<?, ?it/s][A[A

Refilling buffer:  10%|█         | 1/10 [00:12<01:49, 12.18s/it][A[A

Refilling buffer:  20%|██        | 2/10 [00:12<00:43,  5.45s/it][A[A

Refilling buffer:  30%|███       | 3/10 [00:13<00:23,  3.30s/it][A[A

Refilling buffer:  40%|████      | 4/10 [00:14<00:13,  2.29s/it][A[A

Refilling buffer:  50%|█████     | 5/10 [00:15<00:08,  1.74s/it][A[A

Refilling buffer:  60%|██████    | 6/10 [00:15<00:05,  1.41s/it][A[A

Refilling buffer:  70%|███████   | 7/10 [00:16<00:03,  1.20s/it][A[A

Refilling buffer:  80%|████████  | 8/10 [00:17<00:02,  1.05s/it][A[A

Refilling buffer:  90%|█████████ | 9/10 [00:18<00:00,  1.05it/s][A[A

Refilling buffer: 100%|██████████| 10/10 [00:18<00:00,  1.12it/s][A[A

                                                                 [A[A

Refil

OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB. GPU 0 has a total capacity of 39.50 GiB of which 6.28 GiB is free. Including non-PyTorch memory, this process has 33.21 GiB memory in use. Of the allocated memory 27.81 GiB is allocated by PyTorch, and 4.90 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)