# Training a basic SAE with SAELens

This notebook is derived from the SAELens [training tutorial](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb). I ran it on a g6.12xlarge, and training took about 11 hours.

In order to use SAELens, you need to pick a model that's supported by [TransformerLens](https://transformerlensorg.github.io/TransformerLens/generated/model_properties_table.html). 
I tried to pick a model that didn't yet have a [published SAE](https://jbloomaus.github.io/SAELens/sae_table/). I selected `Qwen/Qwen2.5-1.5B-Instruct`.

As part of the training process, you can pick any pretraining dataset from [HuggingFace](https://huggingface.co/datasets). You do not need to use a dataset with any of your own information.

In [1]:
from IPython import get_ipython  # type: ignore

ipython = get_ipython()
assert ipython is not None
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

In [2]:
import torch
import os

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


This next block is directly from the SAE tutorial, except for specifying the model name and disabling logging to [WandB](https://wandb.ai/site).

In [3]:
total_training_steps = 30000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="Qwen/Qwen2.5-1.5B-Instruct",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=20,  # Only one layer in the model.
    d_in=1536,  # the width of the mlp output.
    dataset_path="Skylion007/openwebtext",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=False,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=False,  # always use wandb unless you are just testing code.
    wandb_project="sae_lens_tutorial",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    act_store_device='cpu',
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32",
)


Run name: 24576-L1-5-LR-5e-05-Tokens-1.229e+08
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 30000
Total wandb updates: 1000
n_tokens_per_feature_sampling_window (millions): 2097.152
n_tokens_per_dead_feature_window (millions): 2097.152
We will reset the sparsity calculation 30 times.
Number tokens in sparsity calculation window: 4.10e+06


In [4]:
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()

Loaded pretrained model Qwen/Qwen2.5-1.5B-Instruct into HookedTransformer


Training SAE:   0%|                                                                                                            | 0/122880000 [00:00<?, ?it/s]
Estimating norm scaling factor:   0%|                                                                                               | 0/1000 [00:00<?, ?it/s][A
Estimating norm scaling factor:   0%|                                                                                   | 1/1000 [02:10<36:04:41, 130.01s/it][A
Estimating norm scaling factor:   1%|▊                                                                                   | 10/1000 [02:10<2:34:59,  9.39s/it][A
Estimating norm scaling factor:   2%|█▋                                                                                  | 20/1000 [02:10<1:02:09,  3.81s/it][A
Estimating norm scaling factor:   3%|██▌                                                                                   | 30/1000 [02:10<33:19,  2.06s/it][A
Estimating norm scaling factor:   4%|

Estimating norm scaling factor:  37%|███████████████████████████████▎                                                     | 369/1000 [07:44<04:10,  2.52it/s][A
Estimating norm scaling factor:  38%|████████████████████████████████▏                                                    | 379/1000 [07:44<02:51,  3.63it/s][A
Estimating norm scaling factor:  38%|████████████████████████████████▏                                                    | 379/1000 [07:59<02:51,  3.63it/s][A
Estimating norm scaling factor:  38%|████████████████████████████████▋                                                    | 385/1000 [08:49<25:44,  2.51s/it][A
Estimating norm scaling factor:  39%|█████████████████████████████████▍                                                   | 394/1000 [08:49<17:27,  1.73s/it][A
Estimating norm scaling factor:  40%|██████████████████████████████████▎                                                  | 404/1000 [08:49<11:30,  1.16s/it][A
Estimating norm scaling factor:  4

Estimating norm scaling factor:  76%|████████████████████████████████████████████████████████████████▊                    | 762/1000 [14:39<01:06,  3.59it/s][A
Estimating norm scaling factor:  77%|█████████████████████████████████████████████████████████████████▎                   | 769/1000 [15:28<09:33,  2.48s/it][A
Estimating norm scaling factor:  78%|██████████████████████████████████████████████████████████████████▏                  | 778/1000 [15:28<06:23,  1.73s/it][A
Estimating norm scaling factor:  79%|██████████████████████████████████████████████████████████████████▉                  | 788/1000 [15:28<04:07,  1.17s/it][A
Estimating norm scaling factor:  80%|███████████████████████████████████████████████████████████████████▊                 | 798/1000 [15:29<02:41,  1.25it/s][A
Estimating norm scaling factor:  81%|████████████████████████████████████████████████████████████████████▋                | 808/1000 [15:29<01:46,  1.80it/s][A
Estimating norm scaling factor:  8

In [5]:
sparse_autoencoder.save_model('./models')

After saving the trained SAE locally, you may want to upload it to an S3 bucket in your account for persistence. Be sure to follow [best practices](https://docs.aws.amazon.com/AmazonS3/latest/userguide/security-best-practices.html) for S3 bucket security if you save any data to your AWS account.