<a target="_blank" href="https://colab.research.google.com/github/ai-safety-foundation/sparse_autoencoder/blob/main/docs/content/demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Quick Start Training Demo

This is a quick start demo to get training a SAE right away. All you need to do is choose a few
hyperparameters (like the model to train on), and then set it off.
By default it replicates Neel Nanda's
[comment on the Anthropic dictionary learning
paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html#comment-nanda).

## Setup

### Imports

In [1]:
# Check if we're in Colab
try:
    import google.colab  # noqa: F401 # type: ignore

    in_colab = True
except ImportError:
    in_colab = False

#  Install if in Colab
if in_colab:
    %pip install sparse_autoencoder transformer_lens transformers wandb

# Otherwise enable hot reloading in dev mode
if not in_colab:
    %load_ext autoreload
    %autoreload 2

In [2]:
import os

from sparse_autoencoder import (
    sweep,
    SweepConfig,
    Hyperparameters,
    SourceModelHyperparameters,
    Parameter,
    SourceDataHyperparameters,
    Method,
    LossHyperparameters,
    OptimizerHyperparameters,
)
import wandb


os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_NOTEBOOK_NAME"] = "demo.ipynb"

### Hyperparameters

Customize any hyperparameters you want below (by default we're sweeping over l1 coefficient and
learning rate):

In [3]:
sweep_config = SweepConfig(
    parameters=Hyperparameters(
        loss=LossHyperparameters(
            l1_coefficient=Parameter(values=[1e-3, 1e-4, 1e-5]),
        ),
        optimizer=OptimizerHyperparameters(
            lr=Parameter(values=[1e-3, 1e-4, 1e-5]),
        ),
        source_model=SourceModelHyperparameters(
            name=Parameter("gelu-2l"),
            hook_site=Parameter("mlp_out"),
            hook_layer=Parameter(0),
            hook_dimension=Parameter(512),
        ),
        source_data=SourceDataHyperparameters(
            dataset_path=Parameter("NeelNanda/c4-code-tokenized-2b"),
        ),
    ),
    method=Method.RANDOM,
)
sweep_config

SweepConfig(parameters=Hyperparameters(
    source_data=SourceDataHyperparameters(dataset_path=Parameter(value=NeelNanda/c4-code-tokenized-2b), context_size=Parameter(value=128))
    source_model=SourceModelHyperparameters(name=Parameter(value=gelu-2l), hook_site=Parameter(value=mlp_out), hook_layer=Parameter(value=0), hook_dimension=Parameter(value=512), dtype=Parameter(value=float32))
    activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_resamples=Parameter(value=4), n_steps_collate=Parameter(value=100000000), resample_dataset_size=Parameter(value=819200), dead_neuron_threshold=Parameter(value=0.0))
    autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=4))
    loss=LossHyperparameters(l1_coefficient=Parameter(values=[0.001, 0.0001, 1e-05]))
    optimizer=OptimizerHyperparameters(lr=Parameter(values=[0.001, 0.0001, 1e-05]), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_deca

### Run the sweep

In [None]:
sweep(sweep_config=sweep_config)

wandb.finish()