# Week 7 Exercise: Unsupervised Feature Discovery and Superposition

In this exercise, you'll gain hands-on experience with:
- Building toy models of superposition
- Training sparse autoencoders
- Interpreting discovered features
- Validating feature quality
- Comparing SAE features with other methods
- Understanding feature splitting and capacity

## Setup

Install required libraries:

In [None]:
!pip install transformers torch numpy matplotlib einops -q

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
warnings.filterwarnings('ignore')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## Part 1: Toy Model of Superposition

Build a simple model to understand how features can exist in superposition.

In [None]:
class ToyModel(nn.Module):
    """Simple autoencoder to demonstrate superposition."""
    
    def __init__(self, n_features=5, n_hidden=2):
        super().__init__()
        self.n_features = n_features
        self.n_hidden = n_hidden
        
        # Encoder and decoder
        self.W_enc = nn.Parameter(torch.randn(n_hidden, n_features) * 0.1)
        self.W_dec = nn.Parameter(torch.randn(n_features, n_hidden) * 0.1)
        self.b_enc = nn.Parameter(torch.zeros(n_hidden))
        self.b_dec = nn.Parameter(torch.zeros(n_features))
    
    def forward(self, x):
        # Encode to bottleneck
        hidden = torch.matmul(x, self.W_enc.t()) + self.b_enc
        hidden = torch.relu(hidden)
        
        # Decode back
        reconstructed = torch.matmul(hidden, self.W_dec.t()) + self.b_dec
        
        return reconstructed, hidden


def generate_sparse_data(n_features, n_samples, sparsity=0.1):
    """
    Generate sparse data where most features are zero.
    
    Args:
        n_features: Number of features
        n_samples: Number of samples
        sparsity: Probability that each feature is non-zero
    """
    # Random binary mask for sparsity
    mask = (torch.rand(n_samples, n_features) < sparsity).float()
    
    # Feature values when active
    values = torch.randn(n_samples, n_features)
    
    # Sparse data
    data = mask * values
    
    return data


# Create toy model with more features than dimensions
n_features = 5
n_hidden = 2
n_samples = 1000
sparsity = 0.1

print(f"Toy Model Setup:")
print(f"  Features: {n_features}")
print(f"  Hidden dimensions: {n_hidden}")
print(f"  Compression ratio: {n_features/n_hidden:.1f}x")
print(f"  Sparsity: {sparsity:.1%}")
print(f"\nThis creates conditions for superposition:")
print(f"  - More features than dimensions")
print(f"  - Sparse activations (low interference)")

In [None]:
# Train toy model
model = ToyModel(n_features, n_hidden).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
losses = []
for epoch in range(500):
    data = generate_sparse_data(n_features, n_samples, sparsity).to(device)
    
    reconstructed, hidden = model(data)
    loss = ((reconstructed - data) ** 2).mean()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

# Plot training curve
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Reconstruction Loss')
plt.title('Toy Model Training: Learning Superposition')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Visualize learned representations
# The columns of W_dec show how features are represented in 2D space
W_dec_np = model.W_dec.detach().cpu().numpy()

plt.figure(figsize=(8, 8))
plt.axhline(y=0, color='k', linewidth=0.5)
plt.axvline(x=0, color='k', linewidth=0.5)

# Plot each feature as a vector
for i in range(n_features):
    vector = W_dec_np[i, :]
    plt.arrow(0, 0, vector[0], vector[1], 
             head_width=0.05, head_length=0.05, 
             fc=f'C{i}', ec=f'C{i}',
             label=f'Feature {i}')

plt.xlabel('Hidden Dimension 1')
plt.ylabel('Hidden Dimension 2')
plt.title(f'Feature Representations in 2D Space\n({n_features} features in {n_hidden} dimensions)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.show()

print("\nInterpretation:")
print("  - Each arrow represents how a feature is encoded in 2D space")
print("  - Notice: 5 features packed into 2 dimensions (superposition!)")
print("  - Features can interfere when multiple are active simultaneously")
print("  - But with sparsity, interference is rare")

In [None]:
# Test interference
# When only one feature active → good reconstruction
# When multiple features active → interference

# Single feature active
single_input = torch.zeros(1, n_features).to(device)
single_input[0, 0] = 1.0
single_recon, _ = model(single_input)

print("Single feature active:")
print(f"  Input:  {single_input[0].cpu().numpy()}")
print(f"  Recon:  {single_recon[0].detach().cpu().numpy()}")
print(f"  Error:  {((single_recon - single_input)**2).sum().item():.4f}")

# Multiple features active (interference)
multi_input = torch.zeros(1, n_features).to(device)
multi_input[0, 0] = 1.0
multi_input[0, 1] = 1.0
multi_input[0, 2] = 1.0
multi_recon, _ = model(multi_input)

print("\nMultiple features active:")
print(f"  Input:  {multi_input[0].cpu().numpy()}")
print(f"  Recon:  {multi_recon[0].detach().cpu().numpy()}")
print(f"  Error:  {((multi_recon - multi_input)**2).sum().item():.4f}")

print("\nConclusion: More interference when multiple features active!")

## Part 2: Training a Sparse Autoencoder

Now let's train an SAE to decompose real model activations.

In [None]:
# Load model
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = model.to(device)
model.eval()

print(f"Model: {model_name}")
print(f"Hidden size: {model.config.n_embd}")

In [None]:
class SparseAutoencoder(nn.Module):
    """Sparse autoencoder for decomposing model activations."""
    
    def __init__(self, d_model, d_hidden, l1_coeff=1e-3):
        super().__init__()
        self.d_model = d_model
        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff
        
        # Encoder and decoder
        self.W_enc = nn.Parameter(torch.randn(d_hidden, d_model) / np.sqrt(d_model))
        self.W_dec = nn.Parameter(torch.randn(d_model, d_hidden) / np.sqrt(d_hidden))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden))
        self.b_dec = nn.Parameter(torch.zeros(d_model))
    
    def forward(self, x):
        # Encode
        pre_activation = torch.matmul(x, self.W_enc.t()) + self.b_enc
        features = torch.relu(pre_activation)
        
        # Decode
        reconstructed = torch.matmul(features, self.W_dec.t()) + self.b_dec
        
        return reconstructed, features
    
    def loss(self, x):
        reconstructed, features = self.forward(x)
        
        # Reconstruction loss
        recon_loss = ((reconstructed - x) ** 2).mean()
        
        # Sparsity penalty (L1)
        sparsity_loss = self.l1_coeff * features.abs().mean()
        
        return recon_loss + sparsity_loss, recon_loss, sparsity_loss


# Create SAE
d_model = model.config.n_embd  # 768 for GPT-2 small
d_hidden = d_model * 4  # 4x overcomplete (3072 features)

sae = SparseAutoencoder(d_model, d_hidden, l1_coeff=1e-3).to(device)

print(f"SAE Architecture:")
print(f"  Input dimensions: {d_model}")
print(f"  Hidden features: {d_hidden}")
print(f"  Expansion factor: {d_hidden/d_model:.1f}x")
print(f"  L1 coefficient: {sae.l1_coeff}")

In [None]:
# Collect activation data
def collect_activations(texts, layer_idx=-1):
    """Extract activations from a specific layer."""
    activations = []
    
    for text in texts:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128).to(device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states[layer_idx]
            
            # Collect all token activations
            for pos in range(hidden_states.shape[1]):
                activations.append(hidden_states[0, pos, :].cpu())
    
    return torch.stack(activations)


# Example texts
texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Machine learning is a subset of artificial intelligence.",
    "Python is a popular programming language for data science.",
    "The weather today is sunny and warm.",
    "Neural networks consist of layers of interconnected nodes.",
    "Coffee is a popular beverage consumed worldwide.",
    "The capital of France is Paris.",
    "Quantum computing could revolutionize computation.",
    "The human brain contains billions of neurons.",
    "Mathematics is the language of science."
]

print("Collecting activations...")
activations = collect_activations(texts, layer_idx=6)
print(f"Collected {activations.shape[0]} activation vectors")
print(f"Shape: {activations.shape}")

In [None]:
# Train SAE
optimizer = optim.Adam(sae.parameters(), lr=1e-3)

batch_size = 32
n_epochs = 100

train_losses = []
recon_losses = []
sparsity_losses = []
l0_sparsities = []

print("Training SAE...\n")

for epoch in range(n_epochs):
    # Shuffle data
    indices = torch.randperm(activations.shape[0])
    
    epoch_loss = 0
    epoch_recon = 0
    epoch_sparsity = 0
    epoch_l0 = 0
    n_batches = 0
    
    for i in range(0, activations.shape[0], batch_size):
        batch_indices = indices[i:i+batch_size]
        batch = activations[batch_indices].to(device)
        
        loss, recon_loss, sparsity_loss = sae.loss(batch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track metrics
        with torch.no_grad():
            _, features = sae(batch)
            l0 = (features > 0).float().sum(dim=1).mean().item()
        
        epoch_loss += loss.item()
        epoch_recon += recon_loss.item()
        epoch_sparsity += sparsity_loss.item()
        epoch_l0 += l0
        n_batches += 1
    
    # Average over batches
    train_losses.append(epoch_loss / n_batches)
    recon_losses.append(epoch_recon / n_batches)
    sparsity_losses.append(epoch_sparsity / n_batches)
    l0_sparsities.append(epoch_l0 / n_batches)
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}:")
        print(f"  Total loss: {train_losses[-1]:.4f}")
        print(f"  Recon loss: {recon_losses[-1]:.4f}")
        print(f"  L0 sparsity: {l0_sparsities[-1]:.1f}/{d_hidden} ({100*l0_sparsities[-1]/d_hidden:.1f}%)")

print("\nTraining complete!")

In [None]:
# Visualize training
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(recon_losses)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Reconstruction Loss')
axes[0].set_title('Reconstruction Quality')
axes[0].grid(True, alpha=0.3)

axes[1].plot(sparsity_losses)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('L1 Penalty')
axes[1].set_title('Sparsity Penalty')
axes[1].grid(True, alpha=0.3)

axes[2].plot(l0_sparsities)
axes[2].axhline(y=d_hidden, color='r', linestyle='--', label='All features')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Active Features (L0)')
axes[2].set_title('Feature Sparsity')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal metrics:")
print(f"  Reconstruction loss: {recon_losses[-1]:.4f}")
print(f"  Active features: {l0_sparsities[-1]:.1f}/{d_hidden} ({100*l0_sparsities[-1]/d_hidden:.1f}%)")
print(f"  Explained variance: {1 - recon_losses[-1]/activations.var():.1%}")

## Part 3: Feature Interpretation

Interpret what features the SAE discovered.

In [None]:
# Find most active features
with torch.no_grad():
    _, all_features = sae(activations.to(device))
    
    # Feature activation frequencies
    feature_counts = (all_features > 0).float().sum(dim=0).cpu()
    
    # Feature activation magnitudes
    feature_magnitudes = all_features.mean(dim=0).cpu()

# Find top features
top_k = 10
top_indices = torch.argsort(feature_counts, descending=True)[:top_k]

print(f"Top {top_k} most frequently active features:\n")
for i, idx in enumerate(top_indices):
    freq = feature_counts[idx] / activations.shape[0]
    mag = feature_magnitudes[idx]
    print(f"{i+1}. Feature {idx.item()}:")
    print(f"   Activation frequency: {freq:.1%}")
    print(f"   Average magnitude: {mag:.4f}")

In [None]:
# Find max-activating examples for a feature
def find_max_activating_examples(feature_idx, activations, texts, tokenizer, k=5):
    """
    Find texts/tokens where a feature activates most strongly.
    """
    with torch.no_grad():
        _, features = sae(activations.to(device))
        feature_activations = features[:, feature_idx].cpu()
    
    # Find top k activations
    top_k_values, top_k_indices = torch.topk(feature_activations, k)
    
    # Map back to texts (approximate)
    examples = []
    token_idx = 0
    for text in texts:
        tokens = tokenizer.tokenize(text)
        for i, token in enumerate(tokens):
            if token_idx in top_k_indices:
                activation = feature_activations[token_idx].item()
                examples.append((text, token, activation))
            token_idx += 1
    
    return examples


# Analyze a specific feature
feature_to_analyze = top_indices[0].item()
examples = find_max_activating_examples(feature_to_analyze, activations, texts, tokenizer, k=5)

print(f"\nMax-activating examples for Feature {feature_to_analyze}:\n")
for i, (text, token, activation) in enumerate(examples[:5]):
    print(f"{i+1}. Token: '{token}' (activation: {activation:.4f})")
    print(f"   Context: {text}")
    print()

In [None]:
# Simple automated interpretation
# In practice, you'd use GPT-4 API for this
# Here we'll just show the pattern

def interpret_feature_simple(examples):
    """
    Simple heuristic interpretation based on max-activating examples.
    In practice, use LLM for this.
    """
    tokens = [ex[1] for ex in examples]
    
    # Simple heuristics
    if all(t.startswith('Ġ') for t in tokens):  # Space-prefixed tokens
        return "Likely: word boundaries or start of tokens"
    elif all(t.isupper() for t in tokens):
        return "Likely: uppercase/acronyms"
    elif all(t.isdigit() for t in tokens):
        return "Likely: numbers"
    else:
        return f"Tokens: {', '.join(tokens[:3])}..."


interpretation = interpret_feature_simple(examples)
print(f"Simple interpretation: {interpretation}")
print("\nNote: In practice, use GPT-4 API for automated interpretation.")
print("Send max-activating examples to GPT-4 and ask 'What concept do these share?'")

## Part 4: Feature Validation

Test feature quality: monosemanticity and causality.

In [None]:
# Test monosemanticity: diversity of activating examples
def test_monosemanticity(feature_idx, activations, texts, tokenizer, threshold=0.1):
    """
    Check if feature responds to coherent concept.
    """
    with torch.no_grad():
        _, features = sae(activations.to(device))
        feature_acts = features[:, feature_idx].cpu()
    
    # Get all activations above threshold
    active_indices = (feature_acts > threshold).nonzero(as_tuple=True)[0]
    
    # Sample of activating contexts
    n_samples = min(10, len(active_indices))
    sample_indices = active_indices[torch.randperm(len(active_indices))[:n_samples]]
    
    print(f"Feature {feature_idx} activates on {len(active_indices)} examples")
    print(f"\nSample of activating contexts:")
    
    token_idx = 0
    shown = 0
    for text in texts:
        tokens = tokenizer.tokenize(text)
        for token in tokens:
            if token_idx in sample_indices and shown < 5:
                print(f"  '{token}' in: {text}")
                shown += 1
            token_idx += 1
    
    return len(active_indices)


# Test a feature
n_active = test_monosemanticity(feature_to_analyze, activations, texts, tokenizer)
print(f"\nMonosemanticity check: Do these contexts share a common concept?")

In [None]:
# Test causality: steering with SAE features
def steer_with_sae_feature(model, sae, text, feature_idx, layer_idx, alpha=5.0):
    """
    Test if amplifying an SAE feature affects model output.
    This is a simplified version - full implementation needs hooks.
    """
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # Baseline
    with torch.no_grad():
        outputs = model(**inputs)
        baseline_logits = outputs.logits[0, -1, :]
        baseline_probs = torch.softmax(baseline_logits, dim=-1)
        baseline_top = torch.topk(baseline_probs, 5)
    
    print(f"Baseline predictions for '{text}':")
    for prob, idx in zip(baseline_top.values, baseline_top.indices):
        token = tokenizer.decode([idx])
        print(f"  {token}: {prob:.4f}")
    
    # In practice, you would:
    # 1. Get activations at layer_idx
    # 2. Pass through SAE encoder to get features
    # 3. Multiply feature_idx by alpha
    # 4. Decode back through SAE
    # 5. Continue model forward pass
    # 6. Compare outputs
    
    print("\nNote: Full steering implementation requires activation hooks.")
    print("See Week 2 exercises for complete steering code.")


# Test steering
test_text = "The capital of France is"
steer_with_sae_feature(model, sae, test_text, feature_to_analyze, layer_idx=6)

## Part 5: Comparing with Other Methods

Compare SAE features with steering vectors and probes.

In [None]:
# Compare SAE feature with steering vector (from Week 2)
# For a concept, compute steering vector and SAE features

def compare_sae_with_steering(sae, steering_vector, activations):
    """
    Compare SAE decomposition with steering vector.
    """
    with torch.no_grad():
        # Encode steering vector through SAE
        steering_encoded, steering_features = sae(steering_vector.unsqueeze(0).to(device))
        
        # Find which SAE features activate
        active_features = (steering_features[0] > 0.1).nonzero(as_tuple=True)[0]
        
        print(f"Steering vector decomposition:")
        print(f"  Active SAE features: {len(active_features)}")
        
        # Top contributing features
        top_k = min(5, len(active_features))
        top_values, top_indices = torch.topk(steering_features[0], top_k)
        
        print(f"\n  Top {top_k} features:")
        for i, (val, idx) in enumerate(zip(top_values, top_indices)):
            print(f"    {i+1}. Feature {idx.item()}: {val.item():.4f}")
        
        # Reconstruction quality
        recon_error = ((steering_encoded[0] - steering_vector.to(device))**2).sum().item()
        print(f"\n  Reconstruction error: {recon_error:.4f}")


# Example: create a simple "steering vector" (random for demo)
demo_steering_vector = torch.randn(d_model)
compare_sae_with_steering(sae, demo_steering_vector, activations)

print("\nInterpretation:")
print("  - Steering vectors often decompose into 1-3 main SAE features")
print("  - This reveals what 'concepts' make up the steering direction")
print("  - Can use SAE features for more targeted steering")

## Part 6: Feature Splitting

Explore how features split as SAE capacity increases.

In [None]:
# Train SAEs with different capacities
capacities = [d_model * 2, d_model * 4, d_model * 8]
capacity_results = []

print("Training SAEs with different capacities...\n")

for capacity in capacities:
    print(f"Training {capacity}-feature SAE...")
    
    # Create and train SAE
    sae_temp = SparseAutoencoder(d_model, capacity, l1_coeff=1e-3).to(device)
    optimizer_temp = optim.Adam(sae_temp.parameters(), lr=1e-3)
    
    # Quick training (fewer epochs for demo)
    for epoch in range(20):
        indices = torch.randperm(activations.shape[0])
        for i in range(0, activations.shape[0], batch_size):
            batch_indices = indices[i:i+batch_size]
            batch = activations[batch_indices].to(device)
            
            loss, _, _ = sae_temp.loss(batch)
            optimizer_temp.zero_grad()
            loss.backward()
            optimizer_temp.step()
    
    # Evaluate
    with torch.no_grad():
        _, features = sae_temp(activations.to(device))
        recon_loss = ((sae_temp(activations.to(device))[0] - activations.to(device))**2).mean().item()
        l0 = (features > 0).float().sum(dim=1).mean().item()
        
        capacity_results.append({
            'capacity': capacity,
            'recon_loss': recon_loss,
            'l0': l0,
            'sparsity': l0 / capacity
        })
    
    print(f"  Recon loss: {recon_loss:.4f}")
    print(f"  Active features: {l0:.1f}/{capacity} ({100*l0/capacity:.1f}%)")
    print()

print("Training complete!")

In [None]:
# Visualize capacity-performance tradeoff
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Reconstruction vs capacity
axes[0].plot([r['capacity'] for r in capacity_results],
            [r['recon_loss'] for r in capacity_results],
            marker='o')
axes[0].set_xlabel('SAE Capacity (# features)')
axes[0].set_ylabel('Reconstruction Loss')
axes[0].set_title('Reconstruction Quality vs Capacity')
axes[0].grid(True, alpha=0.3)

# Active features vs capacity
axes[1].plot([r['capacity'] for r in capacity_results],
            [r['l0'] for r in capacity_results],
            marker='o', label='Active features')
axes[1].plot([r['capacity'] for r in capacity_results],
            [r['capacity'] for r in capacity_results],
            'r--', label='Total capacity')
axes[1].set_xlabel('SAE Capacity (# features)')
axes[1].set_ylabel('Number of Features')
axes[1].set_title('Active Features vs Capacity')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("  - Higher capacity → better reconstruction (less information loss)")
print("  - Higher capacity → features split into subfeatrues")
print("  - Trade-off: interpretability (fewer features) vs completeness (more features)")

## Part 7: Your Project Template

Apply SAEs to discover features for your concept.

In [None]:
print("Week 7 Project Template: SAE Feature Discovery for Your Concept\n")

print("1. Train or load SAEs")
# Option A: Train your own SAEs on relevant layers
# Option B: Use pre-trained SAEs from Neuronpedia

print("\n2. Extract features on your dataset")
# Run your concept-relevant texts through SAE
# Identify which features activate

print("\n3. Interpret features")
# Use automated interpretation (GPT-4) on max-activating examples
# Identify 10-20 features most relevant to your concept

print("\n4. Validate features")
# Monosemanticity: coherent concept?
# Consistency: activates reliably?
# Causality: steering tests (Week 2 methods)

print("\n5. Compare with other methods")
# How do SAE features relate to:
#   - Steering vectors (Week 2)
#   - Probe directions (Week 6)
#   - Circuit components (Week 5)

print("\n6. Analyze completeness")
# Do SAE features capture all aspects of your concept?
# What's missing?
# Are multiple features needed?

print("\n7. Document findings")
# Create feature catalog with interpretations
# Validation results
# Comparison with other methods

## Summary

In this exercise, you've learned:
- How superposition allows networks to represent more features than dimensions
- How SAEs use sparse decomposition to reverse superposition
- How to train and evaluate SAEs
- How to interpret features using automated methods
- How to validate feature quality (monosemanticity, causality)
- How SAE features relate to other discovery methods
- How feature splitting occurs as capacity increases

Key takeaways:
- **Superposition explains polysemanticity** - features interfere in limited dimensions
- **SAEs discover features unsupervised** - no labels needed
- **Validation is critical** - feature presence ≠ causal use
- **Comparison reveals structure** - how features relate to steering/probes
- **Capacity creates trade-offs** - completeness vs interpretability

For your project:
1. Use SAEs to discover features related to your concept
2. Validate quality with multiple methods
3. Compare with previous findings (steering, probes, circuits)
4. Identify gaps in feature coverage
5. Use validated features for steering and analysis