In [2]:
import torch
import os

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: mps


In [4]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained(
    "tiny-stories-1L-21M"
)



Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


In [5]:
# here we use model.generate to get 10 completeions with temperature 1.
for i in range(5):
    display(
        model.generate(
            "Once upon a time",
            stop_at_eos=False,  # avoids a bug on MPS
            temperature=1,
            verbose=False,
            max_new_tokens=50,
        )
    )

'Once upon a time, there was a swan. The swan was very big and beautiful, white, with fluffy feathers. But one day, someone was very frustrated because the swan could not fly. \nThe next day, it met a man who was'

'Once upon a time there was a new, attractive bird in a bright blue sky. He had seen many very many birds, but none of them knew how to fly.\n\nSo he went to the navy blue bird very carefully to follow it. He followed the white'

'Once upon a time, there was a little girl named Lily. Lily was feeling very sorry and so she went to her room to become more careful. She knocked on the door and the big door opened. \n\nIn the dark attic, Lily saw many things that'

'Once upon a time, there was a girl who lived in a big house. She was three years old and loved to play outside. One day she wanted to go outside to play. So she took off her shirt and twirled around until it was very high. '

'Once upon a time, there was a bird. He had soft fur and he loved to fly and sing. One day, he found an folder by the river. He opened it and saw a treasure chest. But it was too heavy for him.\n\nThe hipp'

In [6]:
from transformer_lens.utils import test_prompt

# Test the model with a prompt
test_prompt(
    "Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,",
    " Lily",
    model,
    prepend_space_to_answer=False,
)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' girl', '.', ' On', ' her', ' big', ' adventure', ',']
Tokenized answer: [' Lily']


Top 0th token. Logit: 20.48 Prob: 71.06% Token: | she|
Top 1th token. Logit: 18.81 Prob: 13.46% Token: | Lily|
Top 2th token. Logit: 17.35 Prob:  3.11% Token: | the|
Top 3th token. Logit: 17.26 Prob:  2.86% Token: | her|
Top 4th token. Logit: 16.74 Prob:  1.70% Token: | there|
Top 5th token. Logit: 16.43 Prob:  1.25% Token: | they|
Top 6th token. Logit: 15.80 Prob:  0.66% Token: | all|
Top 7th token. Logit: 15.64 Prob:  0.56% Token: | things|
Top 8th token. Logit: 15.28 Prob:  0.39% Token: | one|
Top 9th token. Logit: 15.24 Prob:  0.38% Token: | lived|


In [7]:
import circuitsvis as cv  # optional dep, install with pip install circuitsvis

# Let's make a longer prompt and see the log probabilities of the tokens
example_prompt = """Hi, how are you doing this? I'm really enjoying your posts"""
logits, cache = model.run_with_cache(example_prompt)
cv.logits.token_log_probs(
    model.to_tokens(example_prompt),
    model(example_prompt)[0].log_softmax(dim=-1),
    model.to_string,
)
# hover on the output to see the result.

In [8]:
example_prompt = model.generate(
    "Once upon a time",
    stop_at_eos=False,  # avoids a bug on MPS
    temperature=1,
    verbose=True,
    max_new_tokens=200,
)
logits, cache = model.run_with_cache(example_prompt)
cv.logits.token_log_probs(
    model.to_tokens(example_prompt),
    model(example_prompt)[0].log_softmax(dim=-1),
    model.to_string,
)

100%|█████████████████████████████████████████| 200/200 [00:36<00:00,  5.55it/s]


In [9]:
total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1L-21M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=0,  # Only one layer in the model.
    d_in=1024,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="sae_lens_tutorial",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32"
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()

Run name: 16384-L1-5-LR-5e-05-Tokens-1.229e+08
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 30000
Total wandb updates: 1000
n_tokens_per_feature_sampling_window (millions): 2097.152
n_tokens_per_dead_feature_window (millions): 2097.152
We will reset the sparsity calculation 30 times.
Number tokens in sparsity calculation window: 4.10e+06
Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


Downloading readme: 100%|██████████████████████| 415/415 [00:00<00:00, 7.39kB/s]
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/shiyang/.netrc


  self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast)
Training SAE:   0%|                               | 0/122880000 [00:00<?, ?it/s]
Estimating norm scaling factor:   0%|                  | 0/1000 [01:15<?, ?it/s][A


RuntimeError: Invalid buffer size: 8.00 GB

In [None]:
# Upload trained SAEs to Huggingface
from sae_lens import SAE

# Load some SAEs from the gpt2-small-res-jb release
# In practice you'll want to train your own SAEs to upload to huggingface
layer_0_sae = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.0.hook_resid_pre",
    device="cpu",
)[0]
layer_1_sae = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.1.hook_resid_pre",
    device="cpu",
)[0]

In [None]:
from sae_lens import upload_saes_to_huggingface

layer_1_sae_path = "layer_1_sae"
layer_1_sae.save_model(layer_1_sae_path)

saes_dict = {
    "blocks.0.hook_resid_pre": layer_0_sae, # values can be an SAE object
    "blocks.1.hook_resid_pre": layer_1_sae_path, # or a path to a saved SAE
}

upload_saes_to_huggingface(
    saes_dict,
    # change this to your own huggingface username and repo
    hf_repo_id="chanind/sae-lens-upload-demo",
)

In [None]:
# Loading our uploaded SAEs

uploaded_sae_layer_0 = SAE.from_pretrained(
    release="chanind/sae-lens-upload-demo",
    sae_id="blocks.0.hook_resid_pre",
    device="cpu",
)[0]

uploaded_sae_layer_1 = SAE.from_pretrained(
    release="chanind/sae-lens-upload-demo",
    sae_id="blocks.1.hook_resid_pre",
    device="cpu",
)[0]