# Train a Sparse Autoencoder on Llama 3 8B (single 4090)

This notebook mirrors the SAELens training flow but targets `meta-llama/Meta-Llama-3-8B` on a single 24GB 4090. The defaults aim to be friendly to limited VRAM: mid-layer MLP activations, moderate context length, and a mid-size SAE. Comments are written for ML beginners, with notes on how to downscale if you see OOMs.


In [1]:
# Quick device pick. If CUDA fails, we drop to CPU so the code still runs (but will be slow).
import torch

device = (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)
print(f"Using device: {device}")

# If you are offline or missing weights, install/authorize Hugging Face first:
# !pip install -U "transformers>=4.40" "datasets" "accelerate" bitsandbytes
# !huggingface-cli login


Using device: cuda


In [3]:
# High-level knobs. Defaults shrunk to ease VRAM on a 24GB 4090.
# If you hit OOM, lower d_sae first, then train_batch_size_tokens, then context_size.
hook_layer = 15  # middle layer of Llama 3 8B (32 layers total)
# TransformerLens hook naming (residual before block L updates)
hook_name = f"blocks.{hook_layer}.hook_resid_pre"

# Model dims for Llama 3 8B
hidden_size = 4096  # d_model
mlp_width = 14336   # intermediate size

# SAE size: d_sae controls feature count.
d_in = hidden_size  # residual stream width

d_sae = 16384  # safer default; try 32768 if VRAM allows, 65536 if roomy

# Training shape controls memory directly.
context_size = 256          # shorter contexts are cheaper; try 512 if you have headroom
train_batch_size_tokens = 1024  # tokens per step; reduce to 512 if still OOM, raise to 2048 if comfy
total_training_tokens = 10_000_000  # scale down for quicker smoke test (e.g., 2-3M)

l1_coeff = 2.5  # sparsity strength; try 1.0 (denser) or 5.0 (sparser) to explore


### Hook placement and naming
- TransformerLens (HookedTransformer): residual hooks (`hook_resid_pre/post`), MLP (`hook_mlp_out`), attention (`attn.hook_result`).
- HF module paths (AutoModelForCausalLM) lack residual hooks; use module outputs like `model.layers.L.mlp.down_proj` or map to TL hook names when switching.
- Pick the hook to match your target signal and set `d_in` to that hook's width (4096 for Llama3 8B residual).
This notebook now trains against the TL residual hook (`blocks.<L>.hook_resid_pre`).


In [4]:
# Optional 4-bit quantization to avoid OOM. Set use_4bit=False to disable.
use_4bit = False

hf_load_kwargs = {
    "torch_dtype": torch.bfloat16 if device != "cpu" else torch.float32,
    "device_map": "auto",  # let HF place modules across available devices
}

if use_4bit:
    try:
        import bitsandbytes as bnb  # noqa: F401
        hf_load_kwargs.update(
            {
                "load_in_4bit": True,
                "bnb_4bit_quant_type": "nf4",
                "bnb_4bit_compute_dtype": torch.bfloat16,
                "bnb_4bit_use_double_quant": True,
            }
        )
        print("4-bit enabled via bitsandbytes.")
    except ImportError:
        print(
            "bitsandbytes not installed; falling back to full precision. "
            "Install with `pip install bitsandbytes` if you want 4-bit."
        )
        use_4bit = False
else:
    print("use_4bit is False; loading in full precision.")


use_4bit is False; loading in full precision.


In [5]:
# Build the SAELens runner config. LanguageModelSAETrainingRunner is the maintained entrypoint.
from sae_lens import (
    LanguageModelSAERunnerConfig,
    LanguageModelSAETrainingRunner,
    StandardTrainingSAEConfig,
    LoggingConfig,
)

# Tip: setting dtype to bfloat16 cuts memory roughly in half versus float32 and keeps more stability than fp16.
dtype = 'bfloat16' if device != 'cpu' else 'float32'
from datetime import datetime
run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_path = f'checkpoints/{run_timestamp}_llama3_8b'
output_path = f'runs/{run_timestamp}_llama3_8b'

cfg = LanguageModelSAERunnerConfig(
    # Data + model
    model_name='meta-llama/Meta-Llama-3-8B',
    model_class_name='HookedTransformer',  # use TL model with hook points
    hook_name=hook_name,  # TL hook (residual before block updates)
    dataset_path='monology/pile-uncopyrighted',
    is_dataset_tokenized=False,
    streaming=True,
    dataset_trust_remote_code=True,
    context_size=context_size,
    prepend_bos=True,
    # SAE hyperparameters
    sae=StandardTrainingSAEConfig(
        d_in=d_in,
        d_sae=d_sae,
        apply_b_dec_to_input=False,
        normalize_activations='expected_average_only_in',
        l1_coefficient=l1_coeff,
        l1_warm_up_steps=500,
    ),
    # Training schedule
    train_batch_size_tokens=train_batch_size_tokens,
    training_tokens=total_training_tokens,
    n_batches_in_buffer=16,
    feature_sampling_window=1000,
    dead_feature_window=2000,
    dead_feature_threshold=1e-4,
    lr=3e-4,
    lr_scheduler_name='cosineannealing',
    lr_warm_up_steps=500,
    lr_decay_steps=0,
    adam_beta1=0.9,
    adam_beta2=0.98,
    # Logging/checkpoints
    logger=LoggingConfig(
        log_to_wandb=False,
        wandb_project='sae_lens_llama3_8b',
        wandb_log_frequency=20,
        eval_every_n_wandb_logs=50,
    ),
    n_checkpoints=3,
    checkpoint_path=checkpoint_path,
    output_path=output_path,
    save_final_checkpoint=True,
    # Compute + dtype
    device=device,
    act_store_device='with_model',
    dtype=dtype,
    autocast=True if dtype != 'float32' else False,
    # Model loading hints
    model_from_pretrained_kwargs=hf_load_kwargs,
    seed=42,
)

cfg


  from .autonotebook import tqdm as notebook_tqdm


LanguageModelSAERunnerConfig(sae=StandardTrainingSAEConfig(d_in=4096, d_sae=16384, dtype='float32', device='cpu', apply_b_dec_to_input=False, normalize_activations='expected_average_only_in', reshape_activations='none', metadata=SAEMetadata({'sae_lens_version': '6.22.3', 'sae_lens_training_version': '6.22.3'}), decoder_init_norm=0.1, l1_coefficient=2.5, lp_norm=1.0, l1_warm_up_steps=500), model_name='meta-llama/Meta-Llama-3-8B', model_class_name='HookedTransformer', hook_name='blocks.15.hook_resid_pre', hook_eval='NOT_IN_USE', hook_head_index=None, dataset_path='monology/pile-uncopyrighted', dataset_trust_remote_code=True, streaming=True, is_dataset_tokenized=False, context_size=256, use_cached_activations=False, cached_activations_path=None, from_pretrained_path=None, n_batches_in_buffer=16, training_tokens=10000000, store_batch_size_prompts=32, seqpos_slice=(None,), disable_concat_sequences=False, sequence_separator_token='bos', device='cuda', act_store_device='cuda', seed=42, dtype=

In [6]:
# Kick off training.
# This will download the model + dataset on first run. Expect long runtimes; you can lower total_training_tokens
# to get a quick sanity check SAE in under an hour on a 4090.

runner = LanguageModelSAETrainingRunner(cfg)

# Make the runner config JSON-safe before saving checkpoints (torch.dtype is not JSON serializable)
import torch
runner.cfg.model_from_pretrained_kwargs = {
    k: (str(v) if isinstance(v, torch.dtype) else v) for k, v in hf_load_kwargs.items()
}

sparse_autoencoder = runner.run()

# After training, persist the SAE for reuse.
from pathlib import Path
save_dir = Path(cfg.output_path) / "final_sae"
save_dir.mkdir(parents=True, exist_ok=True)
_ = sparse_autoencoder.save_model(save_dir)
print(f"Saved SAE weights + cfg to {save_dir}")


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.67it/s]


Loaded pretrained model meta-llama/Meta-Llama-3-8B into HookedTransformer


Training SAE:   0%|          | 0/10000000 [00:00<?, ?it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Estimating norm scaling factor: 100%|██████████| 1000/1000 [01:34<00:00, 10.62it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 23.52 GiB of which 182.44 MiB is free. Process 2963 has 488.00 MiB memory in use. Including non-PyTorch memory, this process has 22.81 GiB memory in use. Of the allocated memory 21.74 GiB is allocated by PyTorch, and 629.18 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Tips for adjusting to your hardware

- If you OOM:
  - Drop `d_sae` to 16384, then lower `train_batch_size_tokens` (e.g., 1024) or `context_size` (256).
  - Set `model_class_name="AutoModelForCausalLM"` and add `"device_map": "auto", "load_in_4bit": True` inside `model_from_pretrained_kwargs` to offload with bitsandbytes (slower but lighter).
- If training is too slow, lower `total_training_tokens` for a quick-and-dirty SAE, then rerun longer later.
- To target a different layer, change `hook_layer` and adjust `d_in` (residual streams are `hidden_size`, MLP outs use the intermediate size).
- For crisper sparsity, raise `l1_coeff`; for denser features, lower it and maybe increase `d_sae`.


### References (latest SAELens docs)
- API docs: <https://decoderesearch.github.io/SAELens/latest/api/#sae_lens.LanguageModelSAETrainingRunner>
- Training guide overview: <https://decoderesearch.github.io/SAELens/latest/training_saes/>
- Runner basics: <https://decoderesearch.github.io/SAELens/latest/training_saes/#basic-training-setup>
- Checkpointing: <https://decoderesearch.github.io/SAELens/latest/training_saes/#checkpoints>

These pages (v6.22.3) confirm the current entrypoint is `LanguageModelSAETrainingRunner` and describe logger, scheduler, and checkpoint fields. If the docs update, search the site index (same domain) for "training_saes" or "LanguageModelSAETrainingRunner" to grab the newest recommendations.


**Dataset note:** `togethercomputer/RedPajama-Data-1T` requires a config name (e.g., `default`, `book`, `c4`), and the current SAELens runner doesn't expose a config field. Switched to `monology/pile-uncopyrighted`, which loads without a config. If you want RedPajama instead, edit `dataset_path` and also add `name="default"` via a custom dataset loader or pre-download and pass `override_dataset` when constructing the runner.