In [7]:
#!/usr/bin/env python3
"""
SAE-based representation shift analysis using SAE Lens library
for comparing Gemma and PaliGemma 2 with Google's Gemma Scope SAEs.
"""

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
try:
    from sae_lens import SAE
except:
#     !pip install --upgrade pip setuptools wheel
    !pip install --pre sae-lens
from sae_lens import SAE
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List, Optional
from dataclasses import dataclass
import seaborn as sns
torch.set_grad_enabled(False)
from utils import *

In [8]:
from huggingface_hub import login
import os
from dotenv import find_dotenv, load_dotenv
load_dotenv(find_dotenv())
HUGGING_FACE_API_KEY = os.getenv('HUGGING_FACE_API_KEY')
# Paste your token between the quotes:
login(token=HUGGING_FACE_API_KEY)
device='cuda:1'

In [None]:
def main():
    """Main demonstration of SAE-based representation shift analysis."""
    print("🚀 SAE Lens - Gemma Scope Representation Shift Analysis")
    print("=" * 60)
    
    # Configuration
    LAYER = 12  # Middle layer for analysis
    WIDTH = "16k"  # SAE width
    MODEL_SIZE = "2b"  # Using 2B models for faster demo
    SUFFIX = "canonical"  # Use canonical SAEs (most stable)
    
    # Test texts covering different domains
    test_texts = [
        "The quick brown fox jumps over the lazy dog.",
        "In machine learning, neural networks learn complex patterns from data.",
        "The economy has shown resilience despite global challenges.",
        "Climate change affects weather patterns around the world.",
        "Artificial intelligence transforms how we work and live."
    ]
    
    try:
        # Initialize analyzer
        analyzer = GemmaScopeAnalyzer(
            layer=LAYER, 
            width=WIDTH, 
            model_size=MODEL_SIZE,
            suffix=SUFFIX,
            device=device
        )
        
        # Model names for LLM->VLM adaptation study  
        model1_name = "google/gemma-2-2b"  # Base Gemma-2-2B (LLM)
        model2_name = "google/paligemma2-3b-pt-224"  # PaliGemma using Gemma-2-2B decoder (VLM)
        
        print(f"\n🔬 Research Question: Representational shift during LLM->VLM adaptation")
        print(f"   Comparing the SAME Gemma-2-2B architecture before and after multimodal training")
        
        print(f"\n🔬 Analysis Configuration:")
        print(f"   Layer: {LAYER} (Gemma decoder layer)")
        print(f"   SAE Width: {WIDTH}")
        print(f"   Model Size: {MODEL_SIZE}")
        print(f"   SAE Suffix: {SUFFIX}")
        print(f"   Test Texts: {len(test_texts)}")
        print(f"   Research Focus: LLM->VLM representational adaptation")
        print()
        
        # Run analysis
        results = analyzer.analyze_models(model1_name, model2_name, test_texts)
        
        # Print results
        print("\n📊 ANALYSIS RESULTS:")
        print("=" * 40)
        
        agg = results['aggregate']
        
        print("\nAverage SAE Metrics - Model 1:")
        for key, value in agg['avg_model1_metrics'].items():
            print(f"  {key}: {value:.4f}")
        
        print("\nAverage SAE Metrics - Model 2:")
        for key, value in agg['avg_model2_metrics'].items():
            print(f"  {key}: {value:.4f}")
        
        print("\nAverage Representation Shift Metrics:")
        for key, value in agg['avg_shift_metrics'].items():
            print(f"  {key}: {value:.4f}")
        
        # Generate interpretations
        interpretations = analyzer.interpret_results(results)
        
        print("\n🔍 INTERPRETATIONS - LLM->VLM Adaptation:")
        print("=" * 50)
        for aspect, interpretation in interpretations.items():
            print(f"{aspect.replace('_', ' ').title()}: {interpretation}")
        
        # Add specific LLM->VLM insights
        agg = results['aggregate']
        cosine_sim = agg['avg_shift_metrics']['cosine_similarity']
        feature_overlap = agg['avg_shift_metrics']['feature_overlap']
        
        print(f"\n🧠 LLM->VLM Adaptation Insights:")
        print("-" * 40)
        
        if cosine_sim < 0.7:
            print("🔍 SIGNIFICANT adaptation detected - multimodal training substantially changed representations")
        elif cosine_sim < 0.85:
            print("⚠️  MODERATE adaptation - some representational drift during VLM training")  
        else:
            print("✅ MINIMAL adaptation - Gemma decoder largely preserved during VLM training")
            
        if feature_overlap < 0.4:
            print("🔄 FEATURE SPECIALIZATION - VLM uses different feature combinations than LLM")
        elif feature_overlap < 0.6:
            print("📊 PARTIAL REUSE - VLM partially reuses LLM features with modifications")
        else:
            print("🔗 FEATURE PRESERVATION - VLM largely reuses LLM feature representations")
            
        recon_diff = abs(agg['avg_model1_metrics']['reconstruction_loss'] - 
                        agg['avg_model2_metrics']['reconstruction_loss'])
        if recon_diff > 0.1:
            print("⚡ EFFICIENCY CHANGE - Multimodal training affected information encoding efficiency")
        
        # Create visualization
        analyzer.visualize_results(results)
        
        print(f"\n✅ Analysis complete!")
        print(f"📈 Visualization saved as 'sae_analysis.png'")
        print(f"📋 Analyzed {len(test_texts)} texts across layer {LAYER}")
        
    except Exception as e:
        print(f"❌ Error during analysis: {e}")
        print("\n💡 Troubleshooting tips:")
        print("   1. Install SAE Lens: pip install sae-lens")
        print("   2. Ensure you have sufficient GPU memory")
        print("   3. Try with smaller models or fewer texts")
        print("   4. Check model names are correct and accessible")

if __name__ == "__main__":
    main()

# Installation requirements:
"""
pip install sae-lens transformers torch matplotlib seaborn numpy

# For CUDA support (recommended):
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
"""

🚀 SAE Lens - Gemma Scope Representation Shift Analysis
🔧 Initializing GemmaScope SAE (Layer 12, Width 16k, Size 2b, Suffix canonical)
📥 Loading Gemma Scope SAE...
   Loading from release: gemma-scope-2b-pt-res-canonical
   SAE ID: layer_12/width_16k/canonical


In [10]:
results

NameError: name 'results' is not defined

In [None]:
stopper

In [5]:
from transformers import PaliGemmaForConditionalGeneration
model = PaliGemmaForConditionalGeneration.from_pretrained(
                    pretrained_model_name_or_path = 'google/paligemma2-3b-pt-224', 
                    trust_remote_code=True,
                    torch_dtype=torch.float32,
                    device_map="auto" if torch.cuda.is_available() else None
                )

Fetching 2 files: 100%|██████████| 2/2 [00:58<00:00, 29.01s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.56s/it]


In [6]:
language_model = model.language_model
print(f"   ✅ Extracted Gemma decoder: {type(language_model)}")

   ✅ Extracted Gemma decoder: <class 'transformers.models.gemma2.modeling_gemma2.Gemma2Model'>


In [7]:
language_model # 26 layers

Gemma2Model(
  (embed_tokens): Embedding(257216, 2304, padding_idx=0)
  (layers): ModuleList(
    (0-25): 26 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_feedforward_layernorm): 