<a target="_blank" href="https://colab.research.google.com/github/ai-safety-foundation/sparse_autoencoder/blob/main/docs/content/demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Quick Start Training Demo

This is a quick start demo to get training a SAE right away. All you need to do is choose a few
hyperparameters (like the model to train on), and then set it off. By default it trains SAEs on all
MLP layers from GPT2 small.

## Setup

### Imports

In [1]:
# Check if we're in Colab
try:
    import google.colab  # noqa: F401 # type: ignore

    in_colab = True
except ImportError:
    in_colab = False

#  Install if in Colab
if in_colab:
    %pip install sparse_autoencoder transformer_lens transformers wandb

# Otherwise enable hot reloading in dev mode
if not in_colab:
    %load_ext autoreload
    %autoreload 2

In [2]:
import os

from sparse_autoencoder import (
    ActivationResamplerHyperparameters,
    Hyperparameters,
    LossHyperparameters,
    Method,
    OptimizerHyperparameters,
    Parameter,
    PipelineHyperparameters,
    SourceDataHyperparameters,
    SourceModelHyperparameters,
    sweep,
    SweepConfig,
)


os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_NOTEBOOK_NAME"] = "demo.ipynb"

### Hyperparameters

Customize any hyperparameters you want below (by default we're sweeping over l1 coefficient and
learning rate):

In [5]:
n_layers_gpt2_small = 12

sweep_config = SweepConfig(
    parameters=Hyperparameters(
        activation_resampler=ActivationResamplerHyperparameters(
            resample_interval=Parameter(200_000_000),
            n_activations_activity_collate=Parameter(100_000_000),
            threshold_is_dead_portion_fires=Parameter(1e-6),
            max_n_resamples=Parameter(4),
            resample_dataset_size=Parameter(200_000),
        ),
        loss=LossHyperparameters(
            l1_coefficient=Parameter(max=1e-2, min=4e-3),
        ),
        optimizer=OptimizerHyperparameters(
            lr=Parameter(max=1e-3, min=1e-5),
        ),
        source_model=SourceModelHyperparameters(
            name=Parameter("gpt2-small"),
            # Train in parallel on all MLP layers
            cache_names=Parameter(
                [f"blocks.{layer}.hook_mlp_out" for layer in range(n_layers_gpt2_small)]
            ),
            hook_dimension=Parameter(768),
        ),
        source_data=SourceDataHyperparameters(
            dataset_path=Parameter("alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"),
            context_size=Parameter(128),
            pre_tokenized=Parameter(value=True),
        ),
        pipeline=PipelineHyperparameters(
            max_activations=Parameter(1_000_000_000),
            checkpoint_frequency=Parameter(100_000_000),
            validation_frequency=Parameter(100_000_000),
            train_batch_size=Parameter(1024),
            max_store_size=Parameter(300_000),
        ),
    ),
    method=Method.RANDOM,
)
sweep_config

SweepConfig(parameters=Hyperparameters(
    source_data=SourceDataHyperparameters(dataset_path=Parameter(value=alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2), context_size=Parameter(value=128), dataset_column_name=Parameter(value=input_ids), dataset_dir=None, dataset_files=None, pre_download=Parameter(value=False), pre_tokenized=Parameter(value=True), tokenizer_name=None)
    source_model=SourceModelHyperparameters(name=Parameter(value=gpt2-small), cache_names=Parameter(value=['blocks.0.hook_mlp_out', 'blocks.1.hook_mlp_out', 'blocks.2.hook_mlp_out', 'blocks.3.hook_mlp_out', 'blocks.4.hook_mlp_out', 'blocks.5.hook_mlp_out', 'blocks.6.hook_mlp_out', 'blocks.7.hook_mlp_out', 'blocks.8.hook_mlp_out', 'blocks.9.hook_mlp_out', 'blocks.10.hook_mlp_out', 'blocks.11.hook_mlp_out']), hook_dimension=Parameter(value=768), dtype=Parameter(value=float32))
    activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_n_resamples=Parameter(v

### Run the sweep

In [None]:
sweep(sweep_config=sweep_config)