# SAE Lens Experiments / Tutorial

In [1]:
import torch
import os
import sys

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner


if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


This notebook is for experiments using the [SAE Lens](https://github.com/jbloomAus/SAELens/tree/main) library to train a sparse autoencoder

It will particularly follow [this](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb) tutorial

## Model Selection and Evaluation

Here we will test the outputs of a basic model and create some visualisations. 
We are using the `tiny-stories-1L-21M` model. 
Available models can be found in the [TransformerLens](https://github.com/neelnanda-io/TransformerLens?tab=readme-ov-file) documentation [here](https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)

In [2]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("tiny-stories-1L-21M")

Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


Generate some text from the model

In [3]:
for i in range(5):
    display(
        model.generate(
            "Once upon a time",
            stop_at_eos=False,
            temperature=1,
            verbose=False,
            max_new_tokens=50
        )
    )

'Once upon a time, there was a judge. He was very sick and he was performing without any friends. He felt very proud of himself for finding the toy, but he also thought he deserved to win the prize for nothing. He decided to go look for it when'

'Once upon a time, there was a little girl named Lucy. Missy loved to ring the door. Every day she would put on her pink dress and go outside to play.\n\nOne day, Missy saw Lucy wearing a pretty pink dress. She felt so'

'Once upon a time, a chicken went exploring. Lily wanted to explore more and find out what it was. It was hard to see, but a bad man sitting outside the barn. The chicken was determined to solve the secret meat, but it made Lily very naughty with'

'Once upon a time, there was a happy little girl. She loved to play in the sunshine. One day, the sun was shining so brightly, and the little girl enjoyed the warmth and decided to soar in the sky. However, as the sun started to set,'

"Once upon a time, there was a home. He was a happy and reliable worker. He didn't have to work harder than the grown ups and the passengers. \n\nOne day he scattered candy around the park and left them out in the sun. He said"

This is supposed to show that the model can reliably repeat the name of the character in the story. Below we test to see the probability and rank the model assigns to the expected name of the character that should be used as the next token after the prompt.

In [4]:
from transformer_lens.utils import test_prompt

# Test the model with a prompt
test_prompt(
    "Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,",
    " Lily",
    model,
    prepend_space_to_answer=False
)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' girl', '.', ' On', ' her', ' big', ' adventure', ',']
Tokenized answer: [' Lily']


Top 0th token. Logit: 20.48 Prob: 71.06% Token: | she|
Top 1th token. Logit: 18.81 Prob: 13.46% Token: | Lily|
Top 2th token. Logit: 17.35 Prob:  3.11% Token: | the|
Top 3th token. Logit: 17.26 Prob:  2.86% Token: | her|
Top 4th token. Logit: 16.74 Prob:  1.70% Token: | there|
Top 5th token. Logit: 16.43 Prob:  1.25% Token: | they|
Top 6th token. Logit: 15.80 Prob:  0.66% Token: | all|
Top 7th token. Logit: 15.64 Prob:  0.56% Token: | things|
Top 8th token. Logit: 15.28 Prob:  0.39% Token: | one|
Top 9th token. Logit: 15.24 Prob:  0.38% Token: | lived|


We can use circuitsvis to visualise the top 5 tokens by log probability

In [5]:
import circuitsvis as cv

example_prompt = """Hi, how are you doing this? I'm really enjoying your posts"""
logits, cache = model.run_with_cache(example_prompt)
cv.logits.token_log_probs(
    model.to_tokens(example_prompt),
    model(example_prompt)[0].log_softmax(dim=-1),
    model.to_string,
)

## Training a Sparse Autoencoder

To train a SAE we need to create the runner-config, instantiate the runner and that is it!

In [2]:
cfg = LanguageModelSAERunnerConfig(
    # Data generating function
    model_name="pythia-70m-deduped",
    hook_point="blocks.3.hook_mlp_out",
    hook_point_layer=0,
    d_in=1024,
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",
    is_dataset_tokenized=True,
    # SAE Parameters
    mse_loss_normalization=None,
    expansion_factor=16,
    b_dec_init_method="geometric_median",
    # Training Parameters
    lr=8e-4,
    lr_scheduler_name="constant",
    lr_warm_up_steps=10_000,
    l1_coefficient=1e-3,
    lp_norm=1.,
    train_batch_size=4096,
    context_size=512,
    # Activation Store Parameters
    n_batches_in_buffer=64,
    training_tokens=1_000_000 * 50,
    store_batch_size=16,
    # Resampling protocol
    use_ghost_grads=False,
    feature_sampling_window=1000,
    dead_feature_window=1000,  # Not used as use_ghost_grads=False
    dead_feature_threshold=1e-4,  # Not used as use_ghost_grads=False
    # Wandb
    log_to_wandb=True,
    wandb_project="sae-lens-tutorial",
    wandb_log_frequency=10,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
    model_from_pretrained_kwargs={'checkpoint_index': 0}
)

TypeError: LanguageModelSAERunnerConfig.__init__() got an unexpected keyword argument 'model_from_pretrained_kwargs'

In [8]:
sparse_autoencoder_dictionary = language_model_sae_runner(cfg)

Loaded pretrained model tiny-stories-1L-21M into HookedTransformer
Moving model to device:  cuda
Run name: 16384-L1-0.001-LR-0.0008-Tokens-5.000e+07
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 12207
Total wandb updates: 1220
n_tokens_per_feature_sampling_window (millions): 2097.152
n_tokens_per_dead_feature_window (millions): 2097.152
We will reset the sparsity calculation 12 times.
Number tokens in sparsity calculation window: 4.10e+06
Run name: 16384-L1-0.001-LR-0.0008-Tokens-5.000e+07
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 12207
Total wandb updates: 1220
n_tokens_per_feature_sampling_window (millions): 2097.152
n_tokens_per_dead_feature_window (millions): 2097.152
We will reset the sparsity calculation 12 times.
Number tokens in sparsity calculation window: 4.10e+06


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
details/current_learning_rate,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
details/n_training_tokens,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
losses/ghost_grad_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,███████▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▁▁
losses/mse_loss,██▇▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
losses/overall_loss,████▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁
metrics/CE_loss_score,▁▆█
metrics/ce_loss_with_ablation,█▁▁
metrics/ce_loss_with_sae,█▄▁
metrics/ce_loss_without_sae,▁█▃

0,1
details/current_learning_rate,3e-05
details/n_training_tokens,1474560.0
losses/ghost_grad_loss,0.0
losses/l1_loss,525.8311
losses/mse_loss,0.09542
losses/overall_loss,0.62125
metrics/CE_loss_score,0.85667
metrics/ce_loss_with_ablation,8.20139
metrics/ce_loss_with_sae,2.80313
metrics/ce_loss_without_sae,1.90018


  lambda data: self._console_raw_callback("stderr", data),


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112180588922153, max=1.0…


  0%|          | 0/100 [00:00<?, ?it/s][A
Objective value: 3489832.0000:   0%|          | 0/100 [00:00<?, ?it/s][A
Objective value: 3489829.7500:   0%|          | 0/100 [00:00<?, ?it/s][A
Objective value: 3489830.0000:   3%|▎         | 3/100 [00:00<00:03, 27.30it/s]
  out = torch.tensor(origin, dtype=self.dtype, device=self.device)

Training SAE:   0%|          | 0/50000000 [00:00<?, ?it/s][A
1| MSE Loss 0.354 | L1 0.956:   0%|          | 0/50000000 [00:01<?, ?it/s][A
1| MSE Loss 0.354 | L1 0.956:   0%|          | 4096/50000000 [00:01<4:42:46, 2946.82it/s][A
2| MSE Loss 0.361 | L1 0.964:   0%|          | 4096/50000000 [00:01<4:42:46, 2946.82it/s][A
3| MSE Loss 0.365 | L1 0.966:   0%|          | 8192/50000000 [00:01<4:42:44, 2946.82it/s][A
3| MSE Loss 0.365 | L1 0.966:   0%|          | 12288/50000000 [00:01<1:27:00, 9574.80it/s][A
4| MSE Loss 0.356 | L1 0.960:   0%|          | 12288/50000000 [00:01<1:27:00, 9574.80it/s][A
5| MSE Loss 0.358 | L1 0.960:   0%|          | 16384/5

VBox(children=(Label(value='128.196 MB of 128.196 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
details/current_learning_rate,▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████████
details/n_training_tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,█▆▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,█▄▅▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,█▆▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score,▁▇▄▄▅▆▆▇▇▇▇▇▇███████████████████████████
metrics/ce_loss_with_ablation,▄▃▄▄▂▃▅▄▄▄▄▄▃▃▆▄▃▅▆▅▅▄▄█▆▄▁▄▅▅▄▂▁▁▄▃▄▃▇▄
metrics/ce_loss_with_sae,█▂▅▅▄▃▃▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/ce_loss_without_sae,▂▅▃▃▃▃▇▅▂▇▇▇▃▁▄█▃▅▄▃▆▅▃▃▅▆▆▆▄▅▇▃▆▂██▇▂▇▄

0,1
details/current_learning_rate,0.0008
details/n_training_tokens,49971200.0
losses/ghost_grad_loss,0.0
losses/l1_loss,21.71001
losses/mse_loss,0.04882
losses/overall_loss,0.07053
metrics/CE_loss_score,0.92894
metrics/ce_loss_with_ablation,8.24038
metrics/ce_loss_with_sae,2.36293
metrics/ce_loss_without_sae,1.91329
