# SAE Analysis: Cognitive Patterns → Features → Neuronpedia

This notebook demonstrates how to analyze your pre-computed activations using Sparse Autoencoders (SAEs) and explore the results through Neuronpedia.

## Overview
1. **Load SAE**: Find and load appropriate SAE for your model/layer
2. **Process Activations**: Run your activations through the SAE
3. **Find Top Features**: Identify most active features
4. **Neuronpedia Integration**: View feature interpretations
5. **Advanced Analysis**: Steering, ablation, comparisons

In [None]:
# Install required packages (run once)
!pip install sae-lens transformer-lens plotly

In [None]:
# Imports
import torch
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import IFrame, display, HTML
import json
from pathlib import Path

# Import our SAE analysis module
from sae_analysis import SAEActivationAnalyzer, SAEAnalysisConfig

print("Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device available: {'MPS' if torch.backends.mps.is_available() else 'CUDA' if torch.cuda.is_available() else 'CPU'}")

## Step 1: Configuration and Setup

In [None]:
# Configure analysis
config = SAEAnalysisConfig(
    device="mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu",
    sae_release="gpt2-small-res-jb",  # Start with GPT-2 SAEs
    sae_id="blocks.7.hook_resid_pre",   # Layer 7 residual stream
    top_k_features=20,
    analysis_output_dir="sae_analysis_results"
)

# Initialize analyzer
analyzer = SAEActivationAnalyzer(config)

print(f"Configuration:")
print(f"  Device: {config.device}")
print(f"  SAE: {config.sae_release}/{config.sae_id}")
print(f"  Output directory: {config.analysis_output_dir}")

## Step 2: Discover Available SAEs

In [None]:
# Discover all available SAEs
available_saes = analyzer.discover_available_saes()

# Show available models
model_counts = available_saes['model_name'].value_counts()
print("Available SAE models:")
for model, count in model_counts.head(10).items():
    print(f"  {model}: {count} SAEs")

In [None]:
# Find SAEs for specific models
gpt2_saes = analyzer.find_matching_saes("gpt2", "resid")
display(gpt2_saes)

## Step 3: Load SAE

In [None]:
# Load the SAE
sae = analyzer.load_sae()

print(f"SAE Configuration:")
print(f"  Hook name: {sae.cfg.hook_name}")
print(f"  Input dims (d_in): {sae.cfg.d_in}")
print(f"  SAE dims (d_sae): {sae.cfg.d_sae}")
print(f"  Expansion factor: {sae.cfg.d_sae / sae.cfg.d_in:.1f}x")

## Step 4: Load Your Activations

In [None]:
# Find your activation files
activation_files = list(Path("activations").glob("*.pt"))
print(f"Found {len(activation_files)} activation files:")
for i, file_path in enumerate(activation_files):
    print(f"  {i}: {file_path.name}")

In [None]:
# Choose which activation file to analyze
file_index = 0  # Change this to select different files
activation_path = activation_files[file_index]

print(f"Loading: {activation_path}")
activations = analyzer.load_activations(str(activation_path))

print(f"Activation statistics:")
print(f"  Shape: {activations.shape}")
print(f"  Device: {activations.device}")
print(f"  Dtype: {activations.dtype}")
print(f"  Min/Max: {activations.min():.4f} / {activations.max():.4f}")
print(f"  Mean/Std: {activations.mean():.4f} ± {activations.std():.4f}")

In [None]:
# Handle dimension compatibility
original_shape = activations.shape
actual_dim = activations.shape[-1]
expected_dim = sae.cfg.d_in

print(f"Dimension compatibility check:")
print(f"  Activation dims: {actual_dim}")
print(f"  SAE expects: {expected_dim}")

if actual_dim != expected_dim:
    print(f"  ⚠️  Dimension mismatch detected!")
    
    if actual_dim > expected_dim:
        print(f"  Truncating to first {expected_dim} dimensions")
        activations = activations[..., :expected_dim]
    else:
        print(f"  Padding to {expected_dim} dimensions with zeros")
        padding_shape = list(activations.shape)
        padding_shape[-1] = expected_dim - actual_dim
        padding = torch.zeros(padding_shape, device=activations.device, dtype=activations.dtype)
        activations = torch.cat([activations, padding], dim=-1)
        
    print(f"  New shape: {activations.shape}")
else:
    print(f"  ✅ Dimensions compatible!")

## Step 5: Process Through SAE

In [None]:
# Process activations through SAE
results = analyzer.process_activations(activations)

print("Processing results:")
for key, tensor in results.items():
    if isinstance(tensor, torch.Tensor):
        print(f"  {key}: {tensor.shape} - {tensor.dtype}")
    else:
        print(f"  {key}: {type(tensor)}")

## Step 6: Find Top Features

In [None]:
# Find top activating features
values, indices = analyzer.find_top_features(results['feature_activations'], 
                                           position=-1,  # Last token position
                                           top_k=config.top_k_features)

# Create a dataframe for easier viewing
top_features_df = pd.DataFrame({
    'rank': range(1, len(values) + 1),
    'feature_idx': indices.cpu().numpy(),
    'activation_value': values.cpu().numpy()
})

display(top_features_df)

## Step 7: Visualize Feature Activations

In [None]:
# Create comprehensive visualization
fig = analyzer.visualize_feature_activations(
    results['feature_activations'], 
    position=-1,
    title=f"Feature Analysis: {activation_path.name}"
)

fig.show()

In [None]:
# Additional visualization: Top features bar chart
fig_bar = go.Figure(data=[
    go.Bar(
        x=[f"F{idx}" for idx in indices[:10].cpu()],
        y=values[:10].cpu(),
        text=[f"{val:.3f}" for val in values[:10].cpu()],
        textposition='outside'
    )
])

fig_bar.update_layout(
    title="Top 10 Feature Activations",
    xaxis_title="Feature Index",
    yaxis_title="Activation Value",
    showlegend=False
)

fig_bar.show()

## Step 8: Neuronpedia Integration

In [None]:
# Generate Neuronpedia URLs for top features
print("Neuronpedia Dashboard URLs:")
print("="*50)

for i, (val, idx) in enumerate(zip(values[:10], indices[:10])):
    url = analyzer.get_neuronpedia_dashboard_url(int(idx))
    print(f"{i+1:2d}. Feature {int(idx):4d} (activation: {float(val):6.4f})")
    print(f"    {url}")
    print()

In [None]:
# Display interactive Neuronpedia dashboards for top 3 features
print("Interactive Neuronpedia Dashboards:")

for i in range(min(3, len(indices))):
    feature_idx = int(indices[i])
    activation_val = float(values[i])
    
    print(f"\n{'='*60}")
    print(f"Feature {feature_idx} - Activation: {activation_val:.4f}")
    print(f"{'='*60}")
    
    # Create and display iframe
    dashboard = analyzer.display_feature_dashboard(feature_idx, width=1200, height=400)
    if dashboard:
        display(dashboard)
    else:
        url = analyzer.get_neuronpedia_dashboard_url(feature_idx)
        display(HTML(f'<a href="{url}" target="_blank">Open Feature {feature_idx} in new tab</a>'))

## Step 9: Download and Search Feature Explanations

In [None]:
# Download feature explanations
try:
    explanations = analyzer.download_feature_explanations("gpt2-small", "7-res-jb")
    if explanations is not None:
        print(f"Successfully downloaded {len(explanations)} feature explanations")
        print("\nSample explanations:")
        display(explanations[['feature', 'description']].head())
    else:
        print("Could not download explanations")
except Exception as e:
    print(f"Error downloading explanations: {e}")

In [None]:
# Search for features related to cognitive patterns
if analyzer.feature_explanations is not None:
    search_terms = [
        "thought", "thinking", "cognitive", "mental", "mind", 
        "emotion", "feeling", "pattern", "behavior", "psychology",
        "negative", "positive", "anxiety", "depression", "stress"
    ]
    
    print("Searching for cognitively relevant features:")
    print("="*50)
    
    all_matches = []
    for term in search_terms:
        matches = analyzer.search_features_by_description(term)
        if matches is not None and len(matches) > 0:
            print(f"\n'{term}': {len(matches)} matches")
            if len(matches) > 0:
                example = matches.iloc[0]
                print(f"  Example: Feature {example['feature']} - {example['description'][:100]}...")
                all_matches.extend(matches['feature'].tolist())
    
    # Check if any of your top features match cognitive patterns
    top_feature_indices = [int(idx) for idx in indices[:10]]
    cognitive_matches = [f for f in top_feature_indices if f in all_matches]
    
    if cognitive_matches:
        print(f"\n🎯 Cognitive pattern matches in your top features:")
        for feature_idx in cognitive_matches:
            explanation_row = analyzer.feature_explanations[
                analyzer.feature_explanations['feature'] == feature_idx
            ]
            if not explanation_row.empty:
                desc = explanation_row.iloc[0]['description']
                rank = top_feature_indices.index(feature_idx) + 1
                print(f"  Rank {rank} - Feature {feature_idx}: {desc}")
else:
    print("No explanations available for searching")

## Step 10: Advanced Analysis

In [None]:
# Feature Ablation Analysis
print("Feature Ablation Analysis")
print("="*30)

# Ablate top 3 features
top_3_features = [int(idx) for idx in indices[:3]]
print(f"Ablating features: {top_3_features}")

ablated_reconstructions = analyzer.perform_feature_ablation(
    results['feature_activations'], 
    top_3_features
)

# Compare reconstruction quality
original_error = torch.nn.functional.mse_loss(activations, results['reconstructions'])
ablated_error = torch.nn.functional.mse_loss(activations, ablated_reconstructions)
error_increase = ablated_error - original_error

print(f"\nReconstruction Analysis:")
print(f"  Original MSE: {original_error:.6f}")
print(f"  Ablated MSE:  {ablated_error:.6f}")
print(f"  Error increase: {error_increase:.6f} ({error_increase/original_error*100:.1f}% worse)")

if error_increase > 0.001:  # Significant increase
    print(f"  🔍 These features appear important for reconstruction!")
else:
    print(f"  💭 These features may be less critical for reconstruction")

In [None]:
# Feature Steering Analysis
print("Feature Steering Analysis")
print("="*30)

# Create steering interventions
steering_configs = [
    (int(indices[0]), 2.0),   # Amplify top feature
    (int(indices[1]), -1.0),  # Suppress second feature
    (int(indices[2]), 1.5)    # Moderately amplify third feature
]

print(f"Steering interventions:")
for feature_idx, strength in steering_configs:
    direction = "Amplify" if strength > 0 else "Suppress"
    print(f"  Feature {feature_idx}: {direction} by {abs(strength):.1f}x")

# Apply steering
steered_activations = analyzer.apply_steering(activations, steering_configs)
steered_results = analyzer.process_activations(steered_activations)

# Compare before/after
print(f"\nBefore/After Steering Comparison:")
print(f"  Original avg L0:  {results['l0_norm'].mean():.2f}")
print(f"  Steered avg L0:   {steered_results['l0_norm'].mean():.2f}")
print(f"  Change in sparsity: {steered_results['l0_norm'].mean() - results['l0_norm'].mean():.2f}")

# Find top features after steering
steered_values, steered_indices = analyzer.find_top_features(
    steered_results['feature_activations'], top_k=10
)

print(f"\nTop features after steering:")
for i, (val, idx) in enumerate(zip(steered_values[:5], steered_indices[:5])):
    original_rank = "NEW" if int(idx) not in indices[:10].tolist() else f"#{indices[:10].tolist().index(int(idx)) + 1}"
    print(f"  {i+1}. Feature {int(idx)} ({original_rank}): {float(val):.4f}")

In [None]:
# Comparative visualization
fig = analyzer.compare_activations(
    results['feature_activations'][0, -1, :], 
    steered_results['feature_activations'][0, -1, :],
    labels=("Original", "Steered")
)

fig.show()

## Step 11: Save Results and Generate Report

In [None]:
# Save comprehensive results
timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
file_stem = activation_path.stem

# Save analysis results
results_filename = f"analysis_results_{file_stem}_{timestamp}.pt"
results_path = analyzer.save_analysis_results(results, results_filename)

# Generate report
report_filename = f"analysis_report_{file_stem}_{timestamp}.md"
report_path = analyzer.generate_analysis_report(results, (values, indices), report_filename)

# Create interactive summary
summary = {
    "metadata": {
        "activation_file": str(activation_path),
        "original_shape": list(original_shape),
        "processed_shape": list(activations.shape),
        "analysis_timestamp": timestamp,
        "sae_config": {
            "release": config.sae_release,
            "sae_id": config.sae_id,
            "d_in": sae.cfg.d_in,
            "d_sae": sae.cfg.d_sae,
            "hook_name": sae.cfg.hook_name
        }
    },
    "statistics": {
        "avg_sparsity": float(results['l0_norm'].mean()),
        "avg_reconstruction_error": float(results['reconstruction_error'].mean()),
        "activation_range": {
            "min": float(activations.min()),
            "max": float(activations.max()),
            "mean": float(activations.mean()),
            "std": float(activations.std())
        }
    },
    "top_features": [
        {
            "rank": i + 1,
            "feature_idx": int(idx),
            "activation_value": float(val),
            "neuronpedia_url": analyzer.get_neuronpedia_dashboard_url(int(idx))
        }
        for i, (val, idx) in enumerate(zip(values, indices))
    ]
}

# Add explanations if available
if analyzer.feature_explanations is not None:
    for feature_data in summary["top_features"]:
        feature_idx = feature_data["feature_idx"]
        explanation_row = analyzer.feature_explanations[
            analyzer.feature_explanations['feature'] == feature_idx
        ]
        if not explanation_row.empty:
            feature_data["description"] = explanation_row.iloc[0]['description']

# Save summary
summary_path = Path(config.analysis_output_dir) / f"interactive_summary_{file_stem}_{timestamp}.json"
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

print(f"Analysis Complete! 🎉")
print(f"\nGenerated files:")
print(f"  📊 Results: {results_path.name}")
print(f"  📝 Report: {report_path.name}")
print(f"  📋 Summary: {summary_path.name}")
print(f"\nAll files saved in: {config.analysis_output_dir}/")

## Summary and Next Steps

### What We Accomplished:
1. ✅ Loaded a pre-trained SAE matching your model architecture
2. ✅ Processed your pre-computed activations through the SAE  
3. ✅ Identified the most active features for your cognitive patterns
4. ✅ Generated Neuronpedia dashboard links for feature interpretation
5. ✅ Performed advanced analysis (ablation, steering)
6. ✅ Saved comprehensive results and generated reports

### Key Insights:
- Your activations were successfully processed through the SAE
- Identified top active features that may relate to cognitive patterns
- Generated interpretable visualizations and Neuronpedia links
- Demonstrated feature manipulation techniques

### Next Steps:
1. **Explore Neuronpedia**: Click the dashboard URLs to understand what each feature represents
2. **Compare Patterns**: Run this analysis on different activation files to compare cognitive patterns
3. **Feature Analysis**: Investigate which features consistently activate across similar cognitive patterns
4. **Model Understanding**: Use steering/ablation to understand which features are most important
5. **Research Applications**: Use these insights to understand the neural basis of cognitive transformations

### Troubleshooting:
- If Neuronpedia dashboards don't load, try opening the URLs directly in a new browser tab
- For dimension mismatches, adjust the SAE selection or activation preprocessing
- If explanations fail to download, the Neuronpedia API might be temporarily unavailable