In [4]:
#!/usr/bin/env python3
"""
SAE-based representation shift analysis using SAE Lens library
for comparing Gemma and PaliGemma 2 with Google's Gemma Scope SAEs.
"""

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
try:
    from sae_lens import SAE
except:
#     !pip install --upgrade pip setuptools wheel
    !pip install --pre sae-lens
from sae_lens import SAE
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List, Optional
from dataclasses import dataclass
import seaborn as sns
torch.set_grad_enabled(False)
from utils import *

In [5]:
from huggingface_hub import login
import os
from dotenv import find_dotenv, load_dotenv
load_dotenv(find_dotenv())
HUGGING_FACE_API_KEY = os.getenv('HUGGING_FACE_API_KEY')
# Paste your token between the quotes:
login(token=HUGGING_FACE_API_KEY)

In [6]:
def main():
    """Main demonstration of SAE-based representation shift analysis."""
    print("🚀 SAE Lens - Gemma Scope Representation Shift Analysis")
    print("=" * 60)
    
    # Configuration
    LAYER = 12  # Middle layer for analysis
    WIDTH = "16k"  # SAE width
    MODEL_SIZE = "2b"  # Using 2B models for faster demo
    SUFFIX = "canonical"  # Use canonical SAEs (most stable)
    
    # Test texts covering different domains
    test_texts = [
        "The quick brown fox jumps over the lazy dog.",
        "In machine learning, neural networks learn complex patterns from data.",
        "The economy has shown resilience despite global challenges.",
        "Climate change affects weather patterns around the world.",
        "Artificial intelligence transforms how we work and live."
    ]
    
    try:
        # Initialize analyzer
        analyzer = GemmaScopeAnalyzer(
            layer=LAYER, 
            width=WIDTH, 
            model_size=MODEL_SIZE,
            suffix=SUFFIX
        )
        
        # Model names (adjust these based on available models)
        model1_name = "google/gemma-2-2b"  # Base Gemma 2
#         model2_name = "google/gemma-2-2b-it"  # Instruction-tuned version
        model2_name = 'google/paligemma2-3b-pt-224'  # What does suffix 224 mean? There's also one ending with "448"
        # Note: Replace with actual PaliGemma when available
        
        print(f"\n🔬 Analysis Configuration:")
        print(f"   Layer: {LAYER}")
        print(f"   SAE Width: {WIDTH}")
        print(f"   Model Size: {MODEL_SIZE}")
        print(f"   SAE Suffix: {SUFFIX}")
        print(f"   Test Texts: {len(test_texts)}")
        print()
        
        # Run analysis
        results = analyzer.analyze_models(model1_name, model2_name, test_texts)
        
        # Print results
        print("\n📊 ANALYSIS RESULTS:")
        print("=" * 40)
        
        agg = results['aggregate']
        
        print("\nAverage SAE Metrics - Model 1:")
        for key, value in agg['avg_model1_metrics'].items():
            print(f"  {key}: {value:.4f}")
        
        print("\nAverage SAE Metrics - Model 2:")
        for key, value in agg['avg_model2_metrics'].items():
            print(f"  {key}: {value:.4f}")
        
        print("\nAverage Representation Shift Metrics:")
        for key, value in agg['avg_shift_metrics'].items():
            print(f"  {key}: {value:.4f}")
        
        # Generate interpretations
        interpretations = analyzer.interpret_results(results)
        
        print("\n🔍 INTERPRETATIONS:")
        print("=" * 40)
        for aspect, interpretation in interpretations.items():
            print(f"{aspect.replace('_', ' ').title()}: {interpretation}")
        
        # Create visualization
        analyzer.visualize_results(results)
        
        print(f"\n✅ Analysis complete!")
        print(f"📈 Visualization saved as '../figs_tabs/sae_analysis_{model1_name}_{model2_name}.png'")
        print(f"📋 Analyzed {len(test_texts)} texts across layer {LAYER}")
        
    except Exception as e:
        print(f"❌ Error during analysis: {e}")
        print("\n💡 Troubleshooting tips:")
        print("   1. Install SAE Lens: pip install sae-lens")
        print("   2. Ensure you have sufficient GPU memory")
        print("   3. Try with smaller models or fewer texts")
        print("   4. Check model names are correct and accessible")

if __name__ == "__main__":
    main()

# Installation requirements:
"""
pip install sae-lens transformers torch matplotlib seaborn numpy

# For CUDA support (recommended):
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
"""

🚀 SAE Lens - Gemma Scope Representation Shift Analysis
🔧 Initializing GemmaScope SAE (Layer 12, Width 16k, Size 2b, Suffix canonical)
📥 Loading Gemma Scope SAE...
   Loading from release: gemma-scope-2b-pt-res-canonical
   SAE ID: layer_12/width_16k/canonical
✅ SAE loaded successfully!
   - Dictionary size: 16384
   - Model dimension: 2304

🔬 Analysis Configuration:
   Layer: 12
   SAE Width: 16k
   Model Size: 2b
   SAE Suffix: canonical
   Test Texts: 5

🚀 Starting comparative analysis
   Model 1: google/gemma-2-2b
   Model 2: google/paligemma2-3b-pt-224
   Texts: 5 samples

📝 Processing text 1/5: 'The quick brown fox jumps over the lazy dog....'
🔍 Extracting activations from google/gemma-2-2b


Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.09it/s]


   ✅ Extracted activations: torch.Size([1, 11, 2304])
🔍 Extracting activations from google/paligemma2-3b-pt-224
❌ Error extracting activations: Unrecognized configuration class <class 'transformers.models.paligemma.configuration_paligemma.PaliGemmaConfig'> for this kind of AutoModel: AutoModelForCausalLM.
Model type should be one of ArceeConfig, AriaTextConfig, BambaConfig, BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BitNetConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConfig, CamembertConfig, LlamaConfig, CodeGenConfig, CohereConfig, Cohere2Config, CpmAntConfig, CTRLConfig, Data2VecTextConfig, DbrxConfig, DeepseekV3Config, DiffLlamaConfig, Dots1Config, ElectraConfig, Emu3Config, ErnieConfig, FalconConfig, FalconH1Config, FalconMambaConfig, FuyuConfig, GemmaConfig, Gemma2Config, Gemma3Config, Gemma3TextConfig, Gemma3nConfig, Gemma3nTextConfig, GitConfig, GlmConfig, Glm4Config, GotOcr2Config, GPT2Config, GPT2Config, GPTBigCodeC

'\npip install sae-lens transformers torch matplotlib seaborn numpy\n\n# For CUDA support (recommended):\npip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n'