# Sparse Autoencoder Training Demo

In order to train a sparse autoencoder, we need:
1. A model + a layer of that model on which we want to train our autoencoder.
2. A dataset which we can use to train generate the activations. 
3. An autoencoder to be trained. 

To demonstrate how to train a sparse autoencoder, this notebook shows how to train a sparse
autoencoder on the (Tiny-Stories-1M model)[https://huggingface.co/roneneldan/TinyStories-1M].

To do so, we make use of the (tiny stories dataset)[https://huggingface.co/datasets/roneneldan/TinyStories].

To view other models we can load with hooked transformer, see this (page)[https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html] in the TransformerLens docs.




## Setup

### Imports

In [1]:
import os  # noqa: D100

import torch
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
from transformers import PreTrainedTokenizerBase

from sparse_autoencoder import SparseAutoencoder, TensorActivationStore, pipeline
from sparse_autoencoder.source_data.text_dataset import GenericTextDataset
from sparse_autoencoder.train.sweep_config import SweepParametersRuntime


# Autoreload
%load_ext autoreload
%autoreload 2



os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = get_device()
print(f"Using device: {device}") # You will need a GPU

ImportError: cannot import name 'SparseAutoencoder' from 'sparse_autoencoder' (unknown location)

### Source Model and AutoEncoder

In [2]:
src_model_name = "tiny-stories-1M"
src_model = HookedTransformer.from_pretrained(src_model_name, dtype="float32")
src_d_mlp: int = src_model.cfg.d_mlp  # type: ignore

src_model_activation_hook_point = "blocks.0.mlp.hook_post" # We choose to find features in the output of the first MLP layer.
src_model_activation_layer = 0 # This is the layer index of the layer we are hooking into. Possibly can detect by defautl.

print(f"Source model name: {src_model_name}")
print(f"Source model activation hook point: {src_model_activation_hook_point}")
print(f"Source model d_mlp: {src_d_mlp}") # We need the dimension of the activations we are autoencoding. 

# We can then instantiate the autoencoder
expansion_ratio = 8
autoencoder = SparseAutoencoder(
    n_input_features = src_d_mlp,  # size of the activations we are autoencoding
    n_learned_features = src_d_mlp * expansion_ratio, # size of SAE
    geometric_median_dataset = torch.zeros(src_d_mlp) # this is used to initialize the tied bias
)

[autoreload of sparse_autoencoder.activation_store.base_store failed: Traceback (most recent call last):
  File "/home/francesco/anaconda3/envs/spaut/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/francesco/anaconda3/envs/spaut/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
  File "/home/francesco/anaconda3/envs/spaut/lib/python3.10/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 619, in _exec
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/home/francesco/Repository/sparse_autoencoder/sparse_autoencoder/activation_store/base_store.py", line 9, in <module>
    from sparse_autoencoder.tensor_types import InputOutputActivationBatch, InputOu

config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/48.6M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/722 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model tiny-stories-1M into HookedTransformer
Source model name: tiny-stories-1M
Source model activation hook point: blocks.0.mlp.hook_post
Source model d_mlp: 256


### Source Dataset and Activation Store

In [None]:
tokenizer: PreTrainedTokenizerBase = src_model.tokenizer  # type: ignore

# We've implemented a generic wrapper around huggingface datasets.
# We'll use the training data for the Tiny Stories model. 
source_data = GenericTextDataset(tokenizer=tokenizer, dataset_path = "roneneldan/TinyStories") 

# In practice, we load and shuffle data from the dataset. 
# This is to ensure mixing of the data / prevent overfitting
# Optimal/feasible Max Items will depend on your GPU memory.
max_items = 1_000_000
total_training_tokens = 10_000_000
store = TensorActivationStore(max_items, src_d_mlp, device)

## Training Hyperparameters 

In [None]:
from sparse_autoencoder.train.sweep_config import SweepParametersRuntime


# Some of the training hyperparameters are passed through in the sweep parameters.
# The important thing is to set l1 high enough to get sparsity (eventually),
# without compromising the reconstruction loss too much.
# Having a large batch size is important too.
training_hyperparameters = SweepParametersRuntime(
    lr = 0.001, # This is the learning rate
    l1_coefficient = 0.001, # This is the coefficient for the L1 regularization
    batch_size = 4096, # important that this be quite large.

    # Adam Parameters (don't usually need to change these)
    adam_beta_1 = 0.9,
    adam_beta_2 = 0.999,
    adam_epsilon = 1e-8,
    adam_weight_decay = 0.0
)

###

## Training

If you initialise [wandb](https://wandb.ai/site), the pipeline will automatically log all metrics to
wandb. However, we should pass in a dictionary with all of our hyperaparameters so they're on 
wandb. 

We strongly encourage users to make use of wandb in order to understand the training process.

In [None]:
 config = {
        # Data Params
        "model_name": src_model_name,
        "hook_point": src_model_activation_hook_point,
        "src_model_activation_hook_point": src_model_activation_hook_point,
        "src_model_activation_layer": src_model_activation_layer,

        # SAE params
        "activation_width": src_d_mlp,
        "expansion_ratio": expansion_ratio,

        # Training params
        "max_items": max_items,
        "training_tokens": total_training_tokens,

        # other
        "device": device,
    }

# add training hyperparameters to config
config = config | training_hyperparameters.__dict__
config

In [None]:
# skip if you want.
import wandb
wandb.init(
    project="sparse-autoencoder",
    dir=".cache/wandb",
    name="demo",
    config=config,
)

In [None]:
pipeline(
    src_model=src_model,
    src_model_activation_hook_point=src_model_activation_hook_point,
    src_model_activation_layer=src_model_activation_layer,
    source_dataset=source_data,
    activation_store=store,
    num_activations_before_training=max_items,
    autoencoder=autoencoder,
    device=device,
    max_activations=total_training_tokens,
    sweep_parameters=training_hyperparameters,
)

In [None]:
wandb.finish()

## Training Advice

-- Unfinished --

- Check recovery loss is low while sparsity is low as well (<20 L1) usually.
- Can't be sure features are useful until you dig into them more. 

# Analysis

-- Unfinished --