<a target="_blank" href="https://colab.research.google.com/github/ai-safety-foundation/sparse_autoencoder/blob/main/docs/content/demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Quick Start Training Demo

This is a quick start demo to get training a SAE right away. All you need to do is choose a few
hyperparameters (like the model to train on), and then set it off.
By default it replicates Neel Nanda's
[comment on the Anthropic dictionary learning
paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html#comment-nanda).

## Setup

### Imports

In [1]:
# Check if we're in Colab
try:
    import google.colab  # noqa: F401 # type: ignore

    in_colab = True
except ImportError:
    in_colab = False

#  Install if in Colab
if in_colab:
    %pip install sparse_autoencoder transformer_lens transformers wandb

# Otherwise enable hot reloading in dev mode
if not in_colab:
    %load_ext autoreload
    %autoreload 2

In [2]:
import os

from sparse_autoencoder import (
    ActivationResamplerHyperparameters,
    Hyperparameters,
    LossHyperparameters,
    Method,
    OptimizerHyperparameters,
    Parameter,
    SourceDataHyperparameters,
    SourceModelHyperparameters,
    sweep,
    SweepConfig,
)


os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_NOTEBOOK_NAME"] = "demo.ipynb"

### Hyperparameters

Customize any hyperparameters you want below (by default we're sweeping over l1 coefficient and
learning rate):

In [3]:
sweep_config = SweepConfig(
    parameters=Hyperparameters(
        activation_resampler=ActivationResamplerHyperparameters(
            threshold_is_dead_portion_fires=Parameter(1e-6),
        ),
        loss=LossHyperparameters(
            l1_coefficient=Parameter(max=1e-2, min=4e-3),
        ),
        optimizer=OptimizerHyperparameters(
            lr=Parameter(max=1e-3, min=1e-5),
        ),
        source_model=SourceModelHyperparameters(
            name=Parameter("gelu-2l"),
            cache_names=Parameter(["blocks.0.hook_mlp_out", "blocks.1.hook_mlp_out"]),
            hook_dimension=Parameter(512),
        ),
        source_data=SourceDataHyperparameters(
            dataset_path=Parameter("NeelNanda/c4-code-tokenized-2b"),
        ),
    ),
    method=Method.RANDOM,
)
sweep_config

TypeError: SourceModelHyperparameters.__init__() got an unexpected keyword argument 'hook_layer'

### Run the sweep

In [None]:
sweep(sweep_config=sweep_config)