In [1]:
#!/usr/bin/env python3
"""
Comprehensive SAE-based representation shift analysis with layer sweeping,
real datasets, and patching logic for LLM->VLM adaptation studies.
"""
# Installation requirements:
"""
pip install sae-lens transformers torch matplotlib seaborn numpy datasets tqdm

# For CUDA support (recommended):
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
"""

import torch
import numpy as np
import os
import gc
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, PaliGemmaForConditionalGeneration
from sae_lens import SAE
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List, Optional, Union
from dataclasses import dataclass
import seaborn as sns
from datasets import load_dataset
import json
from tqdm import tqdm
import warnings
import random
import json
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Test to see if both models have the output layer number in the hidden states
tokenizer = AutoTokenizer.from_pretrained('google/paligemma2-3b-pt-224', trust_remote_code=True)
# Ayda: Gemma models add a bos token but Paligemma models don‚Äôt, so it messes up comparison.
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print(f"tokenizer.pad_token:{tokenizer.pad_token}")

model1 = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-2b", 
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map=None
    )

language_model1 = model1
language_model1.eval()
    
model2 = PaliGemmaForConditionalGeneration.from_pretrained(
    'google/paligemma2-3b-pt-224', 
    trust_remote_code=True,
    torch_dtype=torch.float16,  
    device_map=None  
)

language_model2 = model2.language_model
language_model2.eval()
inputs = tokenizer(
            "cat", 
            return_tensors="pt", 
            padding="max_length",
            truncation=True,
            max_length=64,
            add_special_tokens=True  # Ensure special tokens are added properly
        )

#vocab_size = tokenizer.vocab_size
input_ids = inputs['input_ids']

outputs1 = model1(input_ids, output_hidden_states=True)
outputs2 = model2(input_ids, output_hidden_states=True)
print(len(outputs1.hidden_states), len(outputs2.hidden_states))

Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:00<00:00,  4.17it/s]
Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:01<00:00,  1.80it/s]


27 27


In [3]:
"""
Comprehensive SAE-based representation shift analysis with layer sweeping,
real datasets, and patching logic for LLM->VLM adaptation studies.
FIXED: Handles PaliGemma loss computation correctly.
"""

import torch
import numpy as np
import os
import gc
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from sae_lens import SAE
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List, Optional, Union
from dataclasses import dataclass
import seaborn as sns
from datasets import load_dataset
import json
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import torch.nn.functional as F

# Disable gradients globally for memory efficiency
torch.set_grad_enabled(False)

@dataclass
class SAEMetrics:
    """Container for SAE evaluation metrics."""
    reconstruction_loss: float
    l0_sparsity: float
    l1_sparsity: float
    fraction_alive: float
    mean_max_activation: float
    reconstruction_score: float
    model_delta_loss: float 
    rec_loss_topk: float

@dataclass
class RepresentationShift:
    """Container for representation shift metrics."""
    cosine_similarity: float
    l2_distance: float
    feature_overlap: float
    js_divergence: float
    feature_correlation: float

class DatasetLoader:
    """Handles loading and preprocessing of various datasets."""
    
    def __init__(self, device: str = "cuda"):
        self.device = device
    
    def load_cifar100_captions(self, split: str = "train", max_samples: int = 100) -> List[str]:
        """Load CIFAR-100 with generated captions for multimodal analysis."""
        try:
            # CIFAR-100 doesn't have captions by default, so we create descriptive ones
            dataset = load_dataset("cifar100", split=split)
            
            # CIFAR-100 class names
            class_names = [
                'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
                'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
                'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
                'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
                'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
                'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
                'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
                'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
                'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
                'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
                'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
                'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
                'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
                'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
            ]
            
            texts = []
            for i, sample in enumerate(dataset):
                if i >= max_samples:
                    break
                class_name = class_names[sample['fine_label']]
                # Generate descriptive captions
                captions = [
                    f"This is a photo of a {class_name}.",
                    f"An image showing a {class_name}.",
                    f"A picture of a {class_name} in natural setting.",
                    f"Visual representation of a {class_name}."
                ]
                texts.extend(captions[:2])  # Take 2 captions per image
            
            print(f"‚úÖ Loaded {len(texts)} CIFAR-100 captions")
            return texts[:max_samples]
            
        except Exception as e:
            print(f"‚ùå Error loading CIFAR-100: {e}")
            return self._get_fallback_texts()
    
    def load_coco_captions(self, split: str = "validation", max_samples: int = 100) -> List[str]:
        """Load COCO captions dataset."""
        try:
            # Load COCO captions
            dataset = load_dataset("HuggingFaceM4/COCO", split=split)
            
            texts = []
            for i, sample in enumerate(dataset):
                if i >= max_samples:
                    break
                
                # COCO has multiple captions per image
                if 'sentences' in sample and 'raw' in sample['sentences']:
                    for sentence in sample['sentences']['raw'][:2]:  # Take first 2 captions
                        texts.append(sentence)
                elif 'caption' in sample:
                    texts.append(sample['caption'])
            
            print(f"‚úÖ Loaded {len(texts)} COCO captions")
            return texts[:max_samples]
            
        except Exception as e:
            print(f"‚ùå Error loading COCO: {e}")
            # Try alternative COCO dataset
            try:
                dataset = load_dataset("nielsr/coco-captions", split="validation")
                texts = [sample['caption'] for sample in dataset.select(range(min(max_samples, len(dataset))))]
                print(f"‚úÖ Loaded {len(texts)} COCO captions (alternative)")
                return texts
            except:
                return self._get_fallback_texts()
    
    def load_llava_bench(self, max_samples: int = 100) -> List[str]:
        """Load LLaVA-Bench questions/descriptions."""
        try:
            # LLaVA bench conversations
            dataset = load_dataset("lmms-lab/LLaVA-OneVision-Data", split="dev_mini")
            
            texts = []
            for i, sample in enumerate(dataset):
                if i >= max_samples:
                    break
                
                if 'conversations' in sample:
                    for conv in sample['conversations'][:2]:  # Take first 2 conversations
                        if 'value' in conv:
                            texts.append(conv['value'])
            
            print(f"‚úÖ Loaded {len(texts)} LLaVA-Bench texts")
            return texts[:max_samples]
            
        except Exception as e:
            print(f"‚ùå Error loading LLaVA-Bench: {e}")
            return self._get_fallback_texts()
    
    def _get_fallback_texts(self) -> List[str]:
        """Fallback texts if datasets fail to load."""
        return [
            "A photo of a red apple on a white background.",
            "The cat is sitting on a wooden chair.",
            "Mountains covered with snow in winter landscape.",
            "A blue car driving on a highway.",
            "Children playing in a park with green grass.",
            "A delicious chocolate cake on a plate.",
            "Ocean waves crashing against rocky shore.",
            "A person reading a book in a library.",
            "Colorful flowers blooming in spring garden.",
            "A dog running happily in the field.",
        ]
    
    def get_mixed_dataset(self, total_samples: int = 150) -> List[str]:
        """Get a mixed dataset from multiple sources."""
        samples_per_source = total_samples // 3
        
        texts = []
        texts.extend(self.load_cifar100_captions(max_samples=samples_per_source))
        texts.extend(self.load_coco_captions(max_samples=samples_per_source))
        texts.extend(self.load_llava_bench(max_samples=samples_per_source))
        
        # Shuffle for good measure
        import random
        random.shuffle(texts)
        
        return texts[:total_samples]

class MemoryEfficientSAEAnalyzer:
    """Memory-efficient SAE analyzer with layer sweeping and patching logic."""
    
    def __init__(self, 
                 model_size: str = "2b",
                 width: str = "16k", 
                 suffix: str = "canonical",
                 device: str = "cuda",
                 output_dir: str = "../figs_tabs"):
        """
        Initialize memory-efficient SAE analyzer.
        
        Args:
            model_size: Model size ("2b" or "9b")
            width: SAE width ("16k", "65k", "262k")
            suffix: SAE variant ("canonical" or specific L0)
            device: Device to use
            output_dir: Directory for saving outputs
        """
        self.device = device if torch.cuda.is_available() else "cpu"
        self.model_size = model_size
        self.width = width
        self.suffix = suffix
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Model cache for memory efficiency
        self.model_cache = {}
        self.sae_cache = {}
        
        print(f"üîß Initialized SAE Analyzer")
        print(f"   Device: {self.device}")
        print(f"   Model Size: {model_size}")
        print(f"   SAE Width: {width}")
        print(f"   Output Dir: {output_dir}")

    def get_gemmascope_sae(self, layer: int) -> SAE:
        """Load Gemma Scope SAE with caching for memory efficiency."""
        cache_key = f"layer_{layer}"
        
        if cache_key in self.sae_cache:
            return self.sae_cache[cache_key]
        
        release = f"gemma-scope-{self.model_size}-pt-res"
        if self.suffix == "canonical":
            release = f"gemma-scope-{self.model_size}-pt-res-canonical"
            sae_id = f"layer_{layer}/width_{self.width}/canonical"
        else:
            sae_id = f"layer_{layer}/width_{self.width}/{self.suffix}"
        
        print(f"   üì• Loading SAE Layer {layer}: {sae_id}")
        
        try:
            sae = SAE.from_pretrained(release, sae_id).to(self.device)
            sae.eval()
            
            # Cache management - keep only last 2 SAEs to save memory
            if len(self.sae_cache) >= 2:
                oldest_key = list(self.sae_cache.keys())[0]
                del self.sae_cache[oldest_key]
                gc.collect()
            
            self.sae_cache[cache_key] = sae
            return sae
            
        except Exception as e:
            print(f"‚ùå Error loading SAE layer {layer}: {e}")
            raise

    def get_model(self, model_name: str):
        """Load model with caching and proper device placement."""
        if model_name in self.model_cache:
            return self.model_cache[model_name]
        
        print(f"üì• Loading model: {model_name}")
        
        try:
#             tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            tokenizer = AutoTokenizer.from_pretrained('google/paligemma2-3b-pt-224', trust_remote_code=True)
            # Ayda: Gemma models add a bos token but Paligemma models don‚Äôt, so it messes up comparison.
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
#             print(f"model_name: {model_name}, tokenizer.pad_token:{tokenizer.pad_token}")
            # Handle different model types
            if "paligemma" in model_name.lower():
                from transformers import PaliGemmaForConditionalGeneration
                model = PaliGemmaForConditionalGeneration.from_pretrained(
                    model_name, 
                    trust_remote_code=True,
                    torch_dtype=torch.float16,  # Use fp16 for memory efficiency
                    device_map=None  # We'll handle device placement manually
                )
                model = model.to(self.device)
                language_model = model.language_model
            else:
                model = AutoModelForCausalLM.from_pretrained(
                    model_name, 
                    trust_remote_code=True,
                    torch_dtype=torch.float16,
                    device_map=None
                )
                model = model.to(self.device)
                language_model = model
            
            language_model.eval()
#             print(model_name, language_model)
            # Cache management - keep only one model at a time
            if len(self.model_cache) >= 1:
                for cached_name in list(self.model_cache.keys()):
                    del self.model_cache[cached_name]
                gc.collect()
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
            self.model_cache[model_name] = (tokenizer, model, language_model)
            return tokenizer, model, language_model
            
        except Exception as e:
            print(f"‚ùå Error loading model {model_name}: {e}")
            raise

    def extract_activations_with_patching(self, 
                                      model_name: str, 
                                      text: str, 
                                      layer: int,
                                      sae: Optional[SAE] = None) -> Tuple[torch.Tensor, float]:
        """
        Extract activations ~and compute model delta loss with patching~.

        NOTE: Patching is DISABLED to save time. We keep the original code but comment it out.
        Always returns model_delta_loss = 0.0.
        """

        tokenizer, model, language_model = self.get_model(model_name)

        # FIXED: More robust tokenization with proper padding token handling
        # Ensure we have a pad token
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id

        # Tokenize with safer parameters
        inputs = tokenizer(
            text, 
            return_tensors="pt", 
            padding="max_length",
            truncation=True,
            max_length=64,
            add_special_tokens=True  # Ensure special tokens are added properly
        )

        # FIXED: Validate token IDs are within vocabulary range
        vocab_size = tokenizer.vocab_size
        input_ids = inputs['input_ids']

        # Check for out-of-bounds token IDs
        if torch.any(input_ids >= vocab_size) or torch.any(input_ids < 0):
            print(f"‚ö†Ô∏è  Invalid token IDs detected. Max ID: {input_ids.max()}, Vocab size: {vocab_size}")
            # Clamp invalid IDs to valid range
            input_ids = torch.clamp(input_ids, 0, vocab_size - 1)
            inputs['input_ids'] = input_ids

        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # FIXED: More robust label creation
        def create_labels(input_ids, pad_token_id):
            """Create labels with proper masking for loss computation"""
            labels = input_ids.clone()
            # Mask padding tokens
            labels[labels == pad_token_id] = -100
            # FIXED: Also mask the first token (often BOS) to avoid issues
            if labels.size(1) > 1:
                labels[:, 0] = -100
            return labels

        # Get unpatched model loss (baseline) ‚Äî kept for potential logging/consistency
        unpatched_loss = 0.0
#         with torch.no_grad():
#             try:
#                 if "paligemma" in model_name.lower():
#                     # For PaliGemma, we need to handle text-only input differently
#                     labels = create_labels(inputs['input_ids'], tokenizer.pad_token_id)

#                     # Get outputs from language model
#                     unpatched_outputs = language_model(**inputs)

#                     # Check if we have logits to compute loss
#                     if hasattr(unpatched_outputs, 'logits'):
#                         logits = unpatched_outputs.logits

#                         # FIXED: More robust loss computation with better shape handling
#                         if logits.size(1) > 1 and labels.size(1) > 1:
#                             shift_logits = logits[..., :-1, :].contiguous()
#                             shift_labels = labels[..., 1:].contiguous()

#                             # Ensure we have valid data for loss computation
#                             valid_mask = shift_labels != -100
#                             if valid_mask.any():
#                                 shift_logits = shift_logits.view(-1, shift_logits.size(-1))
#                                 shift_labels = shift_labels.view(-1)

#                                 # FIXED: Use reduction='mean' and handle empty tensors
#                                 loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
#                                 unpatched_loss = loss_fct(shift_logits, shift_labels).item()
#                             else:
#                                 print("‚ö†Ô∏è  No valid tokens for loss computation")
#                                 unpatched_loss = 0.0
#                         else:
#                             print("‚ö†Ô∏è  Insufficient sequence length for loss computation")
#                             unpatched_loss = 0.0
#                     else:
#                         # Fallback for models without logits
#                         unpatched_loss = 0.0
#                         print(f"‚ö†Ô∏è  No logits available for {model_name}, using zero loss")

#                 else:
#                     # For regular language models
#                     labels = create_labels(inputs['input_ids'], tokenizer.pad_token_id)
#                     unpatched_outputs = language_model(**inputs, labels=labels)

#                     if hasattr(unpatched_outputs, 'loss') and unpatched_outputs.loss is not None:
#                         unpatched_loss = unpatched_outputs.loss.item()
#                     else:
#                         # FIXED: Same robust loss computation as above
#                         if hasattr(unpatched_outputs, 'logits'):
#                             logits = unpatched_outputs.logits

#                             if logits.size(1) > 1 and labels.size(1) > 1:
#                                 shift_logits = logits[..., :-1, :].contiguous()
#                                 shift_labels = labels[..., 1:].contiguous()

#                                 valid_mask = shift_labels != -100
#                                 if valid_mask.any():
#                                     shift_logits = shift_logits.view(-1, shift_logits.size(-1))
#                                     shift_labels = shift_labels.view(-1)

#                                     loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
#                                     unpatched_loss = loss_fct(shift_logits, shift_labels).item()
#                                 else:
#                                     unpatched_loss = 0.0
#                             else:
#                                 unpatched_loss = 0.0
#                         else:
#                             unpatched_loss = 0.0

#             except Exception as e:
#                 print(f"‚ö†Ô∏è  Error computing unpatched loss: {e}")
#                 unpatched_loss = 0.0

        # Extract activations from target layer
        activations = None
        # patched_loss = unpatched_loss  # DEFAULT if patching were enabled (kept for reference)

        def activation_hook(module, input, output):
            nonlocal activations
            try:
                if isinstance(output, tuple):
                    activations = output[0].clone().detach()
                else:
                    activations = output.clone().detach()
            except Exception as e:
                print(f"‚ö†Ô∏è  Error in activation hook: {e}")

        # FIXED: More robust layer identification
        target_layer = None
        try:
            if hasattr(language_model, 'model') and hasattr(language_model.model, 'layers'):
                if layer < len(language_model.model.layers):
                    target_layer = language_model.model.layers[layer]
                else:
                    print(f"‚ùå Layer {layer} out of range. Model has {len(language_model.model.layers)} layers")
                    return torch.randn(1, 64, 2304).to(self.device), 0.0
            elif hasattr(language_model, 'layers'):
                if layer < len(language_model.layers):
                    target_layer = language_model.layers[layer]
                else:
                    print(f"‚ùå Layer {layer} out of range. Model has {len(language_model.layers)} layers")
                    return torch.randn(1, 64, 2304).to(self.device), 0.0
            else:
                print(f"‚ùå Could not find layers in model structure")
                return torch.randn(1, 64, 2304).to(self.device), 0.0
        except Exception as e:
            print(f"‚ùå Error accessing layer {layer}: {e}")
            return torch.randn(1, 64, 2304).to(self.device), 0.0

        if target_layer is None:
            print(f"‚ùå Could not find layer {layer}")
            return torch.randn(1, 64, 2304).to(self.device), 0.0

        hook = target_layer.register_forward_hook(activation_hook)

        # Forward pass to get activations
        with torch.no_grad():
            try:
                if "paligemma" in model_name.lower():
                    _ = language_model(**inputs)
                else:
                    _ = language_model(**inputs)
            except Exception as e:
                print(f"‚ö†Ô∏è  Error in activation extraction: {e}")

        hook.remove()

        # === PATCHING DISABLED ===
        # Compute patched loss if SAE is provided
        # if sae is not None and activations is not None:
        #     patched_loss = self._compute_patched_loss(
        #         language_model, inputs, activations, sae, layer, model_name, tokenizer
        #     )

        # model_delta_loss = patched_loss - unpatched_loss
        model_delta_loss = 0.0  # Always return 0 for delta_loss when patching is disabled

        if activations is None:
            print(f"‚ö†Ô∏è  Failed to extract activations from layer {layer}")
            # FIXED: Return appropriate tensor size based on model
            try:
                # Try to get the actual hidden size from the model config
                if hasattr(language_model, 'config') and hasattr(language_model.config, 'hidden_size'):
                    hidden_size = language_model.config.hidden_size
                else:
                    hidden_size = 2304  # fallback
                activations = torch.randn(1, 64, hidden_size).to(self.device)
            except:
                activations = torch.randn(1, 64, 2304).to(self.device)

        return activations, model_delta_loss


    def _compute_patched_loss(self, 
                            language_model, 
                            inputs: Dict, 
                            original_activations: torch.Tensor, 
                            sae: SAE, 
                            layer: int,
                            model_name: str,
                            tokenizer) -> float:
        """Compute loss with SAE-patched activations. FIXED: Robust error handling and loss computation."""
        try:
            # Get SAE reconstruction
            flat_activations = original_activations.view(-1, original_activations.size(-1))
            print(f"Activations shape: {flat_activations.shape}")

            sae_output = sae(flat_activations)

            # Handle different SAE output formats
            if hasattr(sae_output, 'sae_out'):
                reconstructed = sae_output.sae_out
            elif isinstance(sae_output, tuple):
                reconstructed = sae_output[0]
            else:
                reconstructed = sae_output

            # Reshape back to original shape
            reconstructed = reconstructed.view(original_activations.shape)

            # Patch the reconstructed activations back into the model
            patched_activations = reconstructed.detach()  # FIXED: Ensure no gradients

            # Create a patching hook
            def patching_hook(module, input, output):
                try:
                    if isinstance(output, tuple):
                        return (patched_activations, *output[1:])
                    else:
                        return patched_activations
                except Exception as e:
                    print(f"‚ö†Ô∏è  Error in patching hook: {e}")
                    return output  # Return original if patching fails

            # Hook the target layer for patching
            target_layer = None
            if hasattr(language_model, 'model') and hasattr(language_model.model, 'layers'):
                if layer < len(language_model.model.layers):
                    target_layer = language_model.model.layers[layer]
            elif hasattr(language_model, 'layers'):
                if layer < len(language_model.layers):
                    target_layer = language_model.layers[layer]

            if target_layer is None:
                return 0.0

            patch_hook = target_layer.register_forward_hook(patching_hook)

            # FIXED: Use the same robust label creation as in main function
            def create_labels(input_ids, pad_token_id):
                labels = input_ids.clone()
                labels[labels == pad_token_id] = -100
                if labels.size(1) > 1:
                    labels[:, 0] = -100
                return labels

            # Forward pass with patched activations
            patched_loss = 0.0
            with torch.no_grad():
                if "paligemma" in model_name.lower():
                    labels = create_labels(inputs['input_ids'], tokenizer.pad_token_id)
                    patched_outputs = language_model(**inputs)

                    if hasattr(patched_outputs, 'logits'):
                        logits = patched_outputs.logits

                        if logits.size(1) > 1 and labels.size(1) > 1:
                            shift_logits = logits[..., :-1, :].contiguous()
                            shift_labels = labels[..., 1:].contiguous()

                            valid_mask = shift_labels != -100
                            if valid_mask.any():
                                shift_logits = shift_logits.view(-1, shift_logits.size(-1))
                                shift_labels = shift_labels.view(-1)

                                loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
                                patched_loss = loss_fct(shift_logits, shift_labels).item()
                    else:
                        patched_loss = 0.0
                else:
                    labels = create_labels(inputs['input_ids'], tokenizer.pad_token_id)
                    patched_outputs = language_model(**inputs, labels=labels)

                    if hasattr(patched_outputs, 'loss') and patched_outputs.loss is not None:
                        patched_loss = patched_outputs.loss.item()
                    else:
                        if hasattr(patched_outputs, 'logits'):
                            logits = patched_outputs.logits

                            if logits.size(1) > 1 and labels.size(1) > 1:
                                shift_logits = logits[..., :-1, :].contiguous()
                                shift_labels = labels[..., 1:].contiguous()

                                valid_mask = shift_labels != -100
                                if valid_mask.any():
                                    shift_logits = shift_logits.view(-1, shift_logits.size(-1))
                                    shift_labels = shift_labels.view(-1)

                                    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
                                    patched_loss = loss_fct(shift_logits, shift_labels).item()

            patch_hook.remove()
            return patched_loss

        except Exception as e:
            print(f"‚ö†Ô∏è  Patching failed: {e}")
            return 0.0

    def compute_sae_metrics(self, activations: torch.Tensor, sae: SAE, model_delta_loss: float) -> SAEMetrics:
        """Compute comprehensive SAE evaluation metrics including model delta loss and top-20 features."""
        with torch.no_grad():
            # Reshape activations for SAE processing
            batch_size, seq_len, d_model = activations.shape
            flat_activations = activations.view(-1, d_model)
            
            # Forward pass through SAE
            sae_output = sae(flat_activations) 
            
            # Handle different SAE output formats
            if hasattr(sae_output, 'feature_acts'):
                feature_acts = sae_output.feature_acts # shape (batch_size * seq_len,  latent_dim)
                reconstructed = sae_output.sae_out
            elif isinstance(sae_output, tuple) and len(sae_output) >= 2:
                reconstructed, feature_acts = sae_output[0], sae_output[1]
            elif hasattr(sae, 'encode') and hasattr(sae, 'decode'):
                feature_acts = sae.encode(flat_activations)
                reconstructed = sae.decode(feature_acts)
            else:
                reconstructed = sae_output
                if hasattr(sae, 'W_enc') and hasattr(sae, 'b_enc'):
                    feature_acts = torch.relu(flat_activations @ sae.W_enc + sae.b_enc)
                else:
                    print(f"Failed retrieving SAE reconstructions, random intialisign...")
                    feature_acts = torch.randn(flat_activations.shape[0], 16384, device=flat_activations.device)
            
            # 1. Reconstruction Loss (MSE)
            reconstruction_loss = torch.nn.functional.mse_loss(reconstructed, flat_activations).item()
            
            # 2. L0 Sparsity (fraction of non-zero features)
            l0_sparsity = (feature_acts > 0).float().mean().item()
            
            # 3. L1 Sparsity (mean absolute activation)
            l1_sparsity = feature_acts.abs().mean().item()
            
            # 4. Fraction of features that are ever active
            fraction_alive = (feature_acts.max(dim=0)[0] > 0).float().mean().item()
            
            # 5. Mean maximum activation per sample
            mean_max_activation = feature_acts.max(dim=1)[0].mean().item()
            
            # 6. Reconstruction score (explained variance)
            var_original = flat_activations.var(dim=0).mean()
            var_residual = (flat_activations - reconstructed).var(dim=0).mean()
            reconstruction_score = max(0.0, 1 - (var_residual / var_original).item())
            
            # Store top-20 features for analysis
            mean_feature_acts = feature_acts.mean(dim=0)  # Average across all tokens/samples
            top_20_indices = torch.topk(mean_feature_acts, k=min(20, feature_acts.size(-1)))[1]
            self._store_top_features(top_20_indices, mean_feature_acts, 
                                   reconstruction_loss, l0_sparsity, model_delta_loss)
            
            # top-20 rec loss
            top_acts = feature_acts[..., top_20_indices] # shape (batch_size * seq_len,  latent_dim)
            if hasattr(sae, 'decode'):
                latent_dim = feature_acts.size(-1)  # e.g., 16384
                z_sparse = torch.zeros(feature_acts.size(0), latent_dim,
                                       device=feature_acts.device, dtype=feature_acts.dtype)
                z_sparse[:, top_20_indices] = top_acts  # place the 20 activations at their true indices
                recon_from_topk = sae.decode(z_sparse)  # ‚úÖ correct shape
#                 recon_from_topk = sae.decode( top_acts )  # if your SAE supports that
            else:
                if hasattr(sae, 'W_dec') and hasattr(sae, 'b_dec'):
                    # Select the relevant columns from W_dec for the top-20 features
                    W_dec_topk = sae.W_dec[:, top_20_indices]  # Select columns corresponding to top-20 activations

                    # If necessary, apply a bias term (assuming b_dec is shared across all features)
                    b_dec_topk = sae.b_dec  # Bias term stays the same for all activations
                    print(top_acts.shape, W_dec_topk.shape, b_dec_topk.shape)
                    # Reconstruct the activations from the top-20 features
                    recon_from_topk = torch.relu(top_acts @ W_dec_topk + b_dec_topk)
            
            rec_loss_topk = F.mse_loss(recon_from_topk, flat_activations).item()
            
            return SAEMetrics(
                reconstruction_loss=reconstruction_loss,
                l0_sparsity=l0_sparsity,
                l1_sparsity=l1_sparsity,
                fraction_alive=fraction_alive,
                mean_max_activation=mean_max_activation,
                reconstruction_score=reconstruction_score,
                model_delta_loss=model_delta_loss,
                rec_loss_topk=rec_loss_topk
            )
    
    def _store_top_features(self, top_indices: torch.Tensor, feature_acts: torch.Tensor, 
                           recon_loss: float, sparsity: float, delta_loss: float):
        """Store top-20 activated features for analysis."""
        if not hasattr(self, 'top_features_log'):
            self.top_features_log = []
        
        top_features_info = {
            'top_20_indices': top_indices.cpu().tolist(),
            'top_20_activations': feature_acts[top_indices].cpu().tolist(),
            'reconstruction_loss': recon_loss,
            'sparsity': sparsity,
            'delta_loss': delta_loss,
            'timestamp': len(self.top_features_log)  # Simple counter
        }
        
        self.top_features_log.append(top_features_info)

    def analyze_layer_sweep(self, 
                           model1_name: str, 
                           model2_name: str, 
                           texts: List[str],
                           layers: List[int] = None) -> Dict:
        """
        Perform memory-efficient layer sweep analysis.
        
        Args:
            model1_name: First model (base LLM)
            model2_name: Second model (VLM) 
            texts: List of texts to analyze
            layers: List of layers to analyze (default: [8, 12, 16, 20])
        """
        if layers is None:
            layers = [8, 12, 16, 20]  # Sample layers across the model
        
        print(f"üöÄ Starting Layer Sweep Analysis")
        print(f"   Model 1: {model1_name}")
        print(f"   Model 2: {model2_name}")
        print(f"   Layers: {layers}")
        print(f"   Texts: {len(texts)} samples")
        print(f"   Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB" if torch.cuda.is_available() else "")
        
        results = {
            'layers': layers,
            'layer_results': {},
            'texts': texts[:10]  # Store subset for reference
        }
        
        for layer in tqdm(layers, desc="Processing layers"):
            print(f"\nüìä Processing Layer {layer}")
            
            # Load SAE for this layer
            #############
            sae = self.get_gemmascope_sae(layer-1)
     #############

            layer_metrics = {
                'model1_metrics': [],
                'model2_metrics': [],
                'shift_metrics': []
            }
            
            # Process subset of texts for each layer (memory efficiency)
            sample_texts = texts[:100]  # Process 100 texts per layer (increased from 20)
            
            for i, text in enumerate(tqdm(sample_texts, desc=f"Layer {layer} texts", leave=False)):
                try:
                    # Extract activations and compute metrics for model 1
                    acts1, delta_loss1 = self.extract_activations_with_patching(
                        model1_name, text, layer, sae
                    )
                    metrics1 = self.compute_sae_metrics(acts1, sae, delta_loss1)
                    
                    # Extract activations and compute metrics for model 2
                    acts2, delta_loss2 = self.extract_activations_with_patching(
                        model2_name, text, layer, sae
                    )
                    metrics2 = self.compute_sae_metrics(acts2, sae, delta_loss2)
                    
                    # Compute representation shift
                    shift = self.compute_representation_shift(acts1, acts2, sae)
                    
                    layer_metrics['model1_metrics'].append(metrics1)
                    layer_metrics['model2_metrics'].append(metrics2)
                    layer_metrics['shift_metrics'].append(shift)
                    
                except Exception as e:
                    print(f"‚ö†Ô∏è  Error processing text {i} in layer {layer}: {e}")
                    continue
            
            # Compute layer-level aggregates
            layer_metrics['aggregate'] = self._compute_layer_aggregate(layer_metrics)
            results['layer_results'][layer] = layer_metrics
            
            # Memory cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                print(f"   Memory after layer {layer}: {torch.cuda.memory_allocated() / 1e9:.2f}GB")
        
        # Compute overall analysis
        results['overall_analysis'] = self._compute_overall_analysis(results)
        
        return results

    def compute_representation_shift(self, 
                                   activations1: torch.Tensor, 
                                   activations2: torch.Tensor,
                                   sae: SAE) -> RepresentationShift:
        """Compute representation shift metrics using SAE features."""
        with torch.no_grad():
            # Process both activation sets through SAE
            flat_acts1 = activations1.view(-1, activations1.size(-1))
            flat_acts2 = activations2.view(-1, activations2.size(-1))
            
            # Get SAE features
            def extract_features(flat_acts):
                sae_output = sae(flat_acts)
                if hasattr(sae_output, 'feature_acts'):
                    return sae_output.feature_acts
                elif isinstance(sae_output, tuple) and len(sae_output) >= 2:
                    return sae_output[1]
                elif hasattr(sae, 'encode'):
                    return sae.encode(flat_acts)
                else:
                    if hasattr(sae, 'W_enc') and hasattr(sae, 'b_enc'):
                        print(f"flat_acts: {flat_acts.shape}, sae.W_enc: {sae.W_enc.shape}, sae.b_enc: {sae.b_enc.shape}")
                        return torch.relu(flat_acts @ sae.W_enc + sae.b_enc)
                    else:
                        return torch.randn(flat_acts.shape[0], 16384, device=flat_acts.device)
            
            features1 = extract_features(flat_acts1)
            features2 = extract_features(flat_acts2)
            
            # 1. Cosine similarity
            cosine_sim = torch.nn.functional.cosine_similarity(
                features1.mean(dim=0), features2.mean(dim=0), dim=0
            ).item()
            
            # 2. L2 distance
            l2_distance = torch.norm(features1.mean(dim=0) - features2.mean(dim=0), p=2).item()
            
            # 3. Feature overlap (Jaccard similarity)
            active1 = (features1 > 0).float()
            active2 = (features2 > 0).float()
            intersection = (active1 * active2).sum(dim=0)
            union = torch.clamp(active1.sum(dim=0) + active2.sum(dim=0) - intersection, min=1)
            feature_overlap = (intersection / union).mean().item()
            
            # 4. Jensen-Shannon divergence
            def js_divergence(p, q):
                p = p + 1e-8
                q = q + 1e-8
                p = p / p.sum()
                q = q / q.sum()
                m = 0.5 * (p + q)
                return 0.5 * (torch.nn.functional.kl_div(p.log(), m, reduction='sum') + 
                             torch.nn.functional.kl_div(q.log(), m, reduction='sum'))
            
            p = features1.mean(dim=0).abs()
            q = features2.mean(dim=0).abs()
            js_div = js_divergence(p, q).item()
            
            # 5. Feature correlation
            try:
                corr_matrix = torch.corrcoef(torch.stack([
                    features1.mean(dim=0), features2.mean(dim=0)
                ]))
                feature_correlation = corr_matrix[0, 1].item() if not torch.isnan(corr_matrix[0, 1]) else 0.0
            except:
                feature_correlation = 0.0
            
            return RepresentationShift(
                cosine_similarity=cosine_sim,
                l2_distance=l2_distance,
                feature_overlap=feature_overlap,
                js_divergence=js_div,
                feature_correlation=feature_correlation
            )

    def _compute_layer_aggregate(self, layer_metrics: Dict) -> Dict:
        """Compute aggregate statistics for a single layer."""
        n_samples = len(layer_metrics['model1_metrics'])
        if n_samples == 0:
            return {}
        
        # Average metrics across samples
        avg_model1 = {}
        avg_model2 = {}
        avg_shift = {}
        
        for field in SAEMetrics.__dataclass_fields__:
            avg_model1[field] = np.mean([getattr(m, field) for m in layer_metrics['model1_metrics']])
            avg_model2[field] = np.mean([getattr(m, field) for m in layer_metrics['model2_metrics']])
        
        for field in RepresentationShift.__dataclass_fields__:
            avg_shift[field] = np.mean([getattr(s, field) for s in layer_metrics['shift_metrics']])
        
        return {
            'avg_model1_metrics': avg_model1,
            'avg_model2_metrics': avg_model2,
            'avg_shift_metrics': avg_shift,
            'n_samples': n_samples
        }

    def _compute_overall_analysis(self, results: Dict) -> Dict:
        """Compute overall analysis across all layers."""
        layers = results['layers']
        
        # Collect metrics across layers
        layer_similarities = []
        layer_overlaps = []
        layer_delta_losses = []
        layer_sparsities = []
        layer_rec_loss_topk = []
        
        for layer in layers:
            if layer in results['layer_results'] and 'aggregate' in results['layer_results'][layer]:
                agg = results['layer_results'][layer]['aggregate']
                if agg:  # Check if aggregate is not empty
                    layer_similarities.append(agg['avg_shift_metrics']['cosine_similarity'])
                    layer_overlaps.append(agg['avg_shift_metrics']['feature_overlap'])
                    layer_delta_losses.append(abs(agg['avg_model1_metrics']['model_delta_loss'] - 
                                                 agg['avg_model2_metrics']['model_delta_loss']))
                    layer_sparsities.append((agg['avg_model1_metrics']['l0_sparsity'] + 
                                           agg['avg_model2_metrics']['l0_sparsity']) / 2)

        
        # Overall insights
        overall = {
            'most_similar_layer': layers[np.argmax(layer_similarities)] if layer_similarities else None,
            'most_different_layer': layers[np.argmin(layer_similarities)] if layer_similarities else None,
            'highest_overlap_layer': layers[np.argmax(layer_overlaps)] if layer_overlaps else None,
            'highest_delta_loss_layer': layers[np.argmax(layer_delta_losses)] if layer_delta_losses else None,
            'avg_similarity_across_layers': np.mean(layer_similarities) if layer_similarities else 0,
            'avg_overlap_across_layers': np.mean(layer_overlaps) if layer_overlaps else 0,
            'avg_delta_loss_across_layers': np.mean(layer_delta_losses) if layer_delta_losses else 0,
            'layer_similarities': dict(zip(layers, layer_similarities)),
            'layer_overlaps': dict(zip(layers, layer_overlaps))
        }
        
        return overall

    def visualize_layer_sweep_results(self, results: Dict, model1_name: str, model2_name: str):
        """Create comprehensive visualization of layer sweep results."""
        layers = results['layers']
        
        # Create output filename
        model1_clean = model1_name.replace('/', '_').replace('-', '_')
        model2_clean = model2_name.replace('/', '_').replace('-', '_')
        save_path = self.output_dir / f"{model1_clean}_{model2_clean}_layer_sweep.png"
        
        fig, axes = plt.subplots(2, 3, figsize=(20, 12))
        fig.suptitle(f'SAE Layer Sweep Analysis: {model1_name} vs {model2_name}', fontsize=16)
        
        # Collect data across layers
        layer_data = {
            'similarities': [],
            'overlaps': [],
            'recon_losses_m1': [],
            'recon_losses_m2': [],
            'sparsities_m1': [],
            'sparsities_m2': [],
            'delta_losses_m1': [],
            'delta_losses_m2': []
        }
        
        for layer in layers:
            if layer in results['layer_results'] and 'aggregate' in results['layer_results'][layer]:
                agg = results['layer_results'][layer]['aggregate']
                if agg:
                    layer_data['similarities'].append(agg['avg_shift_metrics']['cosine_similarity'])
                    layer_data['overlaps'].append(agg['avg_shift_metrics']['feature_overlap'])
                    layer_data['recon_losses_m1'].append(agg['avg_model1_metrics']['reconstruction_loss'])
                    layer_data['recon_losses_m2'].append(agg['avg_model2_metrics']['reconstruction_loss'])
                    layer_data['sparsities_m1'].append(agg['avg_model1_metrics']['l0_sparsity'])
                    layer_data['sparsities_m2'].append(agg['avg_model2_metrics']['l0_sparsity'])
                    layer_data['delta_losses_m1'].append(agg['avg_model1_metrics']['model_delta_loss'])
                    layer_data['delta_losses_m2'].append(agg['avg_model2_metrics']['model_delta_loss'])
        
        # Plot 1: Representation Similarity Across Layers
        axes[0, 0].plot(layers, layer_data['similarities'], 'o-', linewidth=2, markersize=8)
        axes[0, 0].set_title('Cosine Similarity Across Layers')
        axes[0, 0].set_xlabel('Layer')
        axes[0, 0].set_ylabel('Cosine Similarity')
        axes[0, 0].grid(True, alpha=0.3)
#         axes[0, 0].axhline(y=0.8, color='red', linestyle='--', alpha=0.5, label='High Similarity')
        axes[0, 0].legend()
        
        # Plot 2: Feature Overlap Across Layers
        axes[0, 1].plot(layers, layer_data['overlaps'], 'o-', color='green', linewidth=2, markersize=8)
        axes[0, 1].set_title('Feature Overlap Across Layers')
        axes[0, 1].set_xlabel('Layer')
        axes[0, 1].set_ylabel('Feature Overlap')
        axes[0, 1].grid(True, alpha=0.3)
#         axes[0, 1].axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Moderate Overlap')
        axes[0, 1].legend()
        
        # Plot 3: Reconstruction Loss Comparison
#         axes[0, 2].plot(layers, layer_data['recon_losses_m1'], 'o-', label='Model 1 (LLM)', linewidth=2)
        axes[0, 2].plot(layers, layer_data['recon_losses_m2'], 's-', label='Model 2 (VLM)', linewidth=2)
        axes[0, 2].set_title('Reconstruction Loss Across Layers')
        axes[0, 2].set_xlabel('Layer')
        axes[0, 2].set_ylabel('Reconstruction Loss')
        axes[0, 2].legend()
        axes[0, 2].grid(True, alpha=0.3)
        
        # Plot 4: Sparsity Comparison
        axes[1, 0].plot(layers, layer_data['sparsities_m1'], 'o-', label='Model 1 (LLM)', linewidth=2)
        axes[1, 0].plot(layers, layer_data['sparsities_m2'], 's-', label='Model 2 (VLM)', linewidth=2)
        axes[1, 0].set_title('L0 Sparsity Across Layers')
        axes[1, 0].set_xlabel('Layer')
        axes[1, 0].set_ylabel('L0 Sparsity')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # Plot 5: Model Delta Loss (Patching Performance)
        axes[1, 1].plot(layers, layer_data['delta_losses_m1'], 'o-', label='Model 1 (LLM)', linewidth=2)
        axes[1, 1].plot(layers, layer_data['delta_losses_m2'], 's-', label='Model 2 (VLM)', linewidth=2)
        axes[1, 1].set_title('Model Delta Loss (Patching Quality)')
        axes[1, 1].set_xlabel('Layer')
        axes[1, 1].set_ylabel('Delta Loss')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        axes[1, 1].axhline(y=0, color='black', linestyle='-', alpha=0.3)
        
        # Plot 6: Summary Heatmap
        # Create a summary matrix for visualization
        metrics_matrix = np.array([
            layer_data['similarities'],
            layer_data['overlaps'],
            np.array(layer_data['recon_losses_m1']) / max(max(layer_data['recon_losses_m1']), 1e-6),  # Normalize
            np.array(layer_data['sparsities_m1']) * 10,  # Scale up for visibility
        ])
        
        im = axes[1, 2].imshow(metrics_matrix, cmap='RdYlBu_r', aspect='auto')
        axes[1, 2].set_title('Metrics Heatmap Across Layers')
        axes[1, 2].set_xlabel('Layer Index')
        axes[1, 2].set_yticks(range(4))
        axes[1, 2].set_yticklabels(['Similarity', 'Overlap', 'Recon Loss (norm)', 'Sparsity (x10)'])
        axes[1, 2].set_xticks(range(len(layers)))
        axes[1, 2].set_xticklabels([f'L{l}' for l in layers])
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=axes[1, 2])
        cbar.set_label('Metric Value')
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"‚úÖ Layer sweep visualization saved to {save_path}")
        
        # Save detailed results as JSON including top features
        json_path = self.output_dir / f"{model1_clean}_{model2_clean}_results.json"
        
        # Convert results to JSON-serializable format
        json_results = {
            'layers': layers,
            'overall_analysis': results['overall_analysis'],
            'layer_summaries': {},
            'top_features_analysis': getattr(self, 'top_features_log', [])
        }
        
        for layer in layers:
            if layer in results['layer_results'] and 'aggregate' in results['layer_results'][layer]:
                agg = results['layer_results'][layer]['aggregate']
                if agg:
                    json_results['layer_summaries'][str(layer)] = agg
        
        with open(json_path, 'w') as f:
            json.dump(json_results, f, indent=2)
        print(f"‚úÖ Detailed results saved to {json_path}")
        
        # Create top features analysis
#         self._analyze_top_features_trends()

    def interpret_layer_sweep_results(self, results: Dict) -> Dict[str, str]:
        """Provide interpretation of layer sweep results."""
        overall = results['overall_analysis']
        interpretations = {}
        
        # Overall adaptation assessment
        avg_similarity = overall['avg_similarity_across_layers']
        if avg_similarity > 0.85:
            interpretations['adaptation_magnitude'] = "‚úÖ MINIMAL LLM‚ÜíVLM adaptation - representations largely preserved"
        elif avg_similarity > 0.7:
            interpretations['adaptation_magnitude'] = "‚ö†Ô∏è MODERATE LLM‚ÜíVLM adaptation - selective representational changes"
        else:
            interpretations['adaptation_magnitude'] = "üîç SIGNIFICANT LLM‚ÜíVLM adaptation - substantial representational reorganization"
        
        # Layer-specific insights
        if overall['most_different_layer'] is not None:
            interpretations['adaptation_location'] = f"üéØ Layer {overall['most_different_layer']} shows maximum adaptation"
        
        if overall['highest_overlap_layer'] is not None:
            interpretations['feature_preservation'] = f"üîó Layer {overall['highest_overlap_layer']} best preserves LLM features"
        
        # Adaptation pattern
        layer_sims = list(overall['layer_similarities'].values())
        if len(layer_sims) >= 3:
            early_sim = np.mean(layer_sims[:len(layer_sims)//3])
            late_sim = np.mean(layer_sims[-len(layer_sims)//3:])
            
            if early_sim > late_sim + 0.1:
                interpretations['adaptation_pattern'] = "üìà Early layers preserve LLM representations better than late layers"
            elif late_sim > early_sim + 0.1:
                interpretations['adaptation_pattern'] = "üìâ Late layers preserve LLM representations better than early layers"
            else:
                interpretations['adaptation_pattern'] = "üìä Uniform adaptation pattern across layers"
        
        # SAE quality assessment
        avg_delta_loss = overall['avg_delta_loss_across_layers']
        if avg_delta_loss < 0.1:
            interpretations['sae_quality'] = "‚úÖ SAE reconstructions preserve model functionality well"
        elif avg_delta_loss < 0.5:
            interpretations['sae_quality'] = "‚ö†Ô∏è SAE reconstructions cause moderate functional degradation"
        else:
            interpretations['sae_quality'] = "‚ùå SAE reconstructions significantly impact model functionality"
        
        return interpretations




In [None]:
print("üöÄ Comprehensive SAE Layer Sweep Analysis: LLM‚ÜíVLM Adaptation")
print("=" * 70)

# Configuration
MODEL_SIZE = "2b"
WIDTH = "16k"
SUFFIX = "canonical"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
#     LAYERS = [4, 8, 12, 16, 20, 24]  # Sample across the model depth
LAYERS = list(range(1,27)) # for SAE, should be 0-25
try:
    # Initialize analyzer
    analyzer = MemoryEfficientSAEAnalyzer(
        model_size=MODEL_SIZE,
        width=WIDTH,
        suffix=SUFFIX,
        device=DEVICE
    )

    # Load dataset
    print("\nüìö Loading Datasets...")
    dataset_loader = DatasetLoader(device=DEVICE)
    texts = dataset_loader.get_mixed_dataset(total_samples=1000)  # Use 1K data as requested

    print(f"‚úÖ Loaded {len(texts)} texts from mixed datasets")
    print(f"Sample texts: {texts[:3]}")

    # Model configuration for LLM->VLM comparison
    model1_name = "google/gemma-2-2b"  # Base Gemma-2-2B (LLM)
    model2_name = "google/paligemma2-3b-pt-224"  # PaliGemma with Gemma-2-2B decoder (VLM)

    print(f"\nüî¨ Research Configuration:")
    print(f"   Model 1 (LLM): {model1_name}")
    print(f"   Model 2 (VLM): {model2_name}")
    print(f"   Layers to analyze: {LAYERS}")
    print(f"   SAE Configuration: {MODEL_SIZE}-{WIDTH}-{SUFFIX}")
    print(f"   Device: {DEVICE}")
    print(f"   Total texts: {len(texts)}")

    # Run layer sweep analysis
    print(f"\nüöÄ Starting Layer Sweep Analysis...")
    results = analyzer.analyze_layer_sweep(
        model1_name=model1_name,
        model2_name=model2_name,
        texts=texts,
        layers=LAYERS
    )

    # Generate interpretations
    interpretations = analyzer.interpret_layer_sweep_results(results)

    print(f"\nüìä LAYER SWEEP RESULTS:")
    print("=" * 50)

    overall = results['overall_analysis']
    print(f"Most Similar Layer: {overall['most_similar_layer']}")
    print(f"Most Different Layer: {overall['most_different_layer']}")
    print(f"Highest Feature Overlap Layer: {overall['highest_overlap_layer']}")
    print(f"Average Similarity Across Layers: {overall['avg_similarity_across_layers']:.3f}")
    print(f"Average Feature Overlap: {overall['avg_overlap_across_layers']:.3f}")

    print(f"\nüîç INTERPRETATIONS:")
    print("=" * 50)
    for aspect, interpretation in interpretations.items():
        print(f"{aspect.replace('_', ' ').title()}: {interpretation}")

    # Create visualizations
    print(f"\nüìà Generating Visualizations...")
    analyzer.visualize_layer_sweep_results(results, model1_name, model2_name)

    print(f"\n‚úÖ Analysis Complete!")
    print(f"üìÅ Results saved to: {analyzer.output_dir}")
    print(f"üß† Key Finding: {interpretations.get('adaptation_magnitude', 'Analysis completed')}")

    # Memory cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print(f"üîß Final GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB")

except Exception as e:
    print(f"‚ùå Error during analysis: {e}")
    import traceback
    traceback.print_exc()

    print("\nüí° Troubleshooting Tips:")
    print("   1. Ensure sufficient GPU memory (8GB+ recommended)")
    print("   2. Reduce LAYERS list or sample size if out of memory")
    print("   3. Check model names are correct and accessible")
    print("   4. Install required packages: pip install sae-lens transformers datasets")

üöÄ Comprehensive SAE Layer Sweep Analysis: LLM‚ÜíVLM Adaptation
üîß Initialized SAE Analyzer
   Device: cuda
   Model Size: 2b
   SAE Width: 16k
   Output Dir: ../figs_tabs

üìö Loading Datasets...
‚úÖ Loaded 666 CIFAR-100 captions
‚ùå Error loading COCO: 'utf-8' codec can't decode byte 0xc4 in position 4: invalid continuation byte
‚ùå Error loading LLaVA-Bench: Config name is missing.
Please pick one among the available configs: ['CLEVR-Math(MathV360K)', 'Evol-Instruct-GPT4-Turbo', 'FigureQA(MathV360K)', 'GEOS(MathV360K)', 'GeoQA+(MathV360K)', 'Geometry3K(MathV360K)', 'IconQA(MathV360K)', 'MapQA(MathV360K)', 'MathV360K_TQA', 'MathV360K_VQA-AS', 'MathV360K_VQA-RAD', 'PMC-VQA(MathV360K)', 'Super-CLEVR(MathV360K)', 'TabMWP(MathV360K)', 'UniGeo(MathV360K)', 'VisualWebInstruct(filtered)', 'VizWiz(MathV360K)', 'ai2d(cauldron,llava_format)', 'ai2d(gpt4v)', 'ai2d(internvl)', 'allava_instruct_laion4v', 'allava_instruct_vflan4v', 'aokvqa(cauldron,llava_format)', 'chart2text(cauldron)', 'cha

Processing layers:   0%|          | 0/26 [00:00<?, ?it/s]


üìä Processing Layer 1
   üì• Loading SAE Layer 0: layer_0/width_16k/canonical
