# Tutorial 1: SAE Basics - The Prism for Language Models

**Learning Objectives:**
- Understand what Sparse Autoencoders (SAEs) are and why they matter
- Load a pre-trained SAE and language model
- Extract feature activations from text
- Decode features to understand what they represent

**Estimated Time:** 10-15 minutes

---


## What is a Sparse Autoencoder?

Imagine a language model as a complex machine that processes text. Inside this machine, information flows through many layers, but we can't easily see *what* the model is thinking about at each step.

A **Sparse Autoencoder (SAE)** acts like a **prism** that breaks down the model's internal representations into interpretable features. Just as a prism splits white light into distinct colors, an SAE decomposes the model's activations into meaningful components.

### The Prism Metaphor

```
Text Input → Language Model → Dense Activations → SAE → Sparse Features
   "Paris"        (Gemma-2)      [2048 numbers]    (Prism)   [16k features]
                                                              Feature #1234: "French cities"
                                                              Feature #5678: "European capitals"
                                                              Feature #9012: "Tourist destinations"
```

Each feature represents a specific concept or pattern the model has learned. By analyzing which features activate, we can understand what the model is "thinking."

---


## Setup: Import Libraries

We'll use:
- `torch`: For tensor operations
- `transformer_lens`: For loading language models with easy activation access
- `sae_lens`: For loading pre-trained SAEs
- `matplotlib`: For visualization


In [1]:
import torch
from transformer_lens import HookedTransformer
from sae_lens import SAE
import matplotlib.pyplot as plt
import json


## Load the Model and SAE

We'll use:
- **Model:** Gemma-2-2b (a 2-billion parameter language model)
- **SAE:** GemmaScope layer 5, 16k features (trained on Gemma-2's activations)

This will download models (~5GB) on first run. Subsequent runs will use cached versions.


In [2]:
# Auto-detect device (Apple Silicon MPS, CUDA, or CPU)
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# Load SAE
print("\nLoading SAE...")
sae = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_5/width_16k/canonical",
    device=device
)

# Load language model
print("Loading Gemma-2-2b model...")
model = HookedTransformer.from_pretrained("gemma-2-2b", device=device)

print("\n✓ Models loaded successfully!")


Using device: mps

Loading SAE...


Loading Gemma-2-2b model...


Loaded pretrained model gemma-2-2b into HookedTransformer

✓ Models loaded successfully!


## Extract Features from Simple Text

Let's analyze a simple sentence: **"The cat sat on the mat"**

We'll:
1. Run the text through the model
2. Extract activations from layer 5
3. Apply the SAE to get sparse features
4. See which features activate


In [3]:
# Simple example text
text = "The cat sat on the mat"

# Step 1: Convert text to tokens
tokens = model.to_tokens(text)
print(f"Text: '{text}'")
print(f"Tokens: {model.to_str_tokens(tokens[0])}")

# Step 2: Run model and get activations
_, cache = model.run_with_cache(tokens)
activations = cache["blocks.5.hook_resid_post"][0, -1, :]  # Last token's activation

print(f"\nActivation shape: {activations.shape}")
print(f"This is a dense vector with {activations.shape[0]} dimensions")

# Step 3: Apply SAE (the prism!)
activations_2d = activations.unsqueeze(0)  # Add batch dimension
feature_acts = sae.encode(activations_2d).squeeze()

print(f"\nFeature activations shape: {feature_acts.shape}")
print(f"The SAE has {feature_acts.shape[0]} features")

# Step 4: Filter for active features (magnitude > 0)
active_mask = feature_acts > 0
active_indices = torch.nonzero(active_mask).squeeze()
active_magnitudes = feature_acts[active_indices]

print(f"\n✓ Found {len(active_indices)} active features")
print(f"  Total activation energy: {active_magnitudes.sum().item():.3f}")


# Decode top 3 features
def decode_feature(feature_id, top_k=5):
    feature_direction = sae.W_dec[feature_id]
    logits = model.unembed(feature_direction)
    top_token_ids = logits.argsort(descending=True)[:top_k]
    top_words = model.to_str_tokens(top_token_ids)
    return top_words

top_3_indices = active_indices[torch.argsort(active_magnitudes, descending=True)[:3]]
print("\nTop 3 Active Features:")
for i, feat_id in enumerate(top_3_indices.tolist(), 1):
    magnitude = feature_acts[feat_id].item()
    words = decode_feature(feat_id, top_k=5)
    print(f"{i}. Feature #{feat_id} (magnitude: {magnitude:.3f})")
    print(f"   Promotes: {', '.join(words)}")


Text: 'The cat sat on the mat'
Tokens: ['<bos>', 'The', ' cat', ' sat', ' on', ' the', ' mat']

Activation shape: torch.Size([2304])
This is a dense vector with 2304 dimensions

Feature activations shape: torch.Size([16384])
The SAE has 16384 features

✓ Found 91 active features
  Total activation energy: 406.055

Top 3 Active Features:
1. Feature #14203 (magnitude: 29.506)
   Promotes:  mat, mat,  Mat, Mat,  MAT
2. Feature #11477 (magnitude: 17.740)
   Promotes:  kneeling,  kneel,  knees,  knelt,  prostrate
3. Feature #697 (magnitude: 14.543)
   Promotes: ',  , ’,  […], 



## Key Takeaways

In this tutorial, you learned:

1. **SAEs as Prisms:** SAEs decompose dense model activations into interpretable sparse features
2. **Feature Extraction:** How to run text through a model and extract SAE features
3. **Feature Decoding:** How to interpret features by projecting them onto the vocabulary
4. **Sparsity:** Only a small fraction of features activate for any given text (~100-300 out of 16,384)

### What's Next?

In **Tutorial 2: Feature Extraction**, we'll:
- Use the reusable functions from `hallucination_detector` package
- Compare features between different texts
- Identify unique and shared features
- Set the foundation for hallucination detection

---

### Further Exploration

Try modifying the `text` variable above with different sentences:
- "Paris is the capital of France"
- "The dog barked loudly"
- "Machine learning is fascinating"

Observe how different texts activate different features!
