<a target="_blank" href="https://colab.research.google.com/github/ai-safety-foundation/sparse_autoencoder/blob/main/docs/content/flexible_demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Flexible Training Demo

This demo shows you how to train a sparse autoencoder (SAE) on a
[TransformerLens](https://github.com/neelnanda-io/TransformerLens) model. It replicates Neel Nanda's
[comment on the Anthropic dictionary learning
paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html#comment-nanda).

## Introduction

The way this library works is that we provide all the components necessary to train a sparse
autoencoder. For the most part, these are just standard PyTorch modules. For example `AdamWithReset` is
just an extension of `torch.optim.Adam`, with a few extra bells and whistles that are needed for training a SAE
(e.g. a method to reset the optimizer state when you are also resampling dead neurons).

This is very flexible - it's easy for you to extend and change just one component if you want, just
like you'd do with a standard PyTorch mode. It also means it's very easy to see what is going on
under the hood. However to get you started, the following demo sets up a
default SAE that uses the implementation that Neel Nanda used in his comment above.

### Approach

The approach is pretty simple - we run a training pipeline that alternates between generating
activations from a *source model*, and training the *sparse autoencoder* model on these generated
activations.

## Setup

### Imports

In [1]:
# Check if we're in Colab
try:
    import google.colab  # noqa: F401 # type: ignore

    in_colab = True
except ImportError:
    in_colab = False

#  Install if in Colab
if in_colab:
    %pip install sparse_autoencoder transformer_lens transformers wandb

# Otherwise enable hot reloading in dev mode
if not in_colab:
    from IPython import get_ipython  # type: ignore

    ip = get_ipython()
    if ip is not None and ip.extension_manager is not None and not ip.extension_manager.loaded:
        ip.extension_manager.load("autoreload")  # type: ignore
        %autoreload 2

In [2]:
import os
from pathlib import Path

import torch
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device

from sparse_autoencoder import (
    ActivationResampler,
    AdamWithReset,
    L2ReconstructionLoss,
    LearnedActivationsL1Loss,
    LossReducer,
    Pipeline,
    PreTokenizedDataset,
    SparseAutoencoder,
)
import wandb


os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = get_device()
print(f"Using device: {device}")  # You will need a GPU

Using device: mps


### Hyperparameters

The way this library works is that you can define your own hyper-parameters and then setup the
underlying components with them. This is extremely flexible, but to help you get started we've
included some common ones below along with some sensible defaults. You can also easily sweep through
multiple hyperparameters with `wandb.sweep`.

In [3]:
torch.random.manual_seed(49)

hyperparameters = {
    # Expansion factor is the number of features in the sparse representation, relative to the
    # number of features in the original MLP layer. The original paper experimented with 1x to 256x,
    # and we have found that 4x is a good starting point.
    "expansion_factor": 4,
    # L1 coefficient is the coefficient of the L1 regularization term (used to encourage sparsity).
    "l1_coefficient": 3e-4,
    # Adam parameters (set to the default ones here)
    "lr": 1e-4,
    "adam_beta_1": 0.9,
    "adam_beta_2": 0.999,
    "adam_epsilon": 1e-8,
    "adam_weight_decay": 0.0,
    # Batch sizes
    "train_batch_size": 4096,
    "context_size": 128,
    # Source model hook point
    "source_model_name": "gelu-2l",
    "source_model_dtype": "float32",
    "source_model_hook_point": "blocks.0.hook_mlp_out",
    "source_model_hook_point_layer": 0,
    # Train pipeline parameters
    "max_store_size": 384 * 4096 * 2,
    "max_activations": 2_000_000_000,
    "resample_frequency": 122_880_000,
    "checkpoint_frequency": 100_000_000,
    "validation_frequency": 384 * 4096 * 2 * 100,  # Every 100 generations
}

### Source Model

The source model is just a [TransformerLens](https://github.com/neelnanda-io/TransformerLens) model
(see [here](https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
for a full list of supported models).

In this example we're training a sparse autoencoder on the activations from the first MLP layer, so
we'll also get some details about that hook point.

In [4]:
# Source model setup with TransformerLens
src_model = HookedTransformer.from_pretrained(
    str(hyperparameters["source_model_name"]), dtype=str(hyperparameters["source_model_dtype"])
)

# Details about the activations we'll train the sparse autoencoder on
autoencoder_input_dim: int = src_model.cfg.d_model  # type: ignore (TransformerLens typing is currently broken)

f"Source: {hyperparameters['source_model_name']}, \
    Hook: {hyperparameters['source_model_hook_point']}, \
    Features: {autoencoder_input_dim}"

Loaded pretrained model gelu-2l into HookedTransformer


'Source: gelu-2l,     Hook: blocks.0.hook_mlp_out,     Features: 512'

### Sparse Autoencoder

We can then setup the sparse autoencoder. The default model (`SparseAutoencoder`) is setup as per
the original Anthropic paper [Towards Monosemanticity: Decomposing Language Models With Dictionary
Learning ](https://transformer-circuits.pub/2023/monosemantic-features/index.html).

However it's just a standard PyTorch model, so you can create your own model instead if you want to
use a different architecture. To do this you just need to extend the `AbstractAutoencoder`, and
optionally the underlying `AbstractEncoder`, `AbstractDecoder` and `AbstractOuterBias`. See these
classes (which are fully documented) for more details.

In [5]:
expansion_factor = hyperparameters["expansion_factor"]
autoencoder = SparseAutoencoder(
    n_input_features=autoencoder_input_dim,  # size of the activations we are autoencoding
    n_learned_features=int(autoencoder_input_dim * expansion_factor),  # size of SAE
).to(device)
autoencoder

SparseAutoencoder(
  (_pre_encoder_bias): TiedBias(position=pre_encoder)
  (_encoder): LinearEncoder(
    in_features=512, out_features=2048
    (activation_function): ReLU()
  )
  (_decoder): UnitNormDecoder(in_features=2048, out_features=512)
  (_post_decoder_bias): TiedBias(position=post_decoder)
)

We'll also want to setup an Optimizer and Loss function. In this case we'll also use the standard
approach from the original Anthropic paper. However you can create your own loss functions and
optimizers by extending `AbstractLoss` and `AbstractOptimizerWithReset` respectively.

In [6]:
# We use a loss reducer, which simply adds up the losses from the underlying loss functions.
loss = LossReducer(
    LearnedActivationsL1Loss(
        l1_coefficient=float(hyperparameters["l1_coefficient"]),
    ),
    L2ReconstructionLoss(),
)
loss

LossReducer(
  (0): LearnedActivationsL1Loss(l1_coefficient=0.0003)
  (1): L2ReconstructionLoss()
)

In [7]:
optimizer = AdamWithReset(
    params=autoencoder.parameters(),
    named_parameters=autoencoder.named_parameters(),
    lr=float(hyperparameters["lr"]),
    betas=(float(hyperparameters["adam_beta_1"]), float(hyperparameters["adam_beta_2"])),
    eps=float(hyperparameters["adam_epsilon"]),
    weight_decay=float(hyperparameters["adam_weight_decay"]),
)
optimizer

AdamWithReset (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0.0
)

Finally we'll initialise an activation resampler.

In [8]:
activation_resampler = ActivationResampler(
    resample_interval=10_000, n_steps_collate=10_000, max_resamples=5
)

### Source dataset

This is just a dataset of tokenized prompts, to be used in generating activations (which are in turn
used to train the SAE).

In [None]:
source_data = PreTokenizedDataset(
    dataset_path="NeelNanda/c4-code-tokenized-2b", context_size=int(hyperparameters["context_size"])
)

## Training

If you initialise [wandb](https://wandb.ai/site), the pipeline will automatically log all metrics to
wandb. However, we should pass in a dictionary with all of our hyperaparameters so they're on 
wandb. 

We strongly encourage users to make use of wandb in order to understand the training process.

In [None]:
checkpoint_path = Path("../../.checkpoints")
checkpoint_path.mkdir(exist_ok=True)

In [None]:
Path(".cache/").mkdir(exist_ok=True)
wandb.init(
    project="sparse-autoencoder",
    dir=".cache",
    config=hyperparameters,
)

In [None]:
pipeline = Pipeline(
    activation_resampler=activation_resampler,
    autoencoder=autoencoder,
    cache_name=str(hyperparameters["source_model_hook_point"]),
    checkpoint_directory=checkpoint_path,
    layer=int(hyperparameters["source_model_hook_point_layer"]),
    loss=loss,
    optimizer=optimizer,
    source_data_batch_size=6,
    source_dataset=source_data,
    source_model=src_model,
)

pipeline.run_pipeline(
    train_batch_size=int(hyperparameters["train_batch_size"]),
    max_store_size=int(hyperparameters["max_store_size"]),
    max_activations=int(hyperparameters["max_activations"]),
    checkpoint_frequency=int(hyperparameters["checkpoint_frequency"]),
    validate_frequency=int(hyperparameters["validation_frequency"]),
)

In [None]:
wandb.finish()