# Reusing the TinyStories SAE

This notebook loads the SAE trained in `tutorial-training-sae.ipynb` (saved under `output/`) and shows how to attach it to the model, peek at sparsity stats, and run a quick comparison with/without SAE reconstruction.

## How to read this notebook
- Load the TinyStories SAE and show where it hooks (`hook_name`, dims, dtype/device).
- Check saved sparsity and a quick prompt loss comparison with/without reconstruction.
- Inspect which SAE features fire on a prompt and how a few features tilt the logits.
- Use the reconstruction metrics cell to spot if a new checkpoint looks off (high MSE or low explained variance).


In [16]:
# Basic imports and a simple device picker.
# Feel free to switch to CPU if you don't have a GPU handy.
from pathlib import Path
import contextlib
import re
import torch
from sae_lens import SAE, HookedSAETransformer
from safetensors.torch import load_file

# Choose a device automatically so the rest of the notebook just works.
device = (
    "cuda" if torch.cuda.is_available() else 
    "mps" if torch.backends.mps.is_available() else 
    "cpu"
)
print(f"Using device: {device}")

sae_dir = Path("output")
assert sae_dir.exists(), "The SAE folder from training should live in ./output"

Using device: cuda


In [17]:
# Load the TinyStories model and the trained SAE from disk.
# HookedSAETransformer is a drop-in replacement for HookedTransformer that knows how to host SAEs.

model = HookedSAETransformer.from_pretrained("tiny-stories-1L-21M", device=device)
model.eval();

sae = SAE.load_from_disk(sae_dir, device=device)
sae.eval();

def tl_hook_name_from_hf(path: str) -> str:
    m = re.fullmatch(r"model\.layers\.(\d+)\.mlp\.down_proj", path)
    if m:
        return f"blocks.{m.group(1)}.hook_mlp_out"
    m = re.fullmatch(r"model\.layers\.(\d+)\.self_attn\.o_proj", path)
    if m:
        return f"blocks.{m.group(1)}.attn.hook_result"
    return path

sae.cfg.metadata.hook_name = tl_hook_name_from_hf(sae.cfg.metadata.hook_name)

print(f"SAE attaches to hook: {sae.cfg.metadata.hook_name}")
print(f"Latent dims: {sae.cfg.d_sae}")
print(f"SAE dtype/device: {sae.dtype} on {sae.device}")


Loaded pretrained model tiny-stories-1L-21M into HookedTransformer
SAE attaches to hook: blocks.0.hook_mlp_out
Latent dims: 16384
SAE dtype/device: torch.float32 on cuda


In [18]:
# Quick look at sparsity statistics that were saved after training.
# The tensor stores log sparsity per feature; exp(log_sparsity) is the probability a feature is active on a token.
log_sparsity = load_file(sae_dir / "sparsity.safetensors")["sparsity"]
est_l0 = torch.exp(log_sparsity).mean().item() * log_sparsity.numel()
print(f"Estimated active features per token (L0): {est_l0:.1f} out of {log_sparsity.numel()}")


Estimated active features per token (L0): 1256.2 out of 16384


In [19]:
# Helper utilities so we can flip the SAE on/off easily.
# The context manager model.saes(...) temporarily patches the model to route activations through the SAE.
def generate_text(prompt: str, with_sae: bool = False, **gen_kwargs) -> str:
    ctx = model.saes(saes=[sae]) if with_sae else contextlib.nullcontext()
    with torch.no_grad(), ctx:
        return model.generate(
            prompt,
            max_new_tokens=80,
            temperature=0.7,
            stop_at_eos=True,
            verbose=False,
            **gen_kwargs,
        )


def loss_on_prompt(prompt: str, with_sae: bool = False) -> float:
    ctx = model.saes(saes=[sae]) if with_sae else contextlib.nullcontext()
    with torch.no_grad(), ctx:
        tokens = model.to_tokens(prompt, prepend_bos=True)
        logits = model(tokens)
        # Compare predicted next tokens to the ground truth tokens.
        loss = torch.nn.functional.cross_entropy(
            logits[0, :-1].reshape(-1, logits.size(-1)),
            tokens[0, 1:].reshape(-1),
        )
    return loss.item()


In [20]:
# Compare generations with and without SAE reconstruction.
prompt = "Once upon a time, a curious robot learned to"

base_story = generate_text(prompt, with_sae=False)
sae_story = generate_text(prompt, with_sae=True)

print("--- Base model ---")
print(base_story)
print("--- With SAE reconstruction ---")
print(sae_story)

base_loss = loss_on_prompt(prompt, with_sae=False)
sae_loss = loss_on_prompt(prompt, with_sae=True)
print(f"Next-token loss without SAE: {base_loss:.3f}")
print(f"Next-token loss with SAE:    {sae_loss:.3f}")


--- Base model ---
Once upon a time, a curious robot learned to behave. He thanked the tank and drove away, but he was still very careful. He never forgot the lesson he learned that day that being careless can be dangerous.

--- With SAE reconstruction ---
Once upon a time, a curious robot learned to be careful and not touch it.

One day, a little girl named Lucy heard a loud noise. She was scared and wanted to go to the park to play. She ran to the swings and grabbed the handle. She pulled it out of the box, and the buttons were very happy.

She put the new toy in the box and went to the park. She had a lot
Next-token loss without SAE: 3.889
Next-token loss with SAE:    3.731


In [21]:
# Peek at which SAE features fire on the last token of the prompt.
# We grab the MLP activations, encode them with the SAE, and show the top activations.
with torch.no_grad():
    tokens = model.to_tokens(prompt, prepend_bos=True)
    _, cache = model.run_with_cache(tokens, names_filter=[sae.cfg.metadata.hook_name])
    mlp_out = cache[sae.cfg.metadata.hook_name]
    feature_acts = sae.encode(mlp_out)

last_token_acts = feature_acts[0, -1]
values, indices = torch.topk(last_token_acts, k=5)
print("Top activated features on the last token:")
for val, idx in zip(values.tolist(), indices.tolist()):
    print(f"  Feature {idx:5d} -> activation {val:.2f}")


Top activated features on the last token:
  Feature  5242 -> activation 5.26
  Feature  1195 -> activation 1.94
  Feature 16084 -> activation 1.73
  Feature  5574 -> activation 1.65
  Feature 14552 -> activation 1.53


In [None]:
# Reconstruction quality on the prompt: how close is SAE decode to the raw MLP activations?
with torch.no_grad():
    tokens = model.to_tokens(prompt, prepend_bos=True)
    _, cache = model.run_with_cache(tokens, names_filter=[sae.cfg.metadata.hook_name])
    mlp_out = cache[sae.cfg.metadata.hook_name]
    feature_acts = sae.encode(mlp_out)
    recon = sae.decode(feature_acts)

    mse = torch.nn.functional.mse_loss(recon, mlp_out).item()
    l2_orig = mlp_out.pow(2).mean().sqrt().item()
    l2_err = (recon - mlp_out).pow(2).mean().sqrt().item()
    explained = 1 - (l2_err**2 / (l2_orig**2 + 1e-9))
    density = (feature_acts.abs() > 1e-6).float().mean().item()

print(
    f"Recon MSE: {mse:.6f}
"
    f"Orig L2 (mean): {l2_orig:.4f}
"
    f"Error L2 (mean): {l2_err:.4f}
"
    f"Explained variance (approx): {explained:.4f}
"
    f"Mean activation density: {density:.4f}"
)


### Reading the metrics
- If `Recon MSE` or `Error L2` spikes compared to past runs, the SAE may be mis-specified or checkpoint is wrong.
- `Explained variance` near 1.0 means the reconstruction matches MLP activations closely; values near 0 indicate poor fit.
- `Mean activation density` captures sparsity; large jumps imply the L1 coefficient or normalization changed.
- Compare the base vs SAE next-token losses to see whether reconstruction is neutral or harmful for the prompt.


In [15]:
# Sample 10 random SAE features and see the logits they produce on the prompt.
# We keep only those features active (zeroing the rest), then swap that reconstruction into the model.
import torch

torch.manual_seed(0)  # make the random feature pick reproducible
num_features = 10

with torch.no_grad():
    tokens = model.to_tokens(prompt, prepend_bos=True)

    # Baseline logits for comparison
    base_logits = model(tokens)

    # Cache the MLP outputs, encode with the SAE, and pick 10 random features
    _, cache = model.run_with_cache(tokens, names_filter=[sae.cfg.metadata.hook_name])
    mlp_out = cache[sae.cfg.metadata.hook_name]
    feature_acts = sae.encode(mlp_out)

    rand_idx = torch.randperm(feature_acts.size(-1), device=device)[:num_features]
    kept_features = torch.zeros_like(feature_acts)
    kept_features[..., rand_idx] = feature_acts[..., rand_idx]

    # Decode just these features back into the MLP space
    recon_from_subset = sae.decode(kept_features)

    # Swap the original MLP output with the reconstruction from those 10 features
    def swap_mlp_out(acts, hook):
        return recon_from_subset

    hooked_logits = model.run_with_hooks(
        tokens, fwd_hooks=[(sae.cfg.metadata.hook_name, swap_mlp_out)]
    )


def show_top(logits, label, k=10):
    probs = logits.softmax(-1)
    vals, idxs = torch.topk(probs, k)
    print(label)
    for v, i in zip(vals.tolist(), idxs.tolist()):
        # tokenizer.decode gives a nice string for the token id
        print(f"  {model.tokenizer.decode([i]):12s} {v:.3f}")
    print()

base_last = base_logits[0, -1]
hooked_last = hooked_logits[0, -1]
print(f"Random feature IDs: {sorted(rand_idx.tolist())}")
show_top(base_last, "Base model last-token distribution:")
show_top(hooked_last, "Only these 10 features reconstructed:")


Random feature IDs: [85, 667, 1304, 1564, 6450, 7750, 9467, 10254, 11155, 15182]
Base model last-token distribution:
   be          0.147
   explore     0.101
   fly         0.083
   play        0.050
   work        0.022
   try         0.019
   spin        0.019
   use         0.017
   make        0.016
   do          0.016

Only these 10 features reconstructed:
   to          0.737
  .            0.096
   and         0.094
   the         0.041
  ,            0.014
   in          0.008
   be          0.002
   with        0.001
   a           0.001
   that        0.001

