# SAE Analysis Exploration

This notebook explores Sparse Autoencoder (SAE) analysis for the PVA-SAE project.
We'll analyze how language models internally represent code correctness using GemmaScope SAEs.

## Methodology Overview
1. Load generated dataset (correct vs incorrect solutions)
2. Extract model activations using GemmaScope SAEs  
3. Compute separation scores for latent dimensions
4. Filter out general language patterns (>2% activation on Pile dataset)
5. Identify distinguishing latent directions for code correctness

## Setup and Imports

In [None]:
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)

print("📚 Imports completed successfully!")
print(f"🔧 PyTorch version: {torch.__version__}")
print(f"🐼 Pandas version: {pd.__version__}")

## Load Generated Dataset

Load the dataset generated from Phase 1 (code generation and testing)

In [None]:
# Find the latest dataset file
data_dir = Path("data/datasets")
dataset_files = list(data_dir.glob("dataset_*.parquet"))

if dataset_files:
    # Get the most recent dataset
    latest_dataset = max(dataset_files, key=lambda x: x.stat().st_mtime)
    print(f"📊 Loading dataset: {latest_dataset.name}")
    
    # Load dataset
    df = pd.read_parquet(latest_dataset)
    print(f"✅ Dataset loaded successfully!")
    print(f"📏 Dataset shape: {df.shape}")
    
else:
    print("❌ No dataset files found in data/datasets/")
    print("💡 Generate a test dataset first with:")
    print("   python3 run.py phase 1 --model google/gemma-2-9b --start 0 --end 9")
    df = None

In [None]:
# Explore dataset structure
if df is not None:
    print("📋 Dataset Overview:")
    print(f"   Total records: {len(df)}")
    print(f"   Columns: {list(df.columns)}")
    print("\n📊 Correctness Distribution:")
    correctness_counts = df['is_correct'].value_counts()
    print(f"   Correct solutions: {correctness_counts.get(True, 0)}")
    print(f"   Incorrect solutions: {correctness_counts.get(False, 0)}")
    print(f"   Success rate: {df['is_correct'].mean()*100:.1f}%")
    
    # Display first few rows
    print("\n🔍 Sample Data:")
    display(df[['task_id', 'is_correct', 'passed_tests', 'total_tests', 'generation_time']].head())

## Load Model and Tokenizer

Load the same model used for generation to extract activations

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# Model configuration (should match the one used in dataset generation)
MODEL_NAME = "google/gemma-2-9b"  # Update this if using different model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"🤖 Loading model: {MODEL_NAME}")
print(f"🔧 Device: {DEVICE}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("✅ Tokenizer loaded successfully!")

# Load model (we'll load this when we need activations)
model = None  # Load on demand to save memory
print("⏳ Model will be loaded on demand for activation extraction")

## GemmaScope SAE Setup

Set up GemmaScope Sparse Autoencoders for analyzing model activations

In [None]:
# GemmaScope SAE configuration
# Note: Update these URLs/paths based on actual GemmaScope availability

SAE_CONFIG = {
    "model_name": "gemma-2-9b",
    "layer": 41,  # Final layer for residual stream analysis
    "sae_type": "JumpReLU",  # GemmaScope uses JumpReLU architecture
    "width": 65536,  # SAE width
    "activation_threshold": 0.05,  # Minimum activation threshold
}

print("⚙️ SAE Configuration:")
for key, value in SAE_CONFIG.items():
    print(f"   {key}: {value}")

# Placeholder for SAE loading
# TODO: Replace with actual GemmaScope SAE loading code
print("\n📝 Note: SAE loading code needs to be implemented based on GemmaScope API")
print("🔗 Refer to: https://github.com/google-deepmind/gemma_scope")

sae_model = None  # Placeholder

## Extract Model Activations

Extract activations at the final token position for each generated solution

In [None]:
def extract_activations(text: str, model, tokenizer, layer_idx: int = -1) -> torch.Tensor:
    """
    Extract activations from model at specified layer for final token position
    
    Args:
        text: Input text (prompt + generated code)
        model: Language model
        tokenizer: Model tokenizer
        layer_idx: Layer index (-1 for final layer)
        
    Returns:
        torch.Tensor: Activations at final token position
    """
    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt", truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Extract activations with no gradient computation
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        
        # Get activations from specified layer at final token position
        hidden_states = outputs.hidden_states[layer_idx]  # Shape: (batch, seq_len, hidden_dim)
        final_token_activations = hidden_states[0, -1, :]  # Shape: (hidden_dim,)
        
    return final_token_activations.cpu()

# Test activation extraction (when model is loaded)
if model is not None and df is not None and len(df) > 0:
    sample_text = df.iloc[0]['prompt'] + df.iloc[0]['generated_code']
    print(f"🧪 Testing activation extraction on sample text...")
    activations = extract_activations(sample_text, model, tokenizer)
    print(f"✅ Extracted activations shape: {activations.shape}")
else:
    print("⏳ Activation extraction will be tested when model and data are loaded")

## SAE Activation Analysis

Apply SAEs to extract sparse representations and analyze patterns

In [None]:
def apply_sae(activations: torch.Tensor, sae_model) -> torch.Tensor:
    """
    Apply SAE to extract sparse representation
    
    Args:
        activations: Raw model activations
        sae_model: Trained SAE model
        
    Returns:
        torch.Tensor: Sparse SAE activations
    """
    # TODO: Implement based on GemmaScope SAE API
    # This is a placeholder implementation
    
    with torch.no_grad():
        # Apply SAE encoder
        sparse_activations = sae_model.encode(activations)
        
        # Apply activation threshold
        threshold = SAE_CONFIG["activation_threshold"]
        sparse_activations = torch.where(
            sparse_activations > threshold, 
            sparse_activations, 
            torch.zeros_like(sparse_activations)
        )
        
    return sparse_activations

# Placeholder for SAE activation extraction
print("📝 SAE activation extraction function defined")
print("🔧 Implementation pending GemmaScope SAE integration")

## Separation Score Calculation

Calculate separation scores to identify latent dimensions that distinguish correct vs incorrect solutions

In [None]:
def calculate_separation_scores(correct_activations: List[torch.Tensor], 
                              incorrect_activations: List[torch.Tensor]) -> torch.Tensor:
    """
    Calculate separation scores for each SAE dimension
    
    Args:
        correct_activations: List of activations for correct solutions
        incorrect_activations: List of activations for incorrect solutions
        
    Returns:
        torch.Tensor: Separation scores for each dimension
    """
    # Stack activations
    correct_stack = torch.stack(correct_activations)  # Shape: (n_correct, sae_width)
    incorrect_stack = torch.stack(incorrect_activations)  # Shape: (n_incorrect, sae_width)
    
    # Calculate means for each dimension
    correct_means = correct_stack.mean(dim=0)  # Shape: (sae_width,)
    incorrect_means = incorrect_stack.mean(dim=0)  # Shape: (sae_width,)
    
    # Calculate standard deviations
    correct_stds = correct_stack.std(dim=0) + 1e-8  # Add small epsilon for numerical stability
    incorrect_stds = incorrect_stack.std(dim=0) + 1e-8
    
    # Calculate pooled standard deviation
    n_correct = len(correct_activations)
    n_incorrect = len(incorrect_activations)
    pooled_std = torch.sqrt(((n_correct - 1) * correct_stds**2 + 
                           (n_incorrect - 1) * incorrect_stds**2) / 
                          (n_correct + n_incorrect - 2))
    
    # Calculate separation score (Cohen's d)
    separation_scores = torch.abs(correct_means - incorrect_means) / pooled_std
    
    return separation_scores

def analyze_top_separating_dimensions(separation_scores: torch.Tensor, 
                                    top_k: int = 20) -> Dict:
    """
    Analyze top separating dimensions
    
    Args:
        separation_scores: Calculated separation scores
        top_k: Number of top dimensions to analyze
        
    Returns:
        dict: Analysis results
    """
    # Get top separating dimensions
    top_scores, top_indices = torch.topk(separation_scores, top_k)
    
    results = {
        'top_indices': top_indices.tolist(),
        'top_scores': top_scores.tolist(),
        'total_dimensions': len(separation_scores),
        'mean_separation': separation_scores.mean().item(),
        'std_separation': separation_scores.std().item(),
        'max_separation': separation_scores.max().item()
    }
    
    return results

print("🧮 Separation score calculation functions defined")
print("📊 Ready to analyze latent dimension separability")

## Pile Dataset Filtering

Filter out dimensions that activate frequently on general language (Pile dataset) to focus on code-specific patterns

In [None]:
def filter_general_language_patterns(separation_scores: torch.Tensor,
                                    pile_activation_rates: Optional[torch.Tensor] = None,
                                    pile_threshold: float = 0.02) -> torch.Tensor:
    """
    Filter out dimensions that activate frequently on general language
    
    Args:
        separation_scores: Calculated separation scores
        pile_activation_rates: Activation rates on Pile dataset (if available)
        pile_threshold: Threshold for filtering (default 2%)
        
    Returns:
        torch.Tensor: Filtered separation scores
    """
    if pile_activation_rates is not None:
        # Filter dimensions that activate > pile_threshold on Pile dataset
        code_specific_mask = pile_activation_rates <= pile_threshold
        filtered_scores = separation_scores * code_specific_mask.float()
        
        n_filtered = (~code_specific_mask).sum().item()
        print(f"🔍 Filtered out {n_filtered} dimensions (>{pile_threshold*100:.1f}% activation on Pile)")
        
    else:
        print("⚠️  Pile activation rates not available - using unfiltered scores")
        print("💡 Consider loading pre-computed Pile activation rates for better filtering")
        filtered_scores = separation_scores
    
    return filtered_scores

# Placeholder for Pile dataset activation rates
# TODO: Load pre-computed activation rates from GemmaScope
pile_rates = None

print("🔧 General language filtering function defined")
print("📝 Note: Pile activation rates need to be loaded from GemmaScope")

## Visualization Functions

Functions for visualizing SAE analysis results

In [None]:
def plot_separation_distribution(separation_scores: torch.Tensor, title: str = "Separation Score Distribution"):
    """
    Plot distribution of separation scores
    """
    plt.figure(figsize=(12, 6))
    
    # Plot histogram
    plt.subplot(1, 2, 1)
    plt.hist(separation_scores.numpy(), bins=50, alpha=0.7, edgecolor='black')
    plt.xlabel('Separation Score (Cohen\'s d)')
    plt.ylabel('Frequency')
    plt.title(f'{title}\nHistogram')
    plt.grid(True, alpha=0.3)
    
    # Plot log scale
    plt.subplot(1, 2, 2)
    non_zero_scores = separation_scores[separation_scores > 0]
    plt.hist(non_zero_scores.numpy(), bins=50, alpha=0.7, edgecolor='black')
    plt.xlabel('Separation Score (Cohen\'s d)')
    plt.ylabel('Frequency')
    plt.title('Non-zero Scores Only')
    plt.yscale('log')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def plot_top_dimensions(analysis_results: Dict, top_k: int = 10):
    """
    Plot top separating dimensions
    """
    plt.figure(figsize=(12, 8))
    
    top_indices = analysis_results['top_indices'][:top_k]
    top_scores = analysis_results['top_scores'][:top_k]
    
    # Bar plot
    plt.subplot(2, 1, 1)
    bars = plt.bar(range(len(top_scores)), top_scores, alpha=0.8)
    plt.xlabel('Rank')
    plt.ylabel('Separation Score')
    plt.title(f'Top {top_k} Separating Dimensions')
    plt.xticks(range(len(top_scores)), [f'Dim {idx}' for idx in top_indices], rotation=45)
    plt.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for i, (bar, score) in enumerate(zip(bars, top_scores)):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{score:.2f}', ha='center', va='bottom', fontsize=9)
    
    # Summary statistics
    plt.subplot(2, 1, 2)
    stats_text = f"""
    Analysis Summary:
    • Total SAE dimensions: {analysis_results['total_dimensions']:,}
    • Mean separation score: {analysis_results['mean_separation']:.3f}
    • Std separation score: {analysis_results['std_separation']:.3f}
    • Max separation score: {analysis_results['max_separation']:.3f}
    • Top dimension index: {analysis_results['top_indices'][0]}
    """
    plt.text(0.1, 0.5, stats_text, fontsize=12, verticalalignment='center',
             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

def plot_activation_comparison(correct_activations: List[torch.Tensor],
                             incorrect_activations: List[torch.Tensor],
                             dimension_idx: int):
    """
    Compare activations for a specific dimension between correct and incorrect solutions
    """
    correct_vals = [act[dimension_idx].item() for act in correct_activations]
    incorrect_vals = [act[dimension_idx].item() for act in incorrect_activations]
    
    plt.figure(figsize=(10, 6))
    
    # Box plot comparison
    plt.subplot(1, 2, 1)
    data_to_plot = [correct_vals, incorrect_vals]
    box_plot = plt.boxplot(data_to_plot, labels=['Correct', 'Incorrect'], patch_artist=True)
    box_plot['boxes'][0].set_facecolor('lightgreen')
    box_plot['boxes'][1].set_facecolor('lightcoral')
    plt.ylabel('Activation Value')
    plt.title(f'Dimension {dimension_idx} Activation Comparison')
    plt.grid(True, alpha=0.3)
    
    # Histogram comparison
    plt.subplot(1, 2, 2)
    plt.hist(correct_vals, alpha=0.6, label='Correct', color='green', bins=10)
    plt.hist(incorrect_vals, alpha=0.6, label='Incorrect', color='red', bins=10)
    plt.xlabel('Activation Value')
    plt.ylabel('Frequency')
    plt.title(f'Dimension {dimension_idx} Distribution')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print("📊 Visualization functions defined")
print("🎨 Ready to create SAE analysis plots")

## Main SAE Analysis Pipeline

Run the complete SAE analysis pipeline on the generated dataset

In [None]:
def run_sae_analysis_pipeline(df: pd.DataFrame) -> Dict:
    """
    Run complete SAE analysis pipeline
    
    Args:
        df: Dataset with generated solutions
        
    Returns:
        dict: Complete analysis results
    """
    print("🚀 Starting SAE Analysis Pipeline")
    print("=" * 50)
    
    # Step 1: Prepare data
    correct_df = df[df['is_correct'] == True]
    incorrect_df = df[df['is_correct'] == False]
    
    print(f"📊 Data prepared:")
    print(f"   Correct solutions: {len(correct_df)}")
    print(f"   Incorrect solutions: {len(incorrect_df)}")
    
    if len(correct_df) == 0 or len(incorrect_df) == 0:
        print("❌ Need both correct and incorrect solutions for analysis")
        return {}
    
    # Step 2: Extract activations (placeholder)
    print("\n🧠 Extracting model activations...")
    print("⚠️  This step requires actual model and SAE loading")
    
    # Placeholder: In real implementation, extract activations here
    # correct_activations = [extract_sae_activations(row) for _, row in correct_df.iterrows()]
    # incorrect_activations = [extract_sae_activations(row) for _, row in incorrect_df.iterrows()]
    
    # Mock data for demonstration
    sae_width = SAE_CONFIG["width"]
    correct_activations = [torch.randn(sae_width) * 0.1 for _ in range(len(correct_df))]
    incorrect_activations = [torch.randn(sae_width) * 0.1 for _ in range(len(incorrect_df))]
    
    # Add some signal to demonstrate separation
    signal_dims = [100, 500, 1000, 2000, 5000]  # Mock signal dimensions
    for act in correct_activations:
        for dim in signal_dims:
            act[dim] += torch.randn(1) * 0.2 + 0.5  # Add positive signal
    
    for act in incorrect_activations:
        for dim in signal_dims:
            act[dim] += torch.randn(1) * 0.2 - 0.3  # Add negative signal
    
    print(f"✅ Mock activations generated (shape: {sae_width})")
    
    # Step 3: Calculate separation scores
    print("\n🧮 Calculating separation scores...")
    separation_scores = calculate_separation_scores(correct_activations, incorrect_activations)
    print(f"✅ Separation scores calculated")
    
    # Step 4: Filter general language patterns
    print("\n🔍 Filtering general language patterns...")
    filtered_scores = filter_general_language_patterns(separation_scores, pile_rates)
    
    # Step 5: Analyze top dimensions
    print("\n📊 Analyzing top separating dimensions...")
    analysis_results = analyze_top_separating_dimensions(filtered_scores, top_k=20)
    
    # Step 6: Create visualizations
    print("\n🎨 Creating visualizations...")
    plot_separation_distribution(filtered_scores, "Filtered Separation Scores")
    plot_top_dimensions(analysis_results, top_k=10)
    
    # Analyze a specific top dimension
    if analysis_results['top_indices']:
        top_dim = analysis_results['top_indices'][0]
        plot_activation_comparison(correct_activations, incorrect_activations, top_dim)
    
    print("\n✅ SAE Analysis Pipeline Complete!")
    print("=" * 50)
    
    return {
        'separation_scores': separation_scores,
        'filtered_scores': filtered_scores,
        'analysis_results': analysis_results,
        'correct_activations': correct_activations,
        'incorrect_activations': incorrect_activations
    }

print("🔧 SAE analysis pipeline function defined")
print("📋 Ready to run analysis when dataset is loaded")

## Run Analysis

Execute the SAE analysis on the loaded dataset

In [None]:
# Run the analysis if dataset is available
if df is not None:
    print("🚀 Running SAE Analysis Pipeline...")
    analysis_results = run_sae_analysis_pipeline(df)
    
    if analysis_results:
        print("\n🎉 Analysis completed successfully!")
        print("📊 Results available in 'analysis_results' variable")
    else:
        print("❌ Analysis failed - check data requirements")
else:
    print("❌ No dataset loaded. Please generate a dataset first:")
    print("   python3 run.py phase 1 --model google/gemma-2-9b --start 0 --end 9")

## Next Steps and TODOs

### Implementation TODOs:

1. **GemmaScope Integration**
   - Load actual GemmaScope SAE models
   - Implement SAE activation extraction
   - Load pre-computed Pile activation rates

2. **Model Loading**
   - Load the actual language model for activation extraction
   - Implement efficient activation caching
   - Handle memory management for large models

3. **Enhanced Analysis**
   - Implement statistical significance testing
   - Add more sophisticated filtering methods
   - Compute AUROC and F1 scores for validation

4. **Integration with Main Codebase**
   - Migrate working code to `phase2_sae_analysis/sae_analyzer.py`
   - Add configuration management
   - Implement proper logging and error handling

### Expected Outcomes:
- Identification of SAE dimensions that distinguish code correctness
- Separation scores > 0.5 for meaningful latent directions
- Filtered dimensions specific to code patterns (not general language)
- Foundation for Phase 3 model steering experiments

In [None]:
# Summary of current state
print("📋 Current Implementation Status:")
print("✅ Dataset loading and exploration")
print("✅ Separation score calculation")
print("✅ Visualization functions")
print("✅ Analysis pipeline structure")
print("⏳ GemmaScope SAE integration (pending)")
print("⏳ Actual model activation extraction (pending)")
print("⏳ Pile dataset filtering (pending)")

print("\n🚀 Ready for GemmaScope integration and real data analysis!")