# Lab C.4: Feature Extraction with SAEs - SOLUTIONS

This notebook contains solutions to all exercises from Lab C.4.

---

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import gc

device = "cuda" if torch.cuda.is_available() else "cpu"
model = HookedTransformer.from_pretrained("gpt2-small", device=device)
model.eval()

In [None]:
# SAE class (from main notebook)
class SparseAutoencoder(nn.Module):
    def __init__(self, d_model, n_features):
        super().__init__()
        self.d_model = d_model
        self.n_features = n_features
        self.encoder = nn.Linear(d_model, n_features, bias=True)
        self.decoder_bias = nn.Parameter(torch.zeros(d_model))
        nn.init.kaiming_uniform_(self.encoder.weight, nonlinearity='relu')
    
    def encode(self, x):
        return F.relu(self.encoder(x))
    
    def decode(self, features):
        return F.linear(features, self.encoder.weight.T, self.decoder_bias)
    
    def forward(self, x):
        features = self.encode(x)
        return self.decode(features), features
    
    def compute_loss(self, x, sparsity_coef=1e-3):
        reconstructed, features = self.forward(x)
        recon_loss = F.mse_loss(reconstructed, x)
        sparsity_loss = features.abs().mean()
        return recon_loss + sparsity_coef * sparsity_loss, {
            'recon': recon_loss.item(),
            'sparsity': sparsity_loss.item(),
            'n_active': (features > 0).sum(dim=-1).float().mean().item()
        }

## Exercise 1: Different Sparsity Levels

In [None]:
# Solution: Compare different sparsity coefficients

# Collect activations
prompts = [
    "The capital of France is Paris.",
    "def fibonacci(n): return n if n <= 1 else",
    "What is the meaning of life?",
    "I am so happy today!",
    "The quick brown fox jumps over the lazy dog."
] * 10  # Repeat for more data

activations = []
layer = 6
with torch.no_grad():
    for prompt in prompts:
        tokens = model.to_tokens(prompt)
        _, cache = model.run_with_cache(tokens)
        activations.append(cache["resid_post", layer][0])
        del cache

activations = torch.cat(activations, dim=0)
print(f"Activations: {activations.shape}")

# Train with different sparsity levels
sparsity_coefs = [1e-2, 1e-3, 1e-4]
results = {}

for coef in sparsity_coefs:
    print(f"\nTraining with sparsity_coef={coef}")
    sae = SparseAutoencoder(model.cfg.d_model, model.cfg.d_model * 4).to(device)
    optimizer = torch.optim.Adam(sae.parameters(), lr=1e-3)
    
    dataset = TensorDataset(activations)
    loader = DataLoader(dataset, batch_size=64, shuffle=True)
    
    for epoch in range(50):
        for batch in loader:
            optimizer.zero_grad()
            loss, _ = sae.compute_loss(batch[0], sparsity_coef=coef)
            loss.backward()
            optimizer.step()
    
    # Evaluate
    with torch.no_grad():
        _, stats = sae.compute_loss(activations, sparsity_coef=coef)
    
    results[coef] = stats
    print(f"  Recon loss: {stats['recon']:.4f}")
    print(f"  Active features: {stats['n_active']:.1f}")

# Summary
print("\n" + "="*50)
print("Summary:")
print("="*50)
print(f"{'Sparsity Coef':<15} {'Recon Loss':<12} {'Active Features':<15}")
for coef, stats in results.items():
    print(f"{coef:<15} {stats['recon']:<12.4f} {stats['n_active']:<15.1f}")

print("\nConclusion:")
print("- Higher sparsity = fewer active features but worse reconstruction")
print("- Lower sparsity = more active features, potentially less interpretable")
print("- Sweet spot is typically 1e-4 to 1e-3")

## Exercise 2: Feature Similarity

In [None]:
# Solution: Find co-activating features

# Train a fresh SAE
sae = SparseAutoencoder(model.cfg.d_model, model.cfg.d_model * 4).to(device)
optimizer = torch.optim.Adam(sae.parameters(), lr=1e-3)

for _ in range(50):
    for batch in DataLoader(TensorDataset(activations), batch_size=64, shuffle=True):
        optimizer.zero_grad()
        loss, _ = sae.compute_loss(batch[0], sparsity_coef=5e-4)
        loss.backward()
        optimizer.step()

# Get feature activations
with torch.no_grad():
    features = sae.encode(activations).cpu().numpy()  # [n_samples, n_features]

# Compute correlation matrix (for active features only)
feature_active = (features > 0).mean(axis=0) > 0.01  # Features active >1% of time
active_indices = np.where(feature_active)[0]

if len(active_indices) > 0:
    active_features = features[:, active_indices]
    correlation = np.corrcoef(active_features.T)
    
    # Find highly correlated pairs (excluding self-correlation)
    np.fill_diagonal(correlation, 0)
    
    # Top 10 most correlated pairs
    flat_corr = correlation.flatten()
    top_idx = np.argsort(flat_corr)[-20:][::-1]  # Top 20 (pairs appear twice)
    
    print("Top correlated feature pairs:")
    seen = set()
    for idx in top_idx:
        i = idx // len(active_indices)
        j = idx % len(active_indices)
        if i < j and (i, j) not in seen:
            feat_i = active_indices[i]
            feat_j = active_indices[j]
            print(f"  Feature {feat_i} <-> Feature {feat_j}: corr = {correlation[i, j]:.3f}")
            seen.add((i, j))
            if len(seen) >= 10:
                break
    
    # Visualize correlation matrix
    if len(active_indices) <= 100:  # Only if manageable size
        fig = px.imshow(
            correlation,
            title="Feature Correlation Matrix",
            color_continuous_scale="RdBu_r",
            color_continuous_midpoint=0
        )
        fig.update_layout(width=600, height=600)
        fig.show()
else:
    print("No features active frequently enough for correlation analysis")

print("\nInterpretation:")
print("Features that co-activate might represent related concepts")
print("or be part of the same underlying 'circuit' in the model.")

## Exercise 3: Different Layers

In [None]:
# Solution: Compare SAEs on early vs late layers

layers_to_compare = [2, 10]  # Early and late
layer_saes = {}

for layer in layers_to_compare:
    print(f"\nLayer {layer}:")
    print("="*40)
    
    # Collect activations
    layer_acts = []
    with torch.no_grad():
        for prompt in prompts:
            tokens = model.to_tokens(prompt)
            _, cache = model.run_with_cache(tokens)
            layer_acts.append(cache["resid_post", layer][0])
            del cache
    layer_acts = torch.cat(layer_acts, dim=0)
    
    # Train SAE
    sae = SparseAutoencoder(model.cfg.d_model, model.cfg.d_model * 4).to(device)
    optimizer = torch.optim.Adam(sae.parameters(), lr=1e-3)
    
    for _ in range(50):
        for batch in DataLoader(TensorDataset(layer_acts), batch_size=64, shuffle=True):
            optimizer.zero_grad()
            loss, _ = sae.compute_loss(batch[0], sparsity_coef=5e-4)
            loss.backward()
            optimizer.step()
    
    layer_saes[layer] = sae
    
    # Analyze features
    with torch.no_grad():
        features = sae.encode(layer_acts)
    
    # Find most active features
    mean_acts = features.mean(dim=0)
    top_features = mean_acts.argsort(descending=True)[:5]
    
    print(f"Top 5 most active features: {top_features.tolist()}")
    print(f"Average active features: {(features > 0).sum(dim=-1).float().mean():.1f}")

# Compare what activates features in each layer
test_prompts = {
    "Code": "def hello_world():",
    "Question": "What is the capital?",
    "Emotion": "I am very happy!",
    "Factual": "The moon orbits Earth."
}

print("\n" + "="*60)
print("Feature activation comparison:")
print("="*60)

for name, prompt in test_prompts.items():
    print(f"\n'{prompt}' ({name}):")
    tokens = model.to_tokens(prompt)
    _, cache = model.run_with_cache(tokens)
    
    for layer, sae in layer_saes.items():
        acts = cache["resid_post", layer][0, -1, :]  # Last token
        with torch.no_grad():
            features = sae.encode(acts.unsqueeze(0))
        
        n_active = (features > 0).sum().item()
        max_feat = features.argmax().item()
        max_val = features.max().item()
        
        print(f"  Layer {layer}: {n_active} active features, strongest=F{max_feat} ({max_val:.2f})")
    
    del cache

print("\nConclusion:")
print("- Early layers often capture more syntactic/positional features")
print("- Later layers capture more semantic/conceptual features")
print("- This aligns with the 'hierarchy of abstraction' in deep networks")

## Cleanup

In [None]:
del activations, sae, layer_saes
gc.collect()
torch.cuda.empty_cache()
print("Cleanup complete!")