# Notebook with Example Config for Different Models / Hooks

# Warning: This notebook is a WIP and may not reflect current valid / optimal hyperparameters.
# We are hoping to provide more serious training examples / advice soon.

## Setup

In [None]:
import torch
import os
import sys

sys.path.append("..")

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Tiny Stories - 1L

## MLP Out

In [None]:
cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1L-21M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_point="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_point_layer=0,  # Only one layer in the model.
    d_in=1024,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="geometric_median",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder to the input.
    # Training Parameters
    lr=0.0008,  # lower the better, we'll go fairly high to speed up the tutorial.
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=10000,  # this can help avoid too many dead features initially.
    l1_coefficient=0.0015,  # will control how sparse the feature activations are
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size=4096,
    context_size=128,  # will control the lenght of the prompts we feed to the model. Larger is better but slower.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=1_000_000
    * 25,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    finetuning_method="decoder",
    finetuning_tokens=1_000_000 * 25,
    store_batch_size=32,
    # Resampling protocol
    use_ghost_grads=False,
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="sae_lens_tutorial",
    wandb_log_frequency=10,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder_dictionary = language_model_sae_runner(cfg)

# GPT2 - Small

### Residual Stream

In [None]:
cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="gpt2-small",
    hook_point="blocks.8.hook_resid_pre",
    hook_point_layer=8,
    d_in=768,
    dataset_path="apollo-research/Skylion007-openwebtext-tokenizer-gpt2",
    is_dataset_tokenized=True,
    prepend_bos=True,  # should experiment with turning this off.
    # SAE Parameters
    expansion_factor=32,  # determines the dimension of the SAE.
    b_dec_init_method="geometric_median",  # geometric median is better but slower to get started
    apply_b_dec_to_input=False,
    # Training Parameters
    adam_beta1=0,
    adam_beta2=0.999,
    lr=0.0004,
    l1_coefficient=0.008,
    lr_scheduler_name="constant",
    train_batch_size=4096,
    context_size=256,
    lr_warm_up_steps=5000,
    # Activation Store Parameters
    n_batches_in_buffer=128,
    training_tokens=1_000_000 * 200,  # 200M tokens seems doable overnight.
    finetuning_method="decoder",
    finetuning_tokens=1_000_000 * 100,
    store_batch_size=32,
    # Resampling protocol
    use_ghost_grads=False,
    feature_sampling_window=2500,
    dead_feature_window=5000,
    dead_feature_threshold=1e-8,
    # WANDB
    log_to_wandb=True,
    wandb_project="gpt2_small_experiments_april",
    wandb_entity=None,
    wandb_log_frequency=100,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=5,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

sparse_autoencoder = language_model_sae_runner(cfg)