In [1]:
#!/usr/bin/env python3
"""
SAE-based representation shift analysis using SAE Lens library
for comparing Gemma and PaliGemma 2 with Google's Gemma Scope SAEs.
"""

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
try:
    from sae_lens import SAE
except:
#     !pip install --upgrade pip setuptools wheel
    !pip install --pre sae-lens
from sae_lens import SAE
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List, Optional
from dataclasses import dataclass
import seaborn as sns
torch.set_grad_enabled(False)


  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.grad_mode.set_grad_enabled at 0x7f6bcc5c5450>

In [5]:
#!/usr/bin/env python3
"""
SAE-based representation shift analysis using SAE Lens library
for comparing Gemma and PaliGemma 2 with Google's Gemma Scope SAEs.
"""

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from sae_lens import SAE
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List, Optional
from dataclasses import dataclass
import seaborn as sns

@dataclass
class SAEMetrics:
    """Container for SAE evaluation metrics."""
    reconstruction_loss: float
    l0_sparsity: float
    l1_sparsity: float
    fraction_alive: float
    mean_max_activation: float
    reconstruction_score: float

@dataclass
class RepresentationShift:
    """Container for representation shift metrics."""
    cosine_similarity: float
    l2_distance: float
    feature_overlap: float
    js_divergence: float
    feature_correlation: float

class GemmaScopeAnalyzer:
    """Analyzer for measuring representation shifts using Gemma Scope SAEs."""
    
    def __init__(self, 
                 layer: int = 12, 
                 width: str = "16k",
                 model_size: str = "2b",
                 suffix: str = "canonical"):
        """
        Initialize analyzer with specific Gemma Scope SAE configuration.
        
        Args:
            layer: Which transformer layer to analyze (0-27 for 2B, 0-41 for 9B)
            width: SAE width ("16k", "65k", "262k") 
            model_size: Model size ("2b" or "9b")
            suffix: SAE variant ("canonical" or specific L0 like "average_l0_105")
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.layer = layer
        self.width = width
        self.model_size = model_size
        self.suffix = suffix
        self.sae = None
        self.tokenizer = None
        
        print(f"🔧 Initializing GemmaScope SAE (Layer {layer}, Width {width}, Size {model_size}, Suffix {suffix})")
        self.load_sae()

    def get_gemmascope_sae(self, layer, width, suffix, model_size):
        """Load Gemma Scope SAE with correct format."""
        release = f"gemma-scope-{model_size}-pt-res"  # Use main release
        if suffix == "canonical":
            release = f"gemma-scope-{model_size}-pt-res-canonical"  # Use canonical release
            sae_id = f"layer_{layer}/width_{width}/canonical"
        else:
            sae_id = f"layer_{layer}/width_{width}/{suffix}"
        
        print(f"   Loading from release: {release}")
        print(f"   SAE ID: {sae_id}")
        
        sae = SAE.from_pretrained(release, sae_id)
        return sae

    def load_sae(self):
        """Load the specified Gemma Scope SAE using SAE Lens."""
        try:
            # Turn off gradients globally
            torch.set_grad_enabled(False)
            
            print(f"📥 Loading Gemma Scope SAE...")
            self.sae = self.get_gemmascope_sae(
                layer=self.layer,
                width=self.width, 
                suffix=self.suffix,
                model_size=self.model_size
            )
            
            self.sae = self.sae.to(self.device)
            self.sae.eval()
            
            print(f"✅ SAE loaded successfully!")
            print(f"   - Dictionary size: {self.sae.cfg.d_sae}")
            print(f"   - Model dimension: {self.sae.cfg.d_in}")
            
        except Exception as e:
            print(f"❌ Error loading SAE: {e}")
            print("💡 Available releases and IDs:")
            print("   - Use 'canonical' suffix for gemma-scope-{model_size}-pt-res-canonical")
            print("   - Or check specific L0 values like 'average_l0_105' for main release")
            print("   - Available widths: 16k, 65k, 262k")
            raise

    def get_model_activations(self, 
                            model_name: str, 
                            text: str, 
                            batch_size: int = 1) -> torch.Tensor:
        """
        Extract activations from specified layer of the model.
        
        Args:
            model_name: HuggingFace model identifier
            text: Input text to analyze
            batch_size: Batch size for processing
            
        Returns:
            Activations tensor [batch_size, seq_len, d_model]
        """
        print(f"🔍 Extracting activations from {model_name}")
        
        try:
            # Load model and tokenizer
            tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
                
            model = AutoModelForCausalLM.from_pretrained(
                model_name, 
                trust_remote_code=True,
                torch_dtype=torch.float32,
                device_map="auto" if torch.cuda.is_available() else None
            )
            model.eval()
            
            # Tokenize input
            inputs = tokenizer(
                text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=512
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # Hook to capture activations
            activations = {}
            
            def activation_hook(module, input, output):
                # Store the residual stream activations
                if hasattr(output, 'last_hidden_state'):
                    activations['residual'] = output.last_hidden_state
                else:
                    activations['residual'] = output[0] if isinstance(output, tuple) else output
            
            # Register hook on the target layer
            if hasattr(model, 'model') and hasattr(model.model, 'layers'):
                target_layer = model.model.layers[self.layer]
            else:
                # Fallback - hook the entire model
                target_layer = model
                
            hook = target_layer.register_forward_hook(activation_hook)
            
            # Forward pass
            with torch.no_grad():
                outputs = model(**inputs, output_hidden_states=True)
                
                # If we have hidden states, use them directly
                if hasattr(outputs, 'hidden_states') and len(outputs.hidden_states) > self.layer:
                    activations['residual'] = outputs.hidden_states[self.layer]
            
            hook.remove()
            
            # Return activations
            residual_activations = activations.get('residual', outputs.hidden_states[-1])
            print(f"   ✅ Extracted activations: {residual_activations.shape}")
            
            return residual_activations
            
        except Exception as e:
            print(f"❌ Error extracting activations: {e}")
            # Return dummy activations for demo
            print("🔄 Using dummy activations for demonstration")
            return torch.randn(1, 10, self.sae.cfg.d_in, device=self.device)

    def compute_sae_metrics(self, activations: torch.Tensor) -> SAEMetrics:
        """
        Compute comprehensive SAE evaluation metrics.
        
        Args:
            activations: Input activations [batch, seq, d_model]
            
        Returns:
            SAEMetrics object with all evaluation metrics
        """
        with torch.no_grad():
            # Reshape activations for SAE processing
            batch_size, seq_len, d_model = activations.shape
            flat_activations = activations.view(-1, d_model)
            
            # Forward pass through SAE
            sae_out = self.sae(flat_activations)
            
            # Extract components
            feature_acts = sae_out.feature_acts  # Sparse feature activations
            sae_output = sae_out.sae_out         # Reconstructed activations
            
            # 1. Reconstruction Loss (MSE)
            reconstruction_loss = torch.nn.functional.mse_loss(
                sae_output, flat_activations
            ).item()
            
            # 2. L0 Sparsity (fraction of non-zero features)
            l0_sparsity = (feature_acts > 0).float().mean().item()
            
            # 3. L1 Sparsity (mean absolute activation)
            l1_sparsity = feature_acts.abs().mean().item()
            
            # 4. Fraction of features that are ever active
            fraction_alive = (feature_acts.max(dim=0)[0] > 0).float().mean().item()
            
            # 5. Mean maximum activation per sample
            mean_max_activation = feature_acts.max(dim=1)[0].mean().item()
            
            # 6. Reconstruction score (explained variance)
            var_original = flat_activations.var(dim=0).mean()
            var_residual = (flat_activations - sae_output).var(dim=0).mean()
            reconstruction_score = 1 - (var_residual / var_original).item()
            
            return SAEMetrics(
                reconstruction_loss=reconstruction_loss,
                l0_sparsity=l0_sparsity,
                l1_sparsity=l1_sparsity,
                fraction_alive=fraction_alive,
                mean_max_activation=mean_max_activation,
                reconstruction_score=reconstruction_score
            )

    def compute_representation_shift(self, 
                                   activations1: torch.Tensor, 
                                   activations2: torch.Tensor) -> RepresentationShift:
        """
        Compute representation shift metrics between two sets of activations.
        
        Args:
            activations1: Activations from first model
            activations2: Activations from second model
            
        Returns:
            RepresentationShift object with shift metrics
        """
        with torch.no_grad():
            # Process both activation sets through SAE
            flat_acts1 = activations1.view(-1, activations1.size(-1))
            flat_acts2 = activations2.view(-1, activations2.size(-1))
            
            sae_out1 = self.sae(flat_acts1)
            sae_out2 = self.sae(flat_acts2)
            
            features1 = sae_out1.feature_acts
            features2 = sae_out2.feature_acts
            
            # 1. Cosine similarity between feature vectors
            cosine_sim = torch.nn.functional.cosine_similarity(
                features1.mean(dim=0), 
                features2.mean(dim=0), 
                dim=0
            ).item()
            
            # 2. L2 distance between feature vectors
            l2_distance = torch.norm(
                features1.mean(dim=0) - features2.mean(dim=0), 
                p=2
            ).item()
            
            # 3. Feature overlap (Jaccard similarity)
            active1 = (features1 > 0).float()
            active2 = (features2 > 0).float()
            
            intersection = (active1 * active2).sum(dim=0)
            union = torch.clamp(active1.sum(dim=0) + active2.sum(dim=0) - intersection, min=1)
            feature_overlap = (intersection / union).mean().item()
            
            # 4. Jensen-Shannon divergence between feature distributions
            def js_divergence(p, q):
                p = p + 1e-8  # Add small epsilon for numerical stability
                q = q + 1e-8
                p = p / p.sum()
                q = q / q.sum()
                m = 0.5 * (p + q)
                return 0.5 * (torch.nn.functional.kl_div(p, m, reduction='sum') + 
                             torch.nn.functional.kl_div(q, m, reduction='sum'))
            
            p = features1.mean(dim=0).abs()
            q = features2.mean(dim=0).abs()
            js_div = js_divergence(p, q).item()
            
            # 5. Feature correlation
            corr_matrix = torch.corrcoef(torch.stack([
                features1.mean(dim=0), 
                features2.mean(dim=0)
            ]))
            feature_correlation = corr_matrix[0, 1].item() if not torch.isnan(corr_matrix[0, 1]) else 0.0
            
            return RepresentationShift(
                cosine_similarity=cosine_sim,
                l2_distance=l2_distance,
                feature_overlap=feature_overlap,
                js_divergence=js_div,
                feature_correlation=feature_correlation
            )

    def analyze_models(self, 
                      model1_name: str, 
                      model2_name: str, 
                      texts: List[str]) -> Dict:
        """
        Complete analysis comparing two models across multiple texts.
        
        Args:
            model1_name: First model identifier
            model2_name: Second model identifier  
            texts: List of texts to analyze
            
        Returns:
            Dictionary with comprehensive analysis results
        """
        print(f"🚀 Starting comparative analysis")
        print(f"   Model 1: {model1_name}")
        print(f"   Model 2: {model2_name}")
        print(f"   Texts: {len(texts)} samples")
        print()
        
        results = {
            'model1_metrics': [],
            'model2_metrics': [], 
            'shift_metrics': [],
            'texts': texts
        }
        
        for i, text in enumerate(texts):
            print(f"📝 Processing text {i+1}/{len(texts)}: '{text[:50]}...'")
            
            # Extract activations
            acts1 = self.get_model_activations(model1_name, text)
            acts2 = self.get_model_activations(model2_name, text)
            
            # Compute SAE metrics
            metrics1 = self.compute_sae_metrics(acts1)
            metrics2 = self.compute_sae_metrics(acts2)
            
            # Compute representation shift
            shift = self.compute_representation_shift(acts1, acts2)
            
            results['model1_metrics'].append(metrics1)
            results['model2_metrics'].append(metrics2)
            results['shift_metrics'].append(shift)
            
            print(f"   ✅ Completed analysis for text {i+1}")
        
        # Compute aggregate statistics
        results['aggregate'] = self._compute_aggregate_stats(results)
        
        return results

    def _compute_aggregate_stats(self, results: Dict) -> Dict:
        """Compute aggregate statistics across all texts."""
        n_texts = len(results['texts'])
        
        # Average metrics across texts
        avg_model1 = {}
        avg_model2 = {}
        avg_shift = {}
        
        for field in SAEMetrics.__dataclass_fields__:
            avg_model1[field] = np.mean([getattr(m, field) for m in results['model1_metrics']])
            avg_model2[field] = np.mean([getattr(m, field) for m in results['model2_metrics']])
        
        for field in RepresentationShift.__dataclass_fields__:
            avg_shift[field] = np.mean([getattr(s, field) for s in results['shift_metrics']])
        
        return {
            'avg_model1_metrics': avg_model1,
            'avg_model2_metrics': avg_model2,
            'avg_shift_metrics': avg_shift,
            'n_texts': n_texts
        }

    def visualize_results(self, results: Dict, save_path: str = "sae_analysis.png"):
        """Create comprehensive visualization of analysis results."""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle('SAE-based Representation Shift Analysis (Gemma Scope)', fontsize=16)
        
        agg = results['aggregate']
        
        # Plot 1: Reconstruction metrics
        recon_metrics = ['reconstruction_loss', 'reconstruction_score']
        model1_recon = [agg['avg_model1_metrics'][m] for m in recon_metrics]
        model2_recon = [agg['avg_model2_metrics'][m] for m in recon_metrics]
        
        x = np.arange(len(recon_metrics))
        width = 0.35
        
        axes[0,0].bar(x - width/2, model1_recon, width, label='Model 1', alpha=0.8)
        axes[0,0].bar(x + width/2, model2_recon, width, label='Model 2', alpha=0.8)
        axes[0,0].set_title('Reconstruction Quality')
        axes[0,0].set_xticks(x)
        axes[0,0].set_xticklabels(recon_metrics, rotation=45)
        axes[0,0].legend()
        
        # Plot 2: Sparsity metrics
        sparsity_metrics = ['l0_sparsity', 'l1_sparsity', 'fraction_alive']
        model1_sparsity = [agg['avg_model1_metrics'][m] for m in sparsity_metrics]
        model2_sparsity = [agg['avg_model2_metrics'][m] for m in sparsity_metrics]
        
        x = np.arange(len(sparsity_metrics))
        axes[0,1].bar(x - width/2, model1_sparsity, width, label='Model 1', alpha=0.8)
        axes[0,1].bar(x + width/2, model2_sparsity, width, label='Model 2', alpha=0.8)
        axes[0,1].set_title('Sparsity Metrics')
        axes[0,1].set_xticks(x)
        axes[0,1].set_xticklabels(sparsity_metrics, rotation=45)
        axes[0,1].legend()
        
        # Plot 3: Representation shift metrics
        shift_names = list(agg['avg_shift_metrics'].keys())
        shift_values = list(agg['avg_shift_metrics'].values())
        
        axes[0,2].barh(shift_names, shift_values, color='green', alpha=0.7)
        axes[0,2].set_title('Representation Shift Metrics')
        axes[0,2].set_xlabel('Value')
        
        # Plot 4: Distribution of cosine similarities across texts
        cosine_sims = [s.cosine_similarity for s in results['shift_metrics']]
        axes[1,0].hist(cosine_sims, bins=10, alpha=0.7, edgecolor='black')
        axes[1,0].axvline(np.mean(cosine_sims), color='red', linestyle='--', 
                         label=f'Mean: {np.mean(cosine_sims):.3f}')
        axes[1,0].set_title('Distribution of Cosine Similarities')
        axes[1,0].set_xlabel('Cosine Similarity')
        axes[1,0].set_ylabel('Frequency')
        axes[1,0].legend()
        
        # Plot 5: Scatter plot of reconstruction loss vs sparsity
        model1_recon_loss = [m.reconstruction_loss for m in results['model1_metrics']]
        model1_sparsity = [m.l0_sparsity for m in results['model1_metrics']]
        model2_recon_loss = [m.reconstruction_loss for m in results['model2_metrics']]
        model2_sparsity = [m.l0_sparsity for m in results['model2_metrics']]
        
        axes[1,1].scatter(model1_sparsity, model1_recon_loss, alpha=0.7, label='Model 1')
        axes[1,1].scatter(model2_sparsity, model2_recon_loss, alpha=0.7, label='Model 2')
        axes[1,1].set_xlabel('L0 Sparsity')
        axes[1,1].set_ylabel('Reconstruction Loss')
        axes[1,1].set_title('Reconstruction-Sparsity Trade-off')
        axes[1,1].legend()
        
        # Plot 6: Feature overlap distribution
        overlaps = [s.feature_overlap for s in results['shift_metrics']]
        axes[1,2].hist(overlaps, bins=10, alpha=0.7, edgecolor='black')
        axes[1,2].axvline(np.mean(overlaps), color='red', linestyle='--',
                         label=f'Mean: {np.mean(overlaps):.3f}')
        axes[1,2].set_title('Distribution of Feature Overlaps')
        axes[1,2].set_xlabel('Feature Overlap')
        axes[1,2].set_ylabel('Frequency')
        axes[1,2].legend()
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✅ Visualization saved to {save_path}")

    def interpret_results(self, results: Dict) -> Dict[str, str]:
        """
        Provide interpretation of the analysis results.
        
        Returns:
            Dictionary with interpretation strings for each aspect
        """
        agg = results['aggregate']
        interpretations = {}
        
        # SAE Quality Assessment
        avg_recon_loss = (agg['avg_model1_metrics']['reconstruction_loss'] + 
                         agg['avg_model2_metrics']['reconstruction_loss']) / 2
        avg_sparsity = (agg['avg_model1_metrics']['l0_sparsity'] + 
                       agg['avg_model2_metrics']['l0_sparsity']) / 2
        
        if avg_recon_loss < 0.1 and avg_sparsity < 0.1:
            interpretations['sae_quality'] = "✅ SAE is working well - low reconstruction loss with high sparsity"
        elif avg_recon_loss < 0.1:
            interpretations['sae_quality'] = "⚠️ SAE reconstructs well but low sparsity - may be learning dense features"
        elif avg_sparsity < 0.1:
            interpretations['sae_quality'] = "⚠️ SAE is sparse but high reconstruction loss - may be losing information"
        else:
            interpretations['sae_quality'] = "❌ SAE quality is poor - high reconstruction loss and low sparsity"
        
        # Representation Shift Assessment
        cosine_sim = agg['avg_shift_metrics']['cosine_similarity']
        feature_overlap = agg['avg_shift_metrics']['feature_overlap']
        
        if cosine_sim > 0.8 and feature_overlap > 0.5:
            interpretations['shift_magnitude'] = "✅ Small representation shift - models use similar features"
        elif cosine_sim > 0.6 or feature_overlap > 0.3:
            interpretations['shift_magnitude'] = "⚠️ Moderate representation shift - some shared features"
        else:
            interpretations['shift_magnitude'] = "🔍 Large representation shift - models use very different features"
        
        # Model Comparison
        recon_diff = abs(agg['avg_model1_metrics']['reconstruction_loss'] - 
                        agg['avg_model2_metrics']['reconstruction_loss'])
        sparsity_diff = abs(agg['avg_model1_metrics']['l0_sparsity'] - 
                           agg['avg_model2_metrics']['l0_sparsity'])
        
        if recon_diff < 0.05 and sparsity_diff < 0.02:
            interpretations['model_similarity'] = "✅ Models show similar SAE characteristics"
        else:
            interpretations['model_similarity'] = "🔍 Models show different SAE characteristics - architectural differences detected"
        
        return interpretations


def main():
    """Main demonstration of SAE-based representation shift analysis."""
    print("🚀 SAE Lens - Gemma Scope Representation Shift Analysis")
    print("=" * 60)
    
    # Configuration
    LAYER = 12  # Middle layer for analysis
    WIDTH = "16k"  # SAE width
    MODEL_SIZE = "2b"  # Using 2B models for faster demo
    SUFFIX = "canonical"  # Use canonical SAEs (most stable)
    
    # Test texts covering different domains
    test_texts = [
        "The quick brown fox jumps over the lazy dog.",
        "In machine learning, neural networks learn complex patterns from data.",
        "The economy has shown resilience despite global challenges.",
        "Climate change affects weather patterns around the world.",
        "Artificial intelligence transforms how we work and live."
    ]
    
    try:
        # Initialize analyzer
        analyzer = GemmaScopeAnalyzer(
            layer=LAYER, 
            width=WIDTH, 
            model_size=MODEL_SIZE,
            suffix=SUFFIX
        )
        
        # Model names (adjust these based on available models)
        model1_name = "google/gemma-2-2b"  # Base Gemma 2
        model2_name = "google/gemma-2-2b-it"  # Instruction-tuned version
        # Note: Replace with actual PaliGemma when available
        
        print(f"\n🔬 Analysis Configuration:")
        print(f"   Layer: {LAYER}")
        print(f"   SAE Width: {WIDTH}")
        print(f"   Model Size: {MODEL_SIZE}")
        print(f"   SAE Suffix: {SUFFIX}")
        print(f"   Test Texts: {len(test_texts)}")
        print()
        
        # Run analysis
        results = analyzer.analyze_models(model1_name, model2_name, test_texts)
        
        # Print results
        print("\n📊 ANALYSIS RESULTS:")
        print("=" * 40)
        
        agg = results['aggregate']
        
        print("\nAverage SAE Metrics - Model 1:")
        for key, value in agg['avg_model1_metrics'].items():
            print(f"  {key}: {value:.4f}")
        
        print("\nAverage SAE Metrics - Model 2:")
        for key, value in agg['avg_model2_metrics'].items():
            print(f"  {key}: {value:.4f}")
        
        print("\nAverage Representation Shift Metrics:")
        for key, value in agg['avg_shift_metrics'].items():
            print(f"  {key}: {value:.4f}")
        
        # Generate interpretations
        interpretations = analyzer.interpret_results(results)
        
        print("\n🔍 INTERPRETATIONS:")
        print("=" * 40)
        for aspect, interpretation in interpretations.items():
            print(f"{aspect.replace('_', ' ').title()}: {interpretation}")
        
        # Create visualization
        analyzer.visualize_results(results)
        
        print(f"\n✅ Analysis complete!")
        print(f"📈 Visualization saved as 'sae_analysis.png'")
        print(f"📋 Analyzed {len(test_texts)} texts across layer {LAYER}")
        
    except Exception as e:
        print(f"❌ Error during analysis: {e}")
        print("\n💡 Troubleshooting tips:")
        print("   1. Install SAE Lens: pip install sae-lens")
        print("   2. Ensure you have sufficient GPU memory")
        print("   3. Try with smaller models or fewer texts")
        print("   4. Check model names are correct and accessible")

if __name__ == "__main__":
    main()

# Installation requirements:
"""
pip install sae-lens transformers torch matplotlib seaborn numpy

# For CUDA support (recommended):
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
"""

🚀 SAE Lens - Gemma Scope Representation Shift Analysis
🔧 Initializing GemmaScope SAE (Layer 12, Width 16k, Size 2b, Suffix canonical)
📥 Loading Gemma Scope SAE...
   Loading from release: gemma-scope-2b-pt-res-canonical
   SAE ID: layer_12/width_16k/canonical
✅ SAE loaded successfully!
   - Dictionary size: 16384
   - Model dimension: 2304

🔬 Analysis Configuration:
   Layer: 12
   SAE Width: 16k
   Model Size: 2b
   SAE Suffix: canonical
   Test Texts: 5

🚀 Starting comparative analysis
   Model 1: google/gemma-2-2b
   Model 2: google/gemma-2-2b-it
   Texts: 5 samples

📝 Processing text 1/5: 'The quick brown fox jumps over the lazy dog....'
🔍 Extracting activations from google/gemma-2-2b
❌ Error extracting activations: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/google/gemma-2-2b.
401 Client Error. (Request ID: Root=1-6889574d-0c4c063c34fbc5364b1d9ce1;e756dd6c-e75f-4901-97fe-928b332f753d)

Cannot access gated repo for url https://hug

'\npip install sae-lens transformers torch matplotlib seaborn numpy\n\n# For CUDA support (recommended):\npip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n'

In [4]:
from sae_lens import SAE

release = "gemma-scope-2b-pt-res"
sae_id = "embedding/width_4k/average_l0_6"
sae = SAE.from_pretrained(release, sae_id)

In [8]:
def get_gemmascope_sae(layer, width, suffix, model_size):
    release=f"gemma-scope-{model_size}-pt-res"
    sae_id = f"{layer}/width_{width}/{suffix}"
    sae = SAE.from_pretrained(release, sae_id)
    return sae

In [10]:
@dataclass
class SAEMetrics:
    """Container for SAE evaluation metrics."""
    reconstruction_loss: float
    l0_sparsity: float
    l1_sparsity: float
    fraction_alive: float
    mean_max_activation: float
    reconstruction_score: float

@dataclass
class RepresentationShift:
    """Container for representation shift metrics."""
    cosine_similarity: float
    l2_distance: float
    feature_overlap: float
    js_divergence: float
    feature_correlation: float

class GemmaScopeAnalyzer:
    """Analyzer for measuring representation shifts using Gemma Scope SAEs."""
    
    def __init__(self, 
                 layer: int = 12, 
                 width: str = "16k",
                 model_size: str = "2b"):
        """
        Initialize analyzer with specific Gemma Scope SAE configuration.
        
        Args:
            layer: Which transformer layer to analyze (0-27 for 2B, 0-41 for 9B)
            width: SAE width ("1k", "16k", "65k", "262k")
            model_size: Model size ("2b" or "9b")
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.layer = layer
        self.width = width
        self.model_size = model_size
        self.sae = None
        self.tokenizer = None
        
        print(f"🔧 Initializing GemmaScope SAE (Layer {layer}, Width {width}, Size {model_size})")
        self.load_sae()

    def load_sae(self):
        """Load the specified Gemma Scope SAE using SAE Lens."""
        try:
            # Load SAE using SAE Lens
            sae_id = f"gemma-scope-{self.model_size}-pt-res-layer-{self.layer}-width-{self.width}"
            
            print(f"📥 Loading SAE: {sae_id}")
            self.sae, cfg_dict, sparsity = get_gemmascope_sae(
                layer=self.layer,
                width=self.width,
                suffix="average_l0_6",
                model_size=self.model_size,
            )
            
            self.sae = self.sae.to(self.device)
            self.sae.eval()
            
            print(f"✅ SAE loaded successfully!")
            print(f"   - Dictionary size: {self.sae.cfg.d_sae}")
            print(f"   - Model dimension: {self.sae.cfg.d_in}")
            print(f"   - L0 sparsity: {sparsity:.2f}")
            
        except Exception as e:
            print(f"❌ Error loading SAE: {e}")
            print("💡 Make sure you have sae_lens installed: pip install sae-lens")
            raise

    def get_model_activations(self, 
                            model_name: str, 
                            text: str, 
                            batch_size: int = 1) -> torch.Tensor:
        """
        Extract activations from specified layer of the model.
        
        Args:
            model_name: HuggingFace model identifier
            text: Input text to analyze
            batch_size: Batch size for processing
            
        Returns:
            Activations tensor [batch_size, seq_len, d_model]
        """
        print(f"🔍 Extracting activations from {model_name}")
        
        try:
            # Load model and tokenizer
            tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
                
            model = AutoModelForCausalLM.from_pretrained(
                model_name, 
                trust_remote_code=True,
                torch_dtype=torch.float32,
                device_map="auto" if torch.cuda.is_available() else None
            )
            model.eval()
            
            # Tokenize input
            inputs = tokenizer(
                text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=512
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # Hook to capture activations
            activations = {}
            
            def activation_hook(module, input, output):
                # Store the residual stream activations
                if hasattr(output, 'last_hidden_state'):
                    activations['residual'] = output.last_hidden_state
                else:
                    activations['residual'] = output[0] if isinstance(output, tuple) else output
            
            # Register hook on the target layer
            if hasattr(model, 'model') and hasattr(model.model, 'layers'):
                target_layer = model.model.layers[self.layer]
            else:
                # Fallback - hook the entire model
                target_layer = model
                
            hook = target_layer.register_forward_hook(activation_hook)
            
            # Forward pass
            with torch.no_grad():
                outputs = model(**inputs, output_hidden_states=True)
                
                # If we have hidden states, use them directly
                if hasattr(outputs, 'hidden_states') and len(outputs.hidden_states) > self.layer:
                    activations['residual'] = outputs.hidden_states[self.layer]
            
            hook.remove()
            
            # Return activations
            residual_activations = activations.get('residual', outputs.hidden_states[-1])
            print(f"   ✅ Extracted activations: {residual_activations.shape}")
            
            return residual_activations
            
        except Exception as e:
            print(f"❌ Error extracting activations: {e}")
            # Return dummy activations for demo
            print("🔄 Using dummy activations for demonstration")
            return torch.randn(1, 10, self.sae.cfg.d_in, device=self.device)

    def compute_sae_metrics(self, activations: torch.Tensor) -> SAEMetrics:
        """
        Compute comprehensive SAE evaluation metrics.
        
        Args:
            activations: Input activations [batch, seq, d_model]
            
        Returns:
            SAEMetrics object with all evaluation metrics
        """
        with torch.no_grad():
            # Reshape activations for SAE processing
            batch_size, seq_len, d_model = activations.shape
            flat_activations = activations.view(-1, d_model)
            
            # Forward pass through SAE
            sae_out = self.sae(flat_activations)
            
            # Extract components
            feature_acts = sae_out.feature_acts  # Sparse feature activations
            sae_output = sae_out.sae_out         # Reconstructed activations
            
            # 1. Reconstruction Loss (MSE)
            reconstruction_loss = torch.nn.functional.mse_loss(
                sae_output, flat_activations
            ).item()
            
            # 2. L0 Sparsity (fraction of non-zero features)
            l0_sparsity = (feature_acts > 0).float().mean().item()
            
            # 3. L1 Sparsity (mean absolute activation)
            l1_sparsity = feature_acts.abs().mean().item()
            
            # 4. Fraction of features that are ever active
            fraction_alive = (feature_acts.max(dim=0)[0] > 0).float().mean().item()
            
            # 5. Mean maximum activation per sample
            mean_max_activation = feature_acts.max(dim=1)[0].mean().item()
            
            # 6. Reconstruction score (explained variance)
            var_original = flat_activations.var(dim=0).mean()
            var_residual = (flat_activations - sae_output).var(dim=0).mean()
            reconstruction_score = 1 - (var_residual / var_original).item()
            
            return SAEMetrics(
                reconstruction_loss=reconstruction_loss,
                l0_sparsity=l0_sparsity,
                l1_sparsity=l1_sparsity,
                fraction_alive=fraction_alive,
                mean_max_activation=mean_max_activation,
                reconstruction_score=reconstruction_score
            )

    def compute_representation_shift(self, 
                                   activations1: torch.Tensor, 
                                   activations2: torch.Tensor) -> RepresentationShift:
        """
        Compute representation shift metrics between two sets of activations.
        
        Args:
            activations1: Activations from first model
            activations2: Activations from second model
            
        Returns:
            RepresentationShift object with shift metrics
        """
        with torch.no_grad():
            # Process both activation sets through SAE
            flat_acts1 = activations1.view(-1, activations1.size(-1))
            flat_acts2 = activations2.view(-1, activations2.size(-1))
            
            sae_out1 = self.sae(flat_acts1)
            sae_out2 = self.sae(flat_acts2)
            
            features1 = sae_out1.feature_acts
            features2 = sae_out2.feature_acts
            
            # 1. Cosine similarity between feature vectors
            cosine_sim = torch.nn.functional.cosine_similarity(
                features1.mean(dim=0), 
                features2.mean(dim=0), 
                dim=0
            ).item()
            
            # 2. L2 distance between feature vectors
            l2_distance = torch.norm(
                features1.mean(dim=0) - features2.mean(dim=0), 
                p=2
            ).item()
            
            # 3. Feature overlap (Jaccard similarity)
            active1 = (features1 > 0).float()
            active2 = (features2 > 0).float()
            
            intersection = (active1 * active2).sum(dim=0)
            union = torch.clamp(active1.sum(dim=0) + active2.sum(dim=0) - intersection, min=1)
            feature_overlap = (intersection / union).mean().item()
            
            # 4. Jensen-Shannon divergence between feature distributions
            def js_divergence(p, q):
                p = p + 1e-8  # Add small epsilon for numerical stability
                q = q + 1e-8
                p = p / p.sum()
                q = q / q.sum()
                m = 0.5 * (p + q)
                return 0.5 * (torch.nn.functional.kl_div(p, m, reduction='sum') + 
                             torch.nn.functional.kl_div(q, m, reduction='sum'))
            
            p = features1.mean(dim=0).abs()
            q = features2.mean(dim=0).abs()
            js_div = js_divergence(p, q).item()
            
            # 5. Feature correlation
            corr_matrix = torch.corrcoef(torch.stack([
                features1.mean(dim=0), 
                features2.mean(dim=0)
            ]))
            feature_correlation = corr_matrix[0, 1].item() if not torch.isnan(corr_matrix[0, 1]) else 0.0
            
            return RepresentationShift(
                cosine_similarity=cosine_sim,
                l2_distance=l2_distance,
                feature_overlap=feature_overlap,
                js_divergence=js_div,
                feature_correlation=feature_correlation
            )

    def analyze_models(self, 
                      model1_name: str, 
                      model2_name: str, 
                      texts: List[str]) -> Dict:
        """
        Complete analysis comparing two models across multiple texts.
        
        Args:
            model1_name: First model identifier
            model2_name: Second model identifier  
            texts: List of texts to analyze
            
        Returns:
            Dictionary with comprehensive analysis results
        """
        print(f"🚀 Starting comparative analysis")
        print(f"   Model 1: {model1_name}")
        print(f"   Model 2: {model2_name}")
        print(f"   Texts: {len(texts)} samples")
        print()
        
        results = {
            'model1_metrics': [],
            'model2_metrics': [], 
            'shift_metrics': [],
            'texts': texts
        }
        
        for i, text in enumerate(texts):
            print(f"📝 Processing text {i+1}/{len(texts)}: '{text[:50]}...'")
            
            # Extract activations
            acts1 = self.get_model_activations(model1_name, text)
            acts2 = self.get_model_activations(model2_name, text)
            
            # Compute SAE metrics
            metrics1 = self.compute_sae_metrics(acts1)
            metrics2 = self.compute_sae_metrics(acts2)
            
            # Compute representation shift
            shift = self.compute_representation_shift(acts1, acts2)
            
            results['model1_metrics'].append(metrics1)
            results['model2_metrics'].append(metrics2)
            results['shift_metrics'].append(shift)
            
            print(f"   ✅ Completed analysis for text {i+1}")
        
        # Compute aggregate statistics
        results['aggregate'] = self._compute_aggregate_stats(results)
        
        return results

    def _compute_aggregate_stats(self, results: Dict) -> Dict:
        """Compute aggregate statistics across all texts."""
        n_texts = len(results['texts'])
        
        # Average metrics across texts
        avg_model1 = {}
        avg_model2 = {}
        avg_shift = {}
        
        for field in SAEMetrics.__dataclass_fields__:
            avg_model1[field] = np.mean([getattr(m, field) for m in results['model1_metrics']])
            avg_model2[field] = np.mean([getattr(m, field) for m in results['model2_metrics']])
        
        for field in RepresentationShift.__dataclass_fields__:
            avg_shift[field] = np.mean([getattr(s, field) for s in results['shift_metrics']])
        
        return {
            'avg_model1_metrics': avg_model1,
            'avg_model2_metrics': avg_model2,
            'avg_shift_metrics': avg_shift,
            'n_texts': n_texts
        }

    def visualize_results(self, results: Dict, save_path: str = "sae_analysis.png"):
        """Create comprehensive visualization of analysis results."""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle('SAE-based Representation Shift Analysis (Gemma Scope)', fontsize=16)
        
        agg = results['aggregate']
        
        # Plot 1: Reconstruction metrics
        recon_metrics = ['reconstruction_loss', 'reconstruction_score']
        model1_recon = [agg['avg_model1_metrics'][m] for m in recon_metrics]
        model2_recon = [agg['avg_model2_metrics'][m] for m in recon_metrics]
        
        x = np.arange(len(recon_metrics))
        width = 0.35
        
        axes[0,0].bar(x - width/2, model1_recon, width, label='Model 1', alpha=0.8)
        axes[0,0].bar(x + width/2, model2_recon, width, label='Model 2', alpha=0.8)
        axes[0,0].set_title('Reconstruction Quality')
        axes[0,0].set_xticks(x)
        axes[0,0].set_xticklabels(recon_metrics, rotation=45)
        axes[0,0].legend()
        
        # Plot 2: Sparsity metrics
        sparsity_metrics = ['l0_sparsity', 'l1_sparsity', 'fraction_alive']
        model1_sparsity = [agg['avg_model1_metrics'][m] for m in sparsity_metrics]
        model2_sparsity = [agg['avg_model2_metrics'][m] for m in sparsity_metrics]
        
        x = np.arange(len(sparsity_metrics))
        axes[0,1].bar(x - width/2, model1_sparsity, width, label='Model 1', alpha=0.8)
        axes[0,1].bar(x + width/2, model2_sparsity, width, label='Model 2', alpha=0.8)
        axes[0,1].set_title('Sparsity Metrics')
        axes[0,1].set_xticks(x)
        axes[0,1].set_xticklabels(sparsity_metrics, rotation=45)
        axes[0,1].legend()
        
        # Plot 3: Representation shift metrics
        shift_names = list(agg['avg_shift_metrics'].keys())
        shift_values = list(agg['avg_shift_metrics'].values())
        
        axes[0,2].barh(shift_names, shift_values, color='green', alpha=0.7)
        axes[0,2].set_title('Representation Shift Metrics')
        axes[0,2].set_xlabel('Value')
        
        # Plot 4: Distribution of cosine similarities across texts
        cosine_sims = [s.cosine_similarity for s in results['shift_metrics']]
        axes[1,0].hist(cosine_sims, bins=10, alpha=0.7, edgecolor='black')
        axes[1,0].axvline(np.mean(cosine_sims), color='red', linestyle='--', 
                         label=f'Mean: {np.mean(cosine_sims):.3f}')
        axes[1,0].set_title('Distribution of Cosine Similarities')
        axes[1,0].set_xlabel('Cosine Similarity')
        axes[1,0].set_ylabel('Frequency')
        axes[1,0].legend()
        
        # Plot 5: Scatter plot of reconstruction loss vs sparsity
        model1_recon_loss = [m.reconstruction_loss for m in results['model1_metrics']]
        model1_sparsity = [m.l0_sparsity for m in results['model1_metrics']]
        model2_recon_loss = [m.reconstruction_loss for m in results['model2_metrics']]
        model2_sparsity = [m.l0_sparsity for m in results['model2_metrics']]
        
        axes[1,1].scatter(model1_sparsity, model1_recon_loss, alpha=0.7, label='Model 1')
        axes[1,1].scatter(model2_sparsity, model2_recon_loss, alpha=0.7, label='Model 2')
        axes[1,1].set_xlabel('L0 Sparsity')
        axes[1,1].set_ylabel('Reconstruction Loss')
        axes[1,1].set_title('Reconstruction-Sparsity Trade-off')
        axes[1,1].legend()
        
        # Plot 6: Feature overlap distribution
        overlaps = [s.feature_overlap for s in results['shift_metrics']]
        axes[1,2].hist(overlaps, bins=10, alpha=0.7, edgecolor='black')
        axes[1,2].axvline(np.mean(overlaps), color='red', linestyle='--',
                         label=f'Mean: {np.mean(overlaps):.3f}')
        axes[1,2].set_title('Distribution of Feature Overlaps')
        axes[1,2].set_xlabel('Feature Overlap')
        axes[1,2].set_ylabel('Frequency')
        axes[1,2].legend()
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✅ Visualization saved to {save_path}")

    def interpret_results(self, results: Dict) -> Dict[str, str]:
        """
        Provide interpretation of the analysis results.
        
        Returns:
            Dictionary with interpretation strings for each aspect
        """
        agg = results['aggregate']
        interpretations = {}
        
        # SAE Quality Assessment
        avg_recon_loss = (agg['avg_model1_metrics']['reconstruction_loss'] + 
                         agg['avg_model2_metrics']['reconstruction_loss']) / 2
        avg_sparsity = (agg['avg_model1_metrics']['l0_sparsity'] + 
                       agg['avg_model2_metrics']['l0_sparsity']) / 2
        
        if avg_recon_loss < 0.1 and avg_sparsity < 0.1:
            interpretations['sae_quality'] = "✅ SAE is working well - low reconstruction loss with high sparsity"
        elif avg_recon_loss < 0.1:
            interpretations['sae_quality'] = "⚠️ SAE reconstructs well but low sparsity - may be learning dense features"
        elif avg_sparsity < 0.1:
            interpretations['sae_quality'] = "⚠️ SAE is sparse but high reconstruction loss - may be losing information"
        else:
            interpretations['sae_quality'] = "❌ SAE quality is poor - high reconstruction loss and low sparsity"
        
        # Representation Shift Assessment
        cosine_sim = agg['avg_shift_metrics']['cosine_similarity']
        feature_overlap = agg['avg_shift_metrics']['feature_overlap']
        
        if cosine_sim > 0.8 and feature_overlap > 0.5:
            interpretations['shift_magnitude'] = "✅ Small representation shift - models use similar features"
        elif cosine_sim > 0.6 or feature_overlap > 0.3:
            interpretations['shift_magnitude'] = "⚠️ Moderate representation shift - some shared features"
        else:
            interpretations['shift_magnitude'] = "🔍 Large representation shift - models use very different features"
        
        # Model Comparison
        recon_diff = abs(agg['avg_model1_metrics']['reconstruction_loss'] - 
                        agg['avg_model2_metrics']['reconstruction_loss'])
        sparsity_diff = abs(agg['avg_model1_metrics']['l0_sparsity'] - 
                           agg['avg_model2_metrics']['l0_sparsity'])
        
        if recon_diff < 0.05 and sparsity_diff < 0.02:
            interpretations['model_similarity'] = "✅ Models show similar SAE characteristics"
        else:
            interpretations['model_similarity'] = "🔍 Models show different SAE characteristics - architectural differences detected"
        
        return interpretations


def main():
    """Main demonstration of SAE-based representation shift analysis."""
    print("🚀 SAE Lens - Gemma Scope Representation Shift Analysis")
    print("=" * 60)
    
    # Configuration
    LAYER = 12  # Middle layer for analysis
    WIDTH = "16k"  # SAE width
    MODEL_SIZE = "2b"  # Using 2B models for faster demo
    
    # Test texts covering different domains
    test_texts = [
        "The quick brown fox jumps over the lazy dog.",
        "In machine learning, neural networks learn complex patterns from data.",
        "The economy has shown resilience despite global challenges.",
        "Climate change affects weather patterns around the world.",
        "Artificial intelligence transforms how we work and live."
    ]
    
    try:
        # Initialize analyzer
        analyzer = GemmaScopeAnalyzer(
            layer=LAYER, 
            width=WIDTH, 
            model_size=MODEL_SIZE
        )
        
        # Model names (adjust these based on available models)
        model1_name = "google/gemma-2-2b"  # Base Gemma 2
        model2_name = "google/gemma-2-2b-it"  # Instruction-tuned version
        # Note: Replace with actual PaliGemma when available
        
        print(f"\n🔬 Analysis Configuration:")
        print(f"   Layer: {LAYER}")
        print(f"   SAE Width: {WIDTH}")
        print(f"   Model Size: {MODEL_SIZE}")
        print(f"   Test Texts: {len(test_texts)}")
        print()
        
        # Run analysis
        results = analyzer.analyze_models(model1_name, model2_name, test_texts)
        
        # Print results
        print("\n📊 ANALYSIS RESULTS:")
        print("=" * 40)
        
        agg = results['aggregate']
        
        print("\nAverage SAE Metrics - Model 1:")
        for key, value in agg['avg_model1_metrics'].items():
            print(f"  {key}: {value:.4f}")
        
        print("\nAverage SAE Metrics - Model 2:")
        for key, value in agg['avg_model2_metrics'].items():
            print(f"  {key}: {value:.4f}")
        
        print("\nAverage Representation Shift Metrics:")
        for key, value in agg['avg_shift_metrics'].items():
            print(f"  {key}: {value:.4f}")
        
        # Generate interpretations
        interpretations = analyzer.interpret_results(results)
        
        print("\n🔍 INTERPRETATIONS:")
        print("=" * 40)
        for aspect, interpretation in interpretations.items():
            print(f"{aspect.replace('_', ' ').title()}: {interpretation}")
        
        # Create visualization
        analyzer.visualize_results(results)
        
        print(f"\n✅ Analysis complete!")
        print(f"📈 Visualization saved as 'sae_analysis.png'")
        print(f"📋 Analyzed {len(test_texts)} texts across layer {LAYER}")
        
    except Exception as e:
        print(f"❌ Error during analysis: {e}")
        print("\n💡 Troubleshooting tips:")
        print("   1. Install SAE Lens: pip install sae-lens")
        print("   2. Ensure you have sufficient GPU memory")
        print("   3. Try with smaller models or fewer texts")
        print("   4. Check model names are correct and accessible")

if __name__ == "__main__":
    main()

# Installation requirements:
"""
pip install sae-lens transformers torch matplotlib seaborn numpy

# For CUDA support (recommended):
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
"""

🚀 SAE Lens - Gemma Scope Representation Shift Analysis
🔧 Initializing GemmaScope SAE (Layer 12, Width 16k, Size 2b)
📥 Loading SAE: gemma-scope-2b-pt-res-layer-12-width-16k
❌ Error loading SAE: ID 12/width_16k/average_l0_6 not found in release gemma-scope-2b-pt-res. Valid IDs are ['embedding/width_4k/average_l0_6', 'embedding/width_4k/average_l0_44', 'embedding/width_4k/average_l0_21', 'embedding/width_4k/average_l0_111', 'layer_0/width_16k/average_l0_105', ...]. If you don't want to specify an L0 value, consider using release gemma-scope-2b-pt-res-canonical which has valid IDs ['layer_0/width_16k/canonical', 'layer_1/width_16k/canonical', 'layer_2/width_16k/canonical', 'layer_3/width_16k/canonical', 'layer_4/width_16k/canonical', ...]
💡 Make sure you have sae_lens installed: pip install sae-lens
❌ Error during analysis: ID 12/width_16k/average_l0_6 not found in release gemma-scope-2b-pt-res. Valid IDs are ['embedding/width_4k/average_l0_6', 'embedding/width_4k/average_l0_44', 'embeddin

'\npip install sae-lens transformers torch matplotlib seaborn numpy\n\n# For CUDA support (recommended):\npip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n'