# Lab C.4: Feature Extraction with SAEs

**Module:** C - Mechanistic Interpretability  
**Time:** 2.5 hours  
**Difficulty:** ⭐⭐⭐⭐⭐ (Expert)

---

## Learning Objectives

By the end of this notebook, you will:
- [ ] Understand why sparse autoencoders (SAEs) are powerful for interpretability
- [ ] Implement a simple SAE architecture
- [ ] Train an SAE on model activations
- [ ] Extract and interpret learned features
- [ ] Attempt activation steering using discovered features

---

## Prerequisites

- Completed: Labs C.1-C.3
- Knowledge of: Attention patterns, residual stream, activation patching

---

## Real-World Context

**The Superposition Problem**: Neural networks pack way more concepts into their neurons than there are neurons! This is called *superposition* - a single neuron might respond to "dogs", "wheels", and "the color blue" simultaneously.

**Sparse Autoencoders (SAEs)** are a breakthrough technique for "unpacking" these superposed representations into individual, interpretable features. Anthropic's work on [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features/index.html) showed SAEs can find millions of interpretable features in Claude.

**Applications:**
- Finding safety-relevant features (deception, manipulation)
- Understanding model reasoning
- Steering model behavior precisely

---

## ELI5: What are Sparse Autoencoders?

> **Imagine a radio signal** that's a mix of many songs playing at once. Each song is a "feature", and they're all superposed (overlapping) in the signal.
>
> **A regular listener** hears a jumbled mess - the raw signal is hard to interpret.
>
> **A Sparse Autoencoder** is like a magical music separator:
> 1. It takes the mixed signal
> 2. Expands it into many "channels" (more than the original dimensions)
> 3. Forces most channels to be silent (sparsity)
> 4. Each active channel plays ONE song clearly!
>
> **For neural networks:**
> - Input: 768-dimensional residual stream (mixed signals)
> - Expand to: 768 × 4 = 3072 features (many channels)
> - Sparsity: Only ~10-50 features active at once
> - Output: Each feature represents ONE concept ("Python", "negative sentiment", "question", etc.)

```
Superposed Activations      SAE Features
      [Mixed]         →    [Dog][Code][French][ ][ ][Happy]
     (768-dim)              (3072-dim, mostly zeros)
```

---

## ELI5: Why Sparsity Matters

> **Without sparsity**: The SAE just learns to copy inputs. Every feature fires a little, giving us nothing interpretable.
>
> **With sparsity**: The SAE is forced to "choose" which few features to activate. This encourages each feature to specialize in one specific concept!
>
> **The math**: We add an L1 penalty: `loss = reconstruction_error + λ * ||features||₁`
>
> L1 pushes features toward zero, so only the most important ones survive!

---

## Part 1: Setup

In [None]:
# Core imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from tqdm.auto import tqdm
import gc

# TransformerLens
from transformer_lens import HookedTransformer, utils

# Visualization
import plotly.express as px
import plotly.graph_objects as go

# Settings
%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Load model
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    device=device
)
model.eval()

print(f"Model loaded: {model.cfg.n_layers} layers, d_model={model.cfg.d_model}")

---

## Part 2: Implementing a Sparse Autoencoder

Let's build an SAE from scratch!

In [None]:
class SparseAutoencoder(nn.Module):
    """
    Sparse Autoencoder for extracting interpretable features.
    
    Architecture:
    - Input: residual stream activations [batch, d_model]
    - Encoder: Linear + ReLU → sparse feature activations [batch, n_features]
    - Decoder: Linear → reconstructed activations [batch, d_model]
    
    The encoder expansion (n_features > d_model) allows representing more
    features than dimensions. Sparsity ensures only a few activate at once.
    """
    
    def __init__(
        self,
        d_model: int,
        n_features: int,
        tied_weights: bool = True
    ):
        """
        Args:
            d_model: Dimension of input activations (e.g., 768 for GPT-2 Small)
            n_features: Number of sparse features (typically 2-8x d_model)
            tied_weights: If True, decoder weights = encoder weights transposed
        """
        super().__init__()
        
        self.d_model = d_model
        self.n_features = n_features
        self.tied_weights = tied_weights
        
        # Encoder: d_model → n_features
        self.encoder = nn.Linear(d_model, n_features, bias=True)
        
        # Decoder: n_features → d_model
        if tied_weights:
            # Decoder shares weights with encoder (transposed)
            self.decoder_bias = nn.Parameter(torch.zeros(d_model))
        else:
            self.decoder = nn.Linear(n_features, d_model, bias=True)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights using Kaiming initialization."""
        nn.init.kaiming_uniform_(self.encoder.weight, nonlinearity='relu')
        nn.init.zeros_(self.encoder.bias)
        
        if not self.tied_weights:
            nn.init.kaiming_uniform_(self.decoder.weight, nonlinearity='linear')
            nn.init.zeros_(self.decoder.bias)
    
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode input to sparse features.
        
        Args:
            x: Input activations [batch, d_model]
        
        Returns:
            Sparse feature activations [batch, n_features]
        """
        return F.relu(self.encoder(x))  # ReLU ensures non-negative, sparse
    
    def decode(self, features: torch.Tensor) -> torch.Tensor:
        """
        Decode features back to original space.
        
        Args:
            features: Sparse features [batch, n_features]
        
        Returns:
            Reconstructed activations [batch, d_model]
        """
        if self.tied_weights:
            # Use encoder weight transposed
            return F.linear(features, self.encoder.weight.T, self.decoder_bias)
        else:
            return self.decoder(features)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Full forward pass.
        
        Returns:
            (reconstructed, features) tuple
        """
        features = self.encode(x)
        reconstructed = self.decode(features)
        return reconstructed, features
    
    def compute_loss(
        self,
        x: torch.Tensor,
        sparsity_coef: float = 1e-3
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        Compute SAE loss with sparsity penalty.
        
        Loss = MSE(x, reconstructed) + sparsity_coef * L1(features)
        
        Args:
            x: Input activations
            sparsity_coef: Weight for L1 sparsity penalty
        
        Returns:
            (total_loss, loss_dict) tuple
        """
        reconstructed, features = self.forward(x)
        
        # Reconstruction loss (MSE)
        recon_loss = F.mse_loss(reconstructed, x)
        
        # Sparsity loss (L1)
        sparsity_loss = features.abs().mean()
        
        # Total loss
        total_loss = recon_loss + sparsity_coef * sparsity_loss
        
        # Compute sparsity statistics
        with torch.no_grad():
            sparsity = (features > 0).float().mean().item()  # Fraction of non-zero
            n_active = (features > 0).sum(dim=-1).float().mean().item()  # Avg active features
        
        loss_dict = {
            'total': total_loss.item(),
            'recon': recon_loss.item(),
            'sparsity': sparsity_loss.item(),
            'frac_nonzero': sparsity,
            'n_active': n_active
        }
        
        return total_loss, loss_dict

# Test SAE
sae = SparseAutoencoder(
    d_model=model.cfg.d_model,  # 768
    n_features=model.cfg.d_model * 4,  # 3072 features
    tied_weights=True
).to(device)

print(f"SAE created:")
print(f"  Input dim: {sae.d_model}")
print(f"  Features: {sae.n_features}")
print(f"  Expansion: {sae.n_features / sae.d_model}x")
print(f"  Parameters: {sum(p.numel() for p in sae.parameters()):,}")

In [None]:
# Test forward pass
test_input = torch.randn(32, model.cfg.d_model, device=device)
reconstructed, features = sae(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Features shape: {features.shape}")
print(f"Reconstructed shape: {reconstructed.shape}")
print(f"\nFeature statistics:")
print(f"  Min: {features.min().item():.4f}")
print(f"  Max: {features.max().item():.4f}")
print(f"  Fraction > 0: {(features > 0).float().mean().item():.2%}")

---

## Part 3: Collecting Training Data

We need diverse activations from the model to train our SAE.

In [None]:
# Sample prompts for diverse activations
SAMPLE_PROMPTS = [
    # Factual
    "The capital of France is Paris.",
    "Water boils at 100 degrees Celsius.",
    "The Earth orbits around the Sun.",
    
    # Code
    "def fibonacci(n):\n    if n <= 1:",
    "import numpy as np\narray = np.zeros((10, 10))",
    "for i in range(10):\n    print(i)",
    
    # Questions
    "What is the meaning of life?",
    "How do neural networks learn?",
    "Why is the sky blue?",
    
    # Stories
    "Once upon a time, there was a brave knight who",
    "In a galaxy far, far away,",
    "The detective examined the mysterious letter carefully.",
    
    # Emotions
    "I am so happy today!",
    "This is absolutely terrible.",
    "I feel uncertain about the future.",
    
    # Instructions
    "Please summarize the following text:",
    "Translate this to French:",
    "Explain like I'm five:",
    
    # Math
    "2 + 2 = 4, 3 + 3 = 6, 4 + 4 =",
    "The derivative of x^2 is 2x.",
    
    # More diverse content
    "The quick brown fox jumps over the lazy dog.",
    "Breaking news: Scientists discover new species",
    "Recipe: Mix flour, eggs, and sugar together.",
    "Hello! How are you today?",
    "ERROR: Connection refused on port 8080",
    "According to the latest research,"
]

In [None]:
def collect_activations(
    model: HookedTransformer,
    prompts: List[str],
    layer: int,
    position: str = "all"  # "all", "last", or "random"
) -> torch.Tensor:
    """
    Collect residual stream activations from model.
    
    Args:
        model: The transformer model
        prompts: List of text prompts
        layer: Which layer to extract from
        position: Which positions to keep
    
    Returns:
        Tensor of activations [n_samples, d_model]
    """
    all_activations = []
    
    with torch.no_grad():
        for prompt in tqdm(prompts, desc="Collecting activations"):
            tokens = model.to_tokens(prompt)
            _, cache = model.run_with_cache(tokens)
            
            # Get residual stream after this layer
            resid = cache["resid_post", layer][0]  # [seq, d_model]
            
            if position == "all":
                all_activations.append(resid)
            elif position == "last":
                all_activations.append(resid[-1:])  # Keep batch dim
            elif position == "random":
                idx = torch.randint(0, resid.shape[0], (1,))
                all_activations.append(resid[idx])
            
            del cache
    
    return torch.cat(all_activations, dim=0)

# Collect activations from middle layer
layer = 6  # Middle layer of GPT-2 Small
activations = collect_activations(model, SAMPLE_PROMPTS, layer=layer, position="all")

print(f"Collected activations: {activations.shape}")
print(f"Activation stats:")
print(f"  Mean: {activations.mean().item():.4f}")
print(f"  Std: {activations.std().item():.4f}")

In [None]:
# Create more training data by generating random tokens too
def generate_random_activations(
    model: HookedTransformer,
    n_batches: int = 50,
    seq_len: int = 20,
    layer: int = 6
) -> torch.Tensor:
    """Generate activations from random token sequences."""
    all_activations = []
    
    with torch.no_grad():
        for _ in tqdm(range(n_batches), desc="Generating random activations"):
            # Random tokens (avoiding special tokens)
            tokens = torch.randint(1000, 40000, (1, seq_len), device=device)
            _, cache = model.run_with_cache(tokens)
            
            resid = cache["resid_post", layer][0]
            all_activations.append(resid)
            
            del cache
    
    return torch.cat(all_activations, dim=0)

# Generate additional random activations
random_activations = generate_random_activations(model, n_batches=50, layer=layer)
print(f"Random activations: {random_activations.shape}")

# Combine all activations
all_activations = torch.cat([activations, random_activations], dim=0)
print(f"\nTotal activations: {all_activations.shape}")

---

## Part 4: Training the SAE

In [None]:
def train_sae(
    sae: SparseAutoencoder,
    activations: torch.Tensor,
    n_epochs: int = 100,
    batch_size: int = 256,
    lr: float = 1e-3,
    sparsity_coef: float = 1e-3
) -> Dict[str, List[float]]:
    """
    Train the SAE on collected activations.
    
    Returns:
        Dictionary of loss histories
    """
    # Create dataloader
    dataset = TensorDataset(activations)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Optimizer
    optimizer = torch.optim.Adam(sae.parameters(), lr=lr)
    
    # History
    history = {
        'total': [], 'recon': [], 'sparsity': [],
        'frac_nonzero': [], 'n_active': []
    }
    
    # Training loop
    for epoch in tqdm(range(n_epochs), desc="Training SAE"):
        epoch_losses = {k: [] for k in history.keys()}
        
        for batch in dataloader:
            x = batch[0]  # [batch, d_model]
            
            optimizer.zero_grad()
            loss, loss_dict = sae.compute_loss(x, sparsity_coef=sparsity_coef)
            loss.backward()
            optimizer.step()
            
            for k, v in loss_dict.items():
                epoch_losses[k].append(v)
        
        # Record epoch averages
        for k in history.keys():
            history[k].append(np.mean(epoch_losses[k]))
        
        # Print progress every 20 epochs
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}: recon={history['recon'][-1]:.4f}, "
                  f"sparsity={history['frac_nonzero'][-1]:.2%}, "
                  f"n_active={history['n_active'][-1]:.1f}")
    
    return history

In [None]:
# Reinitialize SAE for fresh training
sae = SparseAutoencoder(
    d_model=model.cfg.d_model,
    n_features=model.cfg.d_model * 4,  # 4x expansion
    tied_weights=True
).to(device)

# Train
history = train_sae(
    sae,
    all_activations,
    n_epochs=100,
    batch_size=256,
    lr=1e-3,
    sparsity_coef=5e-4  # Tune this for desired sparsity
)

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

axes[0, 0].plot(history['recon'])
axes[0, 0].set_title('Reconstruction Loss')
axes[0, 0].set_xlabel('Epoch')

axes[0, 1].plot(history['sparsity'])
axes[0, 1].set_title('Sparsity Loss (L1)')
axes[0, 1].set_xlabel('Epoch')

axes[1, 0].plot(history['frac_nonzero'])
axes[1, 0].set_title('Fraction of Features Active')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylim(0, 0.5)

axes[1, 1].plot(history['n_active'])
axes[1, 1].set_title('Avg Number of Active Features')
axes[1, 1].set_xlabel('Epoch')

plt.tight_layout()
plt.show()

print(f"\nFinal statistics:")
print(f"  Reconstruction loss: {history['recon'][-1]:.4f}")
print(f"  Active features: {history['n_active'][-1]:.1f} / {sae.n_features}")
print(f"  Sparsity: {history['frac_nonzero'][-1]:.2%}")

---

## Part 5: Analyzing Learned Features

Now let's explore what features the SAE learned!

In [None]:
def find_top_activating_examples(
    sae: SparseAutoencoder,
    model: HookedTransformer,
    prompts: List[str],
    layer: int,
    feature_idx: int,
    top_k: int = 5
) -> List[Tuple[str, int, float]]:
    """
    Find examples that most activate a specific feature.
    
    Returns:
        List of (prompt, token_position, activation_value) tuples
    """
    results = []
    
    with torch.no_grad():
        for prompt in prompts:
            tokens = model.to_tokens(prompt)
            _, cache = model.run_with_cache(tokens)
            
            resid = cache["resid_post", layer][0]  # [seq, d_model]
            features = sae.encode(resid)  # [seq, n_features]
            
            # Get activations for this feature
            feature_acts = features[:, feature_idx].cpu().numpy()
            
            # Get token strings
            token_strs = model.to_str_tokens(tokens)
            
            for pos, act in enumerate(feature_acts):
                if act > 0:
                    results.append((prompt, token_strs[pos], pos, act))
            
            del cache
    
    # Sort by activation and return top k
    results.sort(key=lambda x: -x[3])
    return results[:top_k]

In [None]:
# Find most active features overall
def find_most_active_features(
    sae: SparseAutoencoder,
    activations: torch.Tensor,
    top_k: int = 20
) -> List[Tuple[int, float, float]]:
    """
    Find features that activate most frequently and strongly.
    
    Returns:
        List of (feature_idx, mean_activation, frequency) tuples
    """
    with torch.no_grad():
        features = sae.encode(activations)  # [n_samples, n_features]
        
        # Mean activation when active
        mean_acts = (features * (features > 0).float()).sum(dim=0) / (features > 0).float().sum(dim=0).clamp(min=1)
        
        # Frequency of activation
        frequency = (features > 0).float().mean(dim=0)
        
        # Combined score
        scores = mean_acts * frequency
        
        top_indices = torch.argsort(scores, descending=True)[:top_k]
        
        results = []
        for idx in top_indices:
            results.append((
                idx.item(),
                mean_acts[idx].item(),
                frequency[idx].item()
            ))
        
        return results

# Find most active features
top_features = find_most_active_features(sae, all_activations, top_k=20)

print("Most Active Features:")
print("=" * 50)
for idx, mean_act, freq in top_features:
    print(f"Feature {idx:4d}: mean_act={mean_act:.3f}, freq={freq:.2%}")

In [None]:
# Analyze specific features
def analyze_feature(feature_idx: int):
    """Analyze what a specific feature responds to."""
    print(f"\n{'='*60}")
    print(f"FEATURE {feature_idx}")
    print(f"{'='*60}")
    
    # Find top activating examples
    top_examples = find_top_activating_examples(
        sae, model, SAMPLE_PROMPTS, layer=layer,
        feature_idx=feature_idx, top_k=10
    )
    
    if top_examples:
        print("\nTop activating tokens:")
        for prompt, token, pos, act in top_examples:
            print(f"  [{act:.3f}] '{token}' in \"{prompt[:40]}...\"")
    else:
        print("  (No strong activations found in sample prompts)")

# Analyze a few top features
for idx, _, _ in top_features[:5]:
    analyze_feature(idx)

In [None]:
# Visualize feature activations on a specific prompt
def visualize_feature_activations(
    prompt: str,
    top_k_features: int = 10
):
    """Visualize which features activate for each token in a prompt."""
    tokens = model.to_tokens(prompt)
    token_strs = model.to_str_tokens(tokens)
    
    with torch.no_grad():
        _, cache = model.run_with_cache(tokens)
        resid = cache["resid_post", layer][0]
        features = sae.encode(resid).cpu().numpy()  # [seq, n_features]
    
    # Find features with highest max activation
    max_acts = features.max(axis=0)
    top_feature_indices = np.argsort(max_acts)[-top_k_features:][::-1]
    
    # Create heatmap
    selected_features = features[:, top_feature_indices].T  # [k_features, seq]
    
    fig = px.imshow(
        selected_features,
        labels={"x": "Token", "y": "Feature", "color": "Activation"},
        x=token_strs,
        y=[f"F{i}" for i in top_feature_indices],
        color_continuous_scale="YlOrRd",
        title=f"Feature Activations for: \"{prompt}\""
    )
    fig.update_layout(width=800, height=400)
    fig.show()

# Visualize for different prompts
visualize_feature_activations("def fibonacci(n):\n    if n <= 1:")
visualize_feature_activations("The capital of France is Paris.")
visualize_feature_activations("I am so happy today!")

---

## Part 6: Activation Steering

Can we modify model behavior by adding/subtracting feature directions?

### Text Generation with TransformerLens

Before we steer, let's understand how to generate text with TransformerLens:

```python
# model.generate() - Autoregressive text generation
# Key parameters:
#   - input: Token tensor [batch, seq] to start from
#   - max_new_tokens: How many new tokens to generate
#   - do_sample: If False, use greedy decoding (pick highest probability)
#                If True, sample from distribution (more creative)
#   - temperature: Controls randomness when sampling (higher = more random)

tokens = model.to_tokens("Hello world")
output = model.generate(tokens, max_new_tokens=10, do_sample=False)
text = model.tokenizer.decode(output[0])
```

This is the baseline we'll compare steering against!

In [None]:
def steer_with_feature(
    model: HookedTransformer,
    sae: SparseAutoencoder,
    prompt: str,
    feature_idx: int,
    strength: float = 5.0,
    layer: int = 6
) -> str:
    """
    Generate text while adding a feature direction.
    
    Args:
        model: The transformer model
        sae: Trained SAE
        prompt: Input prompt
        feature_idx: Which feature to add
        strength: How much to add
        layer: Which layer to intervene
    
    Returns:
        Generated text
    """
    # Get the feature direction from the decoder
    with torch.no_grad():
        # Create one-hot feature vector
        feature_vec = torch.zeros(sae.n_features, device=device)
        feature_vec[feature_idx] = 1.0
        
        # Get direction in residual stream space
        direction = sae.decode(feature_vec.unsqueeze(0))[0]  # [d_model]
        direction = direction / direction.norm() * strength
    
    # Define steering hook
    def steering_hook(activation, hook):
        # Add direction to all positions
        return activation + direction
    
    # Generate with steering
    tokens = model.to_tokens(prompt)
    
    generated = []
    for _ in range(20):  # Generate 20 tokens
        with torch.no_grad():
            logits = model.run_with_hooks(
                tokens,
                fwd_hooks=[(f"blocks.{layer}.hook_resid_post", steering_hook)]
            )
        
        # Sample next token
        next_token = logits[0, -1, :].argmax().unsqueeze(0).unsqueeze(0)
        tokens = torch.cat([tokens, next_token], dim=1)
        generated.append(model.tokenizer.decode(next_token[0].item()))
    
    return prompt + "".join(generated)

In [None]:
# Compare generation with and without steering
prompt = "The weather today is"

# Normal generation
tokens = model.to_tokens(prompt)
with torch.no_grad():
    normal_output = model.generate(
        tokens, max_new_tokens=20, do_sample=False
    )
normal_text = model.tokenizer.decode(normal_output[0])

print("Normal generation:")
print(f"  {normal_text}")
print()

# Try steering with a few different features
for feature_idx in [top_features[0][0], top_features[5][0], top_features[10][0]]:
    steered_text = steer_with_feature(
        model, sae, prompt, feature_idx,
        strength=10.0, layer=layer
    )
    print(f"Steered with feature {feature_idx}:")
    print(f"  {steered_text}")
    print()

### Interpreting Steering Results

If a feature represents a coherent concept, steering should produce semantically consistent changes. However:

- Not all features are interpretable (some are polysemantic)
- Steering strength needs tuning (too much = incoherent)
- This is a simplified SAE - real SAEs are much larger

---

## Try It Yourself

### Exercise 1: Different Sparsity Levels
Train SAEs with different sparsity coefficients (1e-2, 1e-3, 1e-4). How does this affect:
- Reconstruction quality?
- Number of active features?
- Interpretability?

<details>
<summary>Hint</summary>

Higher sparsity coefficient = fewer active features = potentially more monosemantic but higher reconstruction error.
</details>

In [None]:
# Exercise 1: Your code here



### Exercise 2: Feature Similarity
Find pairs of features that often co-activate. What might this tell us about their relationship?

<details>
<summary>Hint</summary>

Compute the correlation matrix of feature activations across all samples. High correlation = features that activate together.

```python
# Get all feature activations
with torch.no_grad():
    features = sae.encode(all_activations).cpu().numpy()  # [n_samples, n_features]

# Option 1: Using numpy's corrcoef
# np.corrcoef computes Pearson correlation coefficients
# It takes a 2D array where each row is a variable, columns are observations
correlation_matrix = np.corrcoef(features.T)  # [n_features, n_features]

# Option 2: Using pandas (more memory efficient for large matrices)
import pandas as pd
df_features = pd.DataFrame(features)
correlation_matrix = df_features.corr().values

# Find highly correlated pairs (excluding self-correlation on diagonal)
threshold = 0.5
high_corr_pairs = []
for i in range(len(correlation_matrix)):
    for j in range(i+1, len(correlation_matrix)):
        if abs(correlation_matrix[i, j]) > threshold:
            high_corr_pairs.append((i, j, correlation_matrix[i, j]))
```
</details>

In [None]:
# Exercise 2: Your code here



### Exercise 3: Different Layers
Train SAEs on early (layer 2) and late (layer 10) layers. Are features more abstract in later layers?

<details>
<summary>Hint</summary>

Early layers often capture syntactic features (punctuation, parts of speech). Later layers capture semantic features (topics, concepts).
</details>

In [None]:
# Exercise 3: Your code here



---

## Common Mistakes

### Mistake 1: Insufficient Training Data
```python
# Wrong: Training on a few prompts
activations = collect_activations(model, ["Hello world"], layer=6)

# Correct: Diverse, large-scale data
activations = collect_activations(model, many_diverse_prompts, layer=6)
```
**Why:** SAEs need diverse data to learn general features, not prompt-specific patterns.

### Mistake 2: Wrong Sparsity Balance
```python
# Wrong: Too high sparsity
loss = recon + 1.0 * sparsity  # Almost no features activate!

# Correct: Balanced sparsity
loss = recon + 1e-4 * sparsity  # ~50-200 active features
```
**Why:** Need enough active features to reconstruct, but sparse enough to be interpretable.

### Mistake 3: Confusing Features with Neurons
```python
# SAE features ≠ model neurons!
# Features are learned directions in activation space
# They may not correspond to any single neuron
```
**Why:** SAEs extract *directions*, which are linear combinations of neurons.

---

## Checkpoint

You've learned:
- Why sparse autoencoders are useful for interpretability
- How to implement and train an SAE
- How to collect diverse activations for training
- How to analyze and interpret learned features
- How to steer model behavior using feature directions

---

## Challenge (Optional)

**Replicate Anthropic's Feature Analysis**

Anthropic's "Towards Monosemanticity" paper trained SAEs with millions of features. Try:

1. Increase to 8x or 16x expansion
2. Train on much more data (Wikipedia, code, etc.)
3. Find specific interpretable features:
   - "Python code"
   - "Questions"
   - "Negative sentiment"
   - "First-person narrative"

This requires more compute but is very rewarding!

In [None]:
# Challenge: Large-scale SAE training
# Your code here



---

## Further Reading

- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features/index.html) - Anthropic's landmark paper
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/index.html) - SAEs on Claude
- [Neuronpedia](https://www.neuronpedia.org/) - Interactive feature explorer
- [SAE Lens](https://github.com/jbloomAus/SAELens) - Production SAE training library

---

## Cleanup

In [None]:
# Clear GPU memory
del sae, all_activations, activations, random_activations
gc.collect()
torch.cuda.empty_cache()

print("Memory cleared!")

---

## Congratulations!

You've completed Module C: Mechanistic Interpretability! You now have hands-on experience with:

1. **TransformerLens** - The core interpretability toolkit
2. **Activation Patching** - Finding causal mechanisms
3. **Induction Heads** - A fundamental circuit for in-context learning
4. **Sparse Autoencoders** - Extracting interpretable features

These are the same techniques used at top AI labs like Anthropic, DeepMind, and OpenAI!

### Where to Go From Here

1. **Contribute to open research**: Check out [200 Concrete Problems](https://www.alignmentforum.org/posts/LbrPTJ4fmABEdEnLf/200-concrete-open-problems-in-mechanistic-interpretability)
2. **Explore Neuronpedia**: Analyze pre-trained SAE features
3. **Join the community**: MATS, ARENA, and AI Safety Camp programs
4. **Apply to your projects**: Use these tools to understand your own models

Good luck on your interpretability journey!