# InterpretabilityWorkbench Tutorial

This tutorial walks through the complete workflow of the InterpretabilityWorkbench:
1. Recording model activations
2. Training sparse autoencoders (SAEs)
3. Analyzing discovered features
4. Creating live LoRA patches
5. Evaluating the results

**Expected runtime**: ~30 minutes (with GPU)
**Requirements**: GPU with 8GB+ VRAM

## Setup and Imports

In [None]:
# Install dependencies if needed
# !pip install -e .

import torch
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModel, AutoTokenizer

# Local imports
from trace import ActivationRecorder, FeatureAnalyzer
from sae_train import train_sae, SparseAutoencoder
from lora_patch import LoRAPatcher
from eval import SAEEvaluator

# Configure matplotlib
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Step 1: Record Model Activations

We'll use a small model for this tutorial to keep runtime manageable.

In [None]:
# Configuration
MODEL_NAME = "microsoft/DialoGPT-small"  # Small model for tutorial
LAYER_IDX = 8  # Middle layer
MAX_SAMPLES = 1000  # Small dataset for tutorial
ACTIVATION_FILE = "tutorial_activations.parquet"

print(f"Recording activations from {MODEL_NAME}, layer {LAYER_IDX}")
print(f"Max samples: {MAX_SAMPLES}")

In [None]:
# Record activations
recorder = ActivationRecorder(
    model_name=MODEL_NAME,
    layer_idx=LAYER_IDX,
    output_path=ACTIVATION_FILE,
    max_samples=MAX_SAMPLES,
    max_length=256  # Shorter sequences for tutorial
)

# This will take a few minutes
recorder.record(dataset_name="wikitext")

print(f"\nActivations saved to {ACTIVATION_FILE}")
file_size_mb = Path(ACTIVATION_FILE).stat().st_size / (1024*1024)
print(f"File size: {file_size_mb:.1f} MB")

### Explore the recorded data

In [None]:
# Load and examine the activation data
import pyarrow.parquet as pq

table = pq.read_table(ACTIVATION_FILE)
df = table.to_pandas()

print(f"Dataset shape: {df.shape}")
print(f"Columns: {list(df.columns)}")
print(f"\nFirst few rows:")
print(df.head())

# Check activation dimensions
first_activation = df.iloc[0]['activation']
print(f"\nActivation vector dimension: {len(first_activation)}")
print(f"Sample activation values: {first_activation[:10]}")

In [None]:
# Visualize activation statistics
activations = np.array([act for act in df['activation']])
print(f"Activations shape: {activations.shape}")

# Plot activation distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Histogram of activation values
axes[0].hist(activations.flatten(), bins=50, alpha=0.7)
axes[0].set_title('Distribution of Activation Values')
axes[0].set_xlabel('Activation Value')
axes[0].set_ylabel('Frequency')

# Mean activation per dimension
mean_activations = np.mean(activations, axis=0)
axes[1].plot(mean_activations)
axes[1].set_title('Mean Activation per Dimension')
axes[1].set_xlabel('Dimension')
axes[1].set_ylabel('Mean Activation')

plt.tight_layout()
plt.show()

print(f"Activation statistics:")
print(f"  Mean: {np.mean(activations):.4f}")
print(f"  Std: {np.std(activations):.4f}")
print(f"  Min: {np.min(activations):.4f}")
print(f"  Max: {np.max(activations):.4f}")

## Step 2: Train Sparse Autoencoder

Now we'll train an SAE to discover interpretable features in the recorded activations.

In [None]:
# SAE training configuration
SAE_OUTPUT_DIR = "tutorial_sae"
LATENT_DIM = 2048  # Expansion factor of ~2.7x (768 -> 2048 for DialoGPT-small)
SPARSITY_COEF = 5e-4  # L1 penalty coefficient
MAX_EPOCHS = 20  # Fewer epochs for tutorial

print(f"Training SAE with {LATENT_DIM} latent dimensions")
print(f"Sparsity coefficient: {SPARSITY_COEF}")
print(f"Max epochs: {MAX_EPOCHS}")

In [None]:
# Train the SAE
sae_trainer = train_sae(
    activation_path=ACTIVATION_FILE,
    output_dir=SAE_OUTPUT_DIR,
    layer_idx=LAYER_IDX,
    latent_dim=LATENT_DIM,
    sparsity_coef=SPARSITY_COEF,
    max_epochs=MAX_EPOCHS,
    gpus=1 if torch.cuda.is_available() else 0
)

print("\nTraining completed!")
print(f"SAE saved to {SAE_OUTPUT_DIR}/")

In [None]:
# Load the trained SAE and inspect its structure
import safetensors.torch as safetensors
import json

sae_path = Path(SAE_OUTPUT_DIR) / f"sae_layer_{LAYER_IDX}.safetensors"
metadata_path = Path(SAE_OUTPUT_DIR) / f"sae_layer_{LAYER_IDX}_metadata.json"

# Load metadata
with open(metadata_path, 'r') as f:
    metadata = json.load(f)

print("SAE Metadata:")
for key, value in metadata.items():
    print(f"  {key}: {value}")

# Load SAE weights
sae_weights = safetensors.load_file(sae_path)
print(f"\nSAE Weight shapes:")
for name, tensor in sae_weights.items():
    print(f"  {name}: {tensor.shape}")

## Step 3: Evaluate SAE Performance

Let's evaluate how well our SAE reconstructs the original activations.

In [None]:
# Evaluate the SAE
evaluator = SAEEvaluator(
    sae_path=str(sae_path),
    activation_path=ACTIVATION_FILE,
    layer_idx=LAYER_IDX
)

# Generate evaluation report
report = evaluator.generate_report("tutorial_eval_report.json")
evaluator.print_summary(report)

In [None]:
# Visualize SAE performance
metrics = report['reconstruction_metrics']
feature_analysis = report['feature_analysis']

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Reconstruction loss
axes[0, 0].bar(['Reconstruction Loss', 'Target (≤0.15)'], 
               [metrics['reconstruction_loss'], 0.15],
               color=['blue', 'red'])
axes[0, 0].set_title('Reconstruction Loss vs Target')
axes[0, 0].set_ylabel('MSE Loss')

# Explained variance
axes[0, 1].bar(['Explained Variance'], [metrics['explained_variance']], 
               color='green')
axes[0, 1].set_title('Explained Variance (R²)')
axes[0, 1].set_ylabel('R²')
axes[0, 1].set_ylim(0, 1)

# Feature activity
total_features = metrics['num_features']
dead_features = feature_analysis['total_dead_features']
active_features = total_features - dead_features

axes[1, 0].pie([active_features, dead_features], 
               labels=['Active Features', 'Dead Features'],
               autopct='%1.1f%%',
               colors=['lightgreen', 'lightcoral'])
axes[1, 0].set_title('Feature Activity Distribution')

# Top active features
if len(feature_analysis['most_active_features']) >= 5:
    top_features = feature_analysis['most_active_features'][:5]
    feature_indices = [f"F{f['feature_idx']}" for f in top_features]
    activation_freqs = [f['activation_frequency'] for f in top_features]
    
    axes[1, 1].bar(feature_indices, activation_freqs)
    axes[1, 1].set_title('Top 5 Most Active Features')
    axes[1, 1].set_ylabel('Activation Frequency')
    axes[1, 1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

# Check if we meet success criteria
success = metrics['reconstruction_loss'] <= 0.15
print(f"\n{'✅ SUCCESS' if success else '❌ NEEDS IMPROVEMENT'}: Reconstruction loss {'meets' if success else 'exceeds'} target threshold")

## Step 4: Analyze Discovered Features

Now let's examine what our SAE has learned by looking at which tokens activate each feature.

In [None]:
# Load the model and tokenizer for feature analysis
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

# Load the trained SAE
sae = SparseAutoencoder(
    input_dim=metadata['input_dim'],
    latent_dim=metadata['latent_dim'],
    tied_weights=metadata['tied_weights']
)
sae.load_state_dict(sae_weights)
sae.eval()

print(f"Loaded SAE: {metadata['input_dim']} → {metadata['latent_dim']}")
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")

In [None]:
# Create feature analyzer
analyzer = FeatureAnalyzer(
    sae_model=sae,
    tokenizer=tokenizer,
    activation_data_path=ACTIVATION_FILE,
    layer_idx=LAYER_IDX
)

# Analyze a few interesting features
interesting_features = [0, 1, 10, 50, 100]  # Sample some features

print("Analyzing sample features...")
for feature_idx in interesting_features:
    if feature_idx < sae.latent_dim:
        print(f"\n=== FEATURE {feature_idx} ===")
        
        try:
            # Get feature summary
            summary = analyzer.get_feature_summary(feature_idx)
            
            print(f"Weight norm: {summary['weight_norm']:.4f}")
            print(f"Sparsity: {summary['sparsity']:.4f}")
            print(f"Max activation: {summary['max_activation']:.4f}")
            print(f"Top tokens: {summary['top_tokens']}")
            
            # Show detailed token analysis
            if len(summary['top_token_details']) > 0:
                print("\nTop activating contexts:")
                for i, token_info in enumerate(summary['top_token_details'][:3]):
                    print(f"  {i+1}. '{token_info['token']}' (strength: {token_info['activation_strength']:.3f})")
                    print(f"     Context: ...{token_info['context_snippet']}...")
        
        except Exception as e:
            print(f"Error analyzing feature {feature_idx}: {e}")

### Visualize Feature Patterns

In [None]:
# Visualize encoder weight patterns for some features
encoder_weights = sae.encoder.weight.data.cpu().numpy()  # Shape: [latent_dim, input_dim]

# Select a few features to visualize
features_to_plot = [0, 1, 10, 50, 100]
features_to_plot = [f for f in features_to_plot if f < encoder_weights.shape[0]]

fig, axes = plt.subplots(len(features_to_plot), 1, figsize=(12, 3*len(features_to_plot)))
if len(features_to_plot) == 1:
    axes = [axes]

for i, feature_idx in enumerate(features_to_plot):
    weights = encoder_weights[feature_idx]
    
    # Plot the weight vector
    axes[i].plot(weights, alpha=0.7)
    axes[i].set_title(f'Feature {feature_idx} Encoder Weights')
    axes[i].set_xlabel('Input Dimension')
    axes[i].set_ylabel('Weight Value')
    axes[i].grid(True, alpha=0.3)
    
    # Add statistics
    norm = np.linalg.norm(weights)
    sparsity = (np.abs(weights) < 1e-6).mean()
    axes[i].text(0.02, 0.98, f'Norm: {norm:.3f}\nSparsity: {sparsity:.3f}', 
                transform=axes[i].transAxes, 
                verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

plt.tight_layout()
plt.show()

## Step 5: Live LoRA Patching

Now comes the exciting part - we'll create live patches to modify model behavior!

In [None]:
# Load the full model for patching
model = AutoModel.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
model.eval()

# Create LoRA patcher
patcher = LoRAPatcher(model)

print(f"Model loaded: {MODEL_NAME}")
print(f"Available for patching: {len(patcher._get_target_modules(LAYER_IDX, 'mlp'))} MLP modules")
print(f"Available for patching: {len(patcher._get_target_modules(LAYER_IDX, 'attention'))} attention modules")

In [None]:
# Select an interesting feature to patch
feature_to_patch = 10  # Choose a feature that showed interesting tokens
feature_vector = torch.tensor(encoder_weights[feature_to_patch], dtype=torch.float32)

print(f"Creating patch for feature {feature_to_patch}")
print(f"Feature vector shape: {feature_vector.shape}")
print(f"Feature vector norm: {feature_vector.norm():.4f}")

# Create a suppression patch (negative strength)
patch_id = patcher.create_feature_patch(
    feature_id=f"feature_{feature_to_patch}",
    layer_idx=LAYER_IDX,
    feature_vector=feature_vector,
    strength=-1.0,  # Suppress this feature
    rank=8,
    target_type="mlp"
)

print(f"Created patch: {patch_id}")
print(f"Active patches: {list(patcher.patch_metadata.keys())}")

In [None]:
# Test the patch effect on some sample text
test_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Hello, how are you doing today?",
    "Machine learning is a fascinating field of study.",
    "The weather is beautiful outside."
]

print("Testing patch effects...\n")

for i, text in enumerate(test_texts):
    print(f"=== Test {i+1}: {text} ===")
    
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    
    with torch.no_grad():
        # Get original activations
        patcher.disable_patch(patch_id)
        original_outputs = model(**inputs)
        original_hidden = original_outputs.last_hidden_state
        
        # Get patched activations
        patcher.enable_patch(patch_id)
        patched_outputs = model(**inputs)
        patched_hidden = patched_outputs.last_hidden_state
        
        # Compare outputs
        diff = torch.norm(patched_hidden - original_hidden).item()
        original_norm = torch.norm(original_hidden).item()
        relative_change = diff / original_norm * 100
        
        print(f"  Original norm: {original_norm:.4f}")
        print(f"  Difference norm: {diff:.4f}")
        print(f"  Relative change: {relative_change:.2f}%")
        print()

### Measure patch latency (Success Criteria)

In [None]:
import time

# Test patch toggle latency (target: <400ms)
test_text = "The quick brown fox jumps over the lazy dog."
inputs = tokenizer(test_text, return_tensors="pt")

latencies = []
num_tests = 10

print(f"Measuring patch toggle latency ({num_tests} trials)...")

for i in range(num_tests):
    # Start timing
    start_time = time.time()
    
    # Toggle patch and run inference
    patcher.enable_patch(patch_id)
    with torch.no_grad():
        outputs = model(**inputs)
    
    # End timing
    end_time = time.time()
    latency_ms = (end_time - start_time) * 1000
    latencies.append(latency_ms)

# Analyze latencies
mean_latency = np.mean(latencies)
std_latency = np.std(latencies)
max_latency = np.max(latencies)

print(f"\nLatency Results:")
print(f"  Mean: {mean_latency:.1f} ms")
print(f"  Std: {std_latency:.1f} ms")
print(f"  Max: {max_latency:.1f} ms")
print(f"  Target: <400 ms")

success = max_latency < 400
print(f"\n{'✅ SUCCESS' if success else '❌ NEEDS IMPROVEMENT'}: Latency {'meets' if success else 'exceeds'} target threshold")

# Plot latency distribution
plt.figure(figsize=(8, 4))
plt.hist(latencies, bins=max(5, num_tests//2), alpha=0.7, edgecolor='black')
plt.axvline(400, color='red', linestyle='--', label='Target (400ms)')
plt.axvline(mean_latency, color='green', linestyle='-', label=f'Mean ({mean_latency:.1f}ms)')
plt.xlabel('Latency (ms)')
plt.ylabel('Frequency')
plt.title('Patch Toggle + Inference Latency Distribution')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## Step 6: Export Results

Finally, let's export our trained SAE and patches for future use.

In [None]:
# Export directory
export_dir = Path("tutorial_exports")
export_dir.mkdir(exist_ok=True)

print(f"Exporting results to {export_dir}/")

In [None]:
# Export SAE weights and metadata
import shutil

# Copy SAE files
sae_export_path = export_dir / "sae"
sae_export_path.mkdir(exist_ok=True)

shutil.copy2(sae_path, sae_export_path / f"sae_layer_{LAYER_IDX}.safetensors")
shutil.copy2(metadata_path, sae_export_path / f"sae_layer_{LAYER_IDX}_metadata.json")

print(f"✅ SAE exported to {sae_export_path}/")

In [None]:
# Export patches
patch_export_path = export_dir / "patches"
patcher.save_patches(str(patch_export_path))

print(f"✅ Patches exported to {patch_export_path}/")

In [None]:
# Export evaluation report
report_export_path = export_dir / "evaluation_report.json"
shutil.copy2("tutorial_eval_report.json", report_export_path)

print(f"✅ Evaluation report exported to {report_export_path}")

In [None]:
# Create summary report
summary_report = {
    "tutorial_info": {
        "model_name": MODEL_NAME,
        "layer_idx": LAYER_IDX,
        "max_samples": MAX_SAMPLES,
        "latent_dim": LATENT_DIM,
        "sparsity_coef": SPARSITY_COEF
    },
    "results": {
        "reconstruction_loss": metrics['reconstruction_loss'],
        "explained_variance": metrics['explained_variance'],
        "meets_loss_target": metrics['reconstruction_loss'] <= 0.15,
        "mean_patch_latency_ms": mean_latency,
        "meets_latency_target": max_latency < 400,
        "dead_features": feature_analysis['total_dead_features'],
        "total_features": metrics['num_features']
    },
    "files_created": {
        "activations": ACTIVATION_FILE,
        "sae_weights": str(sae_path),
        "sae_metadata": str(metadata_path),
        "evaluation_report": "tutorial_eval_report.json",
        "exports": str(export_dir)
    }
}

summary_path = export_dir / "tutorial_summary.json"
with open(summary_path, 'w') as f:
    json.dump(summary_report, f, indent=2)

print(f"✅ Tutorial summary saved to {summary_path}")

## Summary

🎉 **Congratulations!** You've successfully completed the InterpretabilityWorkbench tutorial.

### What you accomplished:

1. ✅ **Recorded activations** from a language model using forward hooks
2. ✅ **Trained a sparse autoencoder** to discover interpretable features
3. ✅ **Evaluated SAE performance** against success criteria
4. ✅ **Analyzed discovered features** to understand what tokens activate them
5. ✅ **Created live LoRA patches** to modify model behavior in real-time
6. ✅ **Measured latency** to verify <400ms patch toggle performance
7. ✅ **Exported all results** for future use

### Next steps:

- **Scale up**: Try with larger models (Llama-2-7B, GPT-2-large)
- **Explore features**: Use the web UI to interactively browse features
- **Advanced patching**: Create patches targeting specific behaviors
- **Multi-layer analysis**: Compare features across different layers
- **Production deployment**: Use exported SAEs in your applications

### Resources:

- 📚 **Documentation**: See README.md for detailed API reference
- 🌐 **Web UI**: Launch with `microscope ui` for interactive exploration
- 🔬 **Advanced features**: Check out provenance graph analysis
- 🚀 **Scale up**: Use the evaluation script for larger experiments

Happy interpretability research! 🔍✨