# Sparse Autoencoder Training Demo

## Setup

In [None]:
# Autoreload
# %load_ext autoreload
# %autoreload 2

from sparse_autoencoder import (
    SparseAutoencoder,
    TensorActivationStore,
    pipeline,
    create_src_dataloader,
)
from transformer_lens import HookedTransformer
from sparse_autoencoder.src_data.datasets.neel_c4_tokenized import (
    collate_neel_c4_tokenized,
)
import torch
import wandb
import os

### Source Dataset

In [None]:
src_dataloader = create_src_dataloader(
    "NeelNanda/c4-code-tokenized-2b",
    collate_fn=collate_neel_c4_tokenized,
    shuffle_buffer_size=10_000,
    random_seed=0,
)

### Source Model

In [None]:
src_model = HookedTransformer.from_pretrained("solu-1l")
src_d_mlp = src_model.cfg.d_mlp

### Activation Store

In [None]:
max_items = 10_000
store = TensorActivationStore(max_items, src_d_mlp, torch.device("mps"))

### Autoencoder

In [None]:
autoencoder = SparseAutoencoder(src_d_mlp, src_d_mlp * 8, torch.zeros(src_d_mlp))

## Training

In [None]:
# Disable TOKENIZERS_PARALLELISM warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
pipeline(
    src_model=src_model,
    src_model_activation_hook_point="blocks.0.mlp.hook_post",
    src_model_activation_layer=0,
    src_dataloader=src_dataloader,
    activation_store=store,
    num_activations_before_training=max_items,
    autoencoder=autoencoder,
)