In [None]:
try:
    # import google.colab # type: ignore
    # from google.colab import output
    %pip install sae-lens transformer-lens
except:
    from IPython import get_ipython  # type: ignore

    ipython = get_ipython()
    assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

In [None]:
import torch
import os

from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
from kaggle_secrets import UserSecretsClient
import wandb
import huggingface_hub
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_api_key")
wandb.login(key=secret_value_0)
huggingface_api = ""
huggingface_hub.login(token=huggingface_api)

In [None]:
import torch
import os
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner, upload_saes_to_huggingface

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def train(layer, k, lr, upload = True):
    total_training_steps = 10_000
    batch_size = 4096
    total_training_tokens = total_training_steps * batch_size

    lr_warm_up_steps = 0
    lr_decay_steps = total_training_steps // 5
    # l1_warm_up_steps = total_training_steps // 20

    cfg = LanguageModelSAERunnerConfig(
        architecture = "topk",
        model_name="gpt2-small",
        hook_name="blocks." + str(layer) + ".hook_mlp_out",
        hook_layer=layer,
        d_in=768,
        dataset_path="apollo-research/Skylion007-openwebtext-tokenizer-gpt2",
        is_dataset_tokenized=True,
        streaming=True,
        mse_loss_normalization=None,
        expansion_factor=16,
        b_dec_init_method="zeros",
        apply_b_dec_to_input=False,
        normalize_sae_decoder=False,
        scale_sparsity_penalty_by_decoder_norm=True,
        decoder_heuristic_init=True,
        init_encoder_as_decoder_transpose=True,
        normalize_activations="expected_average_only_in",
        lr=lr,
        adam_beta1=0.9,
        adam_beta2=0.999,
        lr_scheduler_name="constant",
        lr_warm_up_steps=lr_warm_up_steps,
        lr_decay_steps=lr_decay_steps,
        activation_fn_kwargs = {"k": k},
        lp_norm=1.0,
        train_batch_size_tokens=batch_size,
        context_size=256,
        n_batches_in_buffer=64,
        training_tokens=total_training_tokens,
        store_batch_size_prompts=16,
        use_ghost_grads=False,
        feature_sampling_window=1000,
        dead_feature_window=1000,
        dead_feature_threshold=1e-4,
        log_to_wandb=True,
        wandb_project="gpt2small-mlp-out-saes",
        wandb_log_frequency=30,
        eval_every_n_wandb_logs=20,
        device=device,
        seed=42,
        n_checkpoints=0,
        checkpoint_path="checkpoints",
        dtype="float32",
    )

    # Start training
    sparse_autoencoder = SAETrainingRunner(cfg).run()
    if upload:
        path = os.path.join(cfg.checkpoint_path, os.listdir(cfg.checkpoint_path)[0])
        sae_dict = {
            cfg.hook_name: path
        }
        upload_saes_to_huggingface(
            saes_dict=sae_dict,
            hf_repo_id="anhtu77/sae-topk-32-gpt2-small",
        )
        print(f"Uploaded SAE for layer {layer} to Hugging Face Hub.")
    del sparse_autoencoder

def main():
    # Train the model with different layers
    layers = [7]
    ks = [16, 64]
    lrs = [5e-5]
    for k in ks:
        for lr in lrs:
            for layer in layers:
                print(f"Training layer {layer}...")
                train(layer, k, lr, upload = False)
                print(f"Finished training layer {layer}.")

main()