# Training

In [1]:
# wandb MA-sae-train: 03ca085d27243e9a7876d511e0c7402d861df34c

In [14]:
from sae_lens import (
    SAETrainingRunner, LanguageModelSAERunnerConfig,
    upload_saes_to_huggingface
)

In [3]:
total_steps        = 75_000
batch_size_tokens  = 4_096
# dataset_path       = "wikitext/wikitext-103-raw-v1"
dataset_path  = "apollo-research/monology-pile-uncopyrighted-tokenizer-EleutherAI-gpt-neox-20b"
log_to_wandb       = True
wandb_project      = "MA-sae-train"
n_checkpoints      = 5
checkpoint_path    = "./checkpoints"

use_autocast       = True
use_autocast_lm    = True
use_compile_sae    = True
use_compile_llm    = True
sae_dtype          = "float16"  # or "bfloat16" if you want to match the model

In [5]:
total_training_tokens = total_steps * batch_size_tokens
lr_warmup_steps       = int(0.10 * total_steps)
lr_decay_steps        = total_steps
l1_warmup_steps       = int(0.20 * total_steps)

1) Using `wikitext-103` for the full run: is its domain coverage sufficient,
   or would a subset of the Pile/OpenWebText better capture diverse contexts? - 
--> !Check
2) Expansion factor=16 on small data can kill features. For a quick run,
   we might try 8 or even 4 to see denser utilization.
3) dtype="float16" + compile_sae can speed up, but have you compared
   reconstruction quality vs float32?

In [6]:
cfg = LanguageModelSAERunnerConfig(
    # — data & model hooks —
    model_name                    = "EleutherAI/pythia-410m-deduped",
    hook_name                     = "blocks.4.hook_mlp_out",
    hook_layer                    = 4,
    d_in                          = 1024,
    dataset_path                  = dataset_path,
    is_dataset_tokenized          = True,#False,
    streaming                     = True,

    # — SAE architecture & sparsity —
    architecture                  = "standard",
    expansion_factor              = 16,
    l1_coefficient                = 2.0,
    l1_warm_up_steps              = l1_warmup_steps,
    normalize_activations         = "expected_average_only_in",
    mse_loss_normalization        = "layer",

    # — init & symmetry —
    b_dec_init_method             = "zeros",
    init_encoder_as_decoder_transpose = True,
    decoder_heuristic_init        = False,

    # — optimization & scheduling —
    lr                            = 5e-5,
    adam_beta1                    = 0.9,
    adam_beta2                    = 0.999,
    lr_scheduler_name             = "cosineannealing",
    lr_warm_up_steps              = lr_warmup_steps,
    lr_decay_steps                = lr_decay_steps,

    # — context & batch sizing —
    context_size                  = 2048, #512,
    train_batch_size_tokens       = batch_size_tokens,

    # — logging, checkpoints & precision —
    training_tokens               = total_training_tokens,
    feature_sampling_window       = 1_000,
    log_to_wandb                  = log_to_wandb,
    wandb_project                 = wandb_project,
    wandb_log_frequency           = 100,
    n_checkpoints                 = n_checkpoints,
    checkpoint_path               = checkpoint_path,
    compile_sae                   = use_compile_sae,
    compile_llm                   = use_compile_llm,
    autocast                      = use_autocast,
    autocast_lm                   = use_autocast_lm,
    device                        = "cuda:0",
    seed                          = 42,
    dtype                         = sae_dtype,
)

In [7]:
print(f"{'TEST_MODE' if TEST_MODE else 'FULL_RUN'} ➞ steps={total_steps}, dataset={dataset_path}")

FULL_RUN ➞ steps=75000, dataset=apollo-research/monology-pile-uncopyrighted-tokenizer-EleutherAI-gpt-neox-20b


In [8]:
# — run training —
sparse_autoencoder = SAETrainingRunner(cfg).run()

Loaded pretrained model EleutherAI/pythia-410m-deduped into HookedTransformer


Downloading readme:   0%|          | 0.00/303 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/338 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/338 [00:00<?, ?it/s]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtilmankerl[0m ([33mtilmankerl-technical-university-of-vienna[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training SAE:   0%|          | 0/307200000 [00:00<?, ?it/s]
Estimating norm scaling factor:   0%|          | 0/1000 [00:00<?, ?it/s][A

Refilling buffer:   0%|          | 0/10 [00:00<?, ?it/s][A[A

Refilling buffer:  10%|█         | 1/10 [00:12<01:49, 12.18s/it][A[A

Refilling buffer:  20%|██        | 2/10 [00:12<00:43,  5.45s/it][A[A

Refilling buffer:  30%|███       | 3/10 [00:13<00:23,  3.30s/it][A[A

Refilling buffer:  40%|████      | 4/10 [00:14<00:13,  2.29s/it][A[A

Refilling buffer:  50%|█████     | 5/10 [00:15<00:08,  1.74s/it][A[A

Refilling buffer:  60%|██████    | 6/10 [00:15<00:05,  1.41s/it][A[A

Refilling buffer:  70%|███████   | 7/10 [00:16<00:03,  1.20s/it][A[A

Refilling buffer:  80%|████████  | 8/10 [00:17<00:02,  1.05s/it][A[A

Refilling buffer:  90%|█████████ | 9/10 [00:18<00:00,  1.05it/s][A[A

Refilling buffer: 100%|██████████| 10/10 [00:18<00:00,  1.12it/s][A[A

                                                                 [A[A

Refil

OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB. GPU 0 has a total capacity of 39.50 GiB of which 6.28 GiB is free. Including non-PyTorch memory, this process has 33.21 GiB memory in use. Of the allocated memory 27.81 GiB is allocated by PyTorch, and 4.90 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [17]:
import os
import torch
import torch.distributed as dist

from datasets import load_dataset
from sae_lens import (
    SAETrainingRunner,
    LanguageModelSAERunnerConfig,
    upload_saes_to_huggingface
)

from sae_lens.saes import TopKTrainingSAEConfig
from sae_lens.config import LoggingConfig

# — config parameters —
# projected steps for convergence (gao et al.)
TOTAL_STEPS           = 10 # try different strategies
GLOBAL_BATCH_TOKENS   = 512
TRAIN_TOKENS         = TOTAL_STEPS * GLOBAL_BATCH_TOKENS
LR_WARMUP            = int(0.05 * TOTAL_STEPS)

sae_config = TopKTrainingSAEConfig(
    k                 = 1_024,      # active latents        
    d_sae               = 512 * 8, # # latent dimension (expansion_factor × d_in)
    d_in = 512,
    # expansion_factor  = 8,
)  # no l1_

cfg = LanguageModelSAERunnerConfig(
    # model & hook
    model_name = "HuggingFaceTB/SmolLM2-135M",
    hook_name = "blocks.7.hook_resid_post",
    hook_layer = 7,
                    
    # init_encoder_as_decoder_transpose = True,
    context_size      = 64,

    sae = sae_config,

    # data
    dataset_path = "MisterXY89/lmsys-chat-1m-english-tokenized-smollm-135M",
    is_dataset_tokenized = True,
    streaming = True,

    lr                    = 1e-4,
    adam_beta1            = 0.9,
    adam_beta2            = 0.999,
    lr_scheduler_name     = "cosineannealing",
    lr_warm_up_steps      = LR_WARMUP,
    lr_decay_steps        = TOTAL_STEPS,

    train_batch_size_tokens = GLOBAL_BATCH_TOKENS,
    training_tokens       = TRAIN_TOKENS,


    logger = LoggingConfig(
        log_to_wandb          = True,
        wandb_project         = "SAE-SmolLM-135M",
        wandb_log_frequency   = 100,
    ),
    n_checkpoints         = 5,
    checkpoint_path       = "./checkpoints",

    compile_sae           = False,
    compile_llm           = False,
    autocast              = True,
    autocast_lm           = True,
    dtype                 = "float32",

    device                = f"cuda:0",
    seed                  = 42,

    # use_sparse_kernels    = True,
    # ema_decay             = 0.999,
)
# — data loader —
train_ds = load_dataset(
    cfg.dataset_path,
    split="train",
    streaming=True,
    trust_remote_code=True,
)


sae_model = SAETrainingRunner(cfg, override_dataset=train_ds).run()



Downloading readme:   0%|          | 0.00/296 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

You just passed in a dataset which will override the one specified in your configuration: MisterXY89/lmsys-chat-1m-english-tokenized-smollm-135M. As a consequence this run will not be reproducible via configuration alone.


ValueError: HuggingFaceTB/SmolLM2-135M not found. Valid official model names (excl aliases): ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl', 'distilgpt2', 'facebook/opt-125m', 'facebook/opt-1.3b', 'facebook/opt-2.7b', 'facebook/opt-6.7b', 'facebook/opt-13b', 'facebook/opt-30b', 'facebook/opt-66b', 'EleutherAI/gpt-neo-125M', 'EleutherAI/gpt-neo-1.3B', 'EleutherAI/gpt-neo-2.7B', 'EleutherAI/gpt-j-6B', 'EleutherAI/gpt-neox-20b', 'stanford-crfm/alias-gpt2-small-x21', 'stanford-crfm/battlestar-gpt2-small-x49', 'stanford-crfm/caprica-gpt2-small-x81', 'stanford-crfm/darkmatter-gpt2-small-x343', 'stanford-crfm/expanse-gpt2-small-x777', 'stanford-crfm/arwen-gpt2-medium-x21', 'stanford-crfm/beren-gpt2-medium-x49', 'stanford-crfm/celebrimbor-gpt2-medium-x81', 'stanford-crfm/durin-gpt2-medium-x343', 'stanford-crfm/eowyn-gpt2-medium-x777', 'EleutherAI/pythia-14m', 'EleutherAI/pythia-31m', 'EleutherAI/pythia-70m', 'EleutherAI/pythia-160m', 'EleutherAI/pythia-410m', 'EleutherAI/pythia-1b', 'EleutherAI/pythia-1.4b', 'EleutherAI/pythia-2.8b', 'EleutherAI/pythia-6.9b', 'EleutherAI/pythia-12b', 'EleutherAI/pythia-70m-deduped', 'EleutherAI/pythia-160m-deduped', 'EleutherAI/pythia-410m-deduped', 'EleutherAI/pythia-1b-deduped', 'EleutherAI/pythia-1.4b-deduped', 'EleutherAI/pythia-2.8b-deduped', 'EleutherAI/pythia-6.9b-deduped', 'EleutherAI/pythia-12b-deduped', 'EleutherAI/pythia-70m-v0', 'EleutherAI/pythia-160m-v0', 'EleutherAI/pythia-410m-v0', 'EleutherAI/pythia-1b-v0', 'EleutherAI/pythia-1.4b-v0', 'EleutherAI/pythia-2.8b-v0', 'EleutherAI/pythia-6.9b-v0', 'EleutherAI/pythia-12b-v0', 'EleutherAI/pythia-70m-deduped-v0', 'EleutherAI/pythia-160m-deduped-v0', 'EleutherAI/pythia-410m-deduped-v0', 'EleutherAI/pythia-1b-deduped-v0', 'EleutherAI/pythia-1.4b-deduped-v0', 'EleutherAI/pythia-2.8b-deduped-v0', 'EleutherAI/pythia-6.9b-deduped-v0', 'EleutherAI/pythia-12b-deduped-v0', 'EleutherAI/pythia-160m-seed1', 'EleutherAI/pythia-160m-seed2', 'EleutherAI/pythia-160m-seed3', 'NeelNanda/SoLU_1L_v9_old', 'NeelNanda/SoLU_2L_v10_old', 'NeelNanda/SoLU_4L_v11_old', 'NeelNanda/SoLU_6L_v13_old', 'NeelNanda/SoLU_8L_v21_old', 'NeelNanda/SoLU_10L_v22_old', 'NeelNanda/SoLU_12L_v23_old', 'NeelNanda/SoLU_1L512W_C4_Code', 'NeelNanda/SoLU_2L512W_C4_Code', 'NeelNanda/SoLU_3L512W_C4_Code', 'NeelNanda/SoLU_4L512W_C4_Code', 'NeelNanda/SoLU_6L768W_C4_Code', 'NeelNanda/SoLU_8L1024W_C4_Code', 'NeelNanda/SoLU_10L1280W_C4_Code', 'NeelNanda/SoLU_12L1536W_C4_Code', 'NeelNanda/GELU_1L512W_C4_Code', 'NeelNanda/GELU_2L512W_C4_Code', 'NeelNanda/GELU_3L512W_C4_Code', 'NeelNanda/GELU_4L512W_C4_Code', 'NeelNanda/Attn_Only_1L512W_C4_Code', 'NeelNanda/Attn_Only_2L512W_C4_Code', 'NeelNanda/Attn_Only_3L512W_C4_Code', 'NeelNanda/Attn_Only_4L512W_C4_Code', 'NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr', 'NeelNanda/SoLU_1L512W_Wiki_Finetune', 'NeelNanda/SoLU_4L512W_Wiki_Finetune', 'ArthurConmy/redwood_attn_2l', 'llama-7b-hf', 'llama-13b-hf', 'llama-30b-hf', 'llama-65b-hf', 'meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-7b-chat-hf', 'meta-llama/Llama-2-13b-hf', 'meta-llama/Llama-2-13b-chat-hf', 'meta-llama/Llama-2-70b-chat-hf', 'codellama/CodeLlama-7b-hf', 'codellama/CodeLlama-7b-Python-hf', 'codellama/CodeLlama-7b-Instruct-hf', 'meta-llama/Meta-Llama-3-8B', 'meta-llama/Meta-Llama-3-8B-Instruct', 'meta-llama/Meta-Llama-3-70B', 'meta-llama/Meta-Llama-3-70B-Instruct', 'meta-llama/Llama-3.1-70B', 'meta-llama/Llama-3.1-8B', 'meta-llama/Llama-3.1-8B-Instruct', 'meta-llama/Llama-3.1-70B-Instruct', 'meta-llama/Llama-3.2-1B', 'meta-llama/Llama-3.2-3B', 'meta-llama/Llama-3.2-1B-Instruct', 'meta-llama/Llama-3.2-3B-Instruct', 'meta-llama/Llama-3.3-70B-Instruct', 'Baidicoot/Othello-GPT-Transformer-Lens', 'google-bert/bert-base-cased', 'google-bert/bert-base-uncased', 'google-bert/bert-large-cased', 'google-bert/bert-large-uncased', 'roneneldan/TinyStories-1M', 'roneneldan/TinyStories-3M', 'roneneldan/TinyStories-8M', 'roneneldan/TinyStories-28M', 'roneneldan/TinyStories-33M', 'roneneldan/TinyStories-Instruct-1M', 'roneneldan/TinyStories-Instruct-3M', 'roneneldan/TinyStories-Instruct-8M', 'roneneldan/TinyStories-Instruct-28M', 'roneneldan/TinyStories-Instruct-33M', 'roneneldan/TinyStories-1Layer-21M', 'roneneldan/TinyStories-2Layers-33M', 'roneneldan/TinyStories-Instuct-1Layer-21M', 'roneneldan/TinyStories-Instruct-2Layers-33M', 'stabilityai/stablelm-base-alpha-3b', 'stabilityai/stablelm-base-alpha-7b', 'stabilityai/stablelm-tuned-alpha-3b', 'stabilityai/stablelm-tuned-alpha-7b', 'mistralai/Mistral-7B-v0.1', 'mistralai/Mistral-7B-Instruct-v0.1', 'mistralai/Mistral-Small-24B-Base-2501', 'mistralai/Mistral-Nemo-Base-2407', 'mistralai/Mixtral-8x7B-v0.1', 'mistralai/Mixtral-8x7B-Instruct-v0.1', 'bigscience/bloom-560m', 'bigscience/bloom-1b1', 'bigscience/bloom-1b7', 'bigscience/bloom-3b', 'bigscience/bloom-7b1', 'bigcode/santacoder', 'Qwen/Qwen-1_8B', 'Qwen/Qwen-7B', 'Qwen/Qwen-14B', 'Qwen/Qwen-1_8B-Chat', 'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat', 'Qwen/Qwen1.5-0.5B', 'Qwen/Qwen1.5-0.5B-Chat', 'Qwen/Qwen1.5-1.8B', 'Qwen/Qwen1.5-1.8B-Chat', 'Qwen/Qwen1.5-4B', 'Qwen/Qwen1.5-4B-Chat', 'Qwen/Qwen1.5-7B', 'Qwen/Qwen1.5-7B-Chat', 'Qwen/Qwen1.5-14B', 'Qwen/Qwen1.5-14B-Chat', 'Qwen/Qwen2-0.5B', 'Qwen/Qwen2-0.5B-Instruct', 'Qwen/Qwen2-1.5B', 'Qwen/Qwen2-1.5B-Instruct', 'Qwen/Qwen2-7B', 'Qwen/Qwen2-7B-Instruct', 'Qwen/Qwen2.5-0.5B', 'Qwen/Qwen2.5-0.5B-Instruct', 'Qwen/Qwen2.5-1.5B', 'Qwen/Qwen2.5-1.5B-Instruct', 'Qwen/Qwen2.5-3B', 'Qwen/Qwen2.5-3B-Instruct', 'Qwen/Qwen2.5-7B', 'Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-14B', 'Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-32B', 'Qwen/Qwen2.5-32B-Instruct', 'Qwen/Qwen2.5-72B', 'Qwen/Qwen2.5-72B-Instruct', 'Qwen/QwQ-32B-Preview', 'microsoft/phi-1', 'microsoft/phi-1_5', 'microsoft/phi-2', 'microsoft/Phi-3-mini-4k-instruct', 'microsoft/phi-4', 'google/gemma-2b', 'google/gemma-7b', 'google/gemma-2b-it', 'google/gemma-7b-it', 'google/gemma-2-2b', 'google/gemma-2-2b-it', 'google/gemma-2-9b', 'google/gemma-2-9b-it', 'google/gemma-2-27b', 'google/gemma-2-27b-it', '01-ai/Yi-6B', '01-ai/Yi-34B', '01-ai/Yi-6B-Chat', '01-ai/Yi-34B-Chat', 'google-t5/t5-small', 'google-t5/t5-base', 'google-t5/t5-large', 'ai-forever/mGPT']