In [4]:
#!/usr/bin/env python3
"""
Comprehensive SAE-based representation shift analysis with layer sweeping,
real datasets, and patching logic for LLM->VLM adaptation studies.
"""
# Installation requirements:
"""
pip install sae-lens transformers torch matplotlib seaborn numpy datasets tqdm

# For CUDA support (recommended):
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
"""

import torch
import numpy as np
import os
import gc
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from sae_lens import SAE
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List, Optional, Union
from dataclasses import dataclass
import seaborn as sns
from datasets import load_dataset
import json
from tqdm import tqdm
import warnings
import random
import json
warnings.filterwarnings("ignore")



In [5]:
#!/usr/bin/env python3
"""
Comprehensive SAE-based representation shift analysis with layer sweeping,
real datasets, and patching logic for LLM->VLM adaptation studies.
FIXED: Handles PaliGemma loss computation correctly.
"""

import torch
import numpy as np
import os
import gc
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from sae_lens import SAE
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List, Optional, Union
from dataclasses import dataclass
import seaborn as sns
from datasets import load_dataset
import json
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

# Disable gradients globally for memory efficiency
torch.set_grad_enabled(False)

@dataclass
class SAEMetrics:
    """Container for SAE evaluation metrics."""
    reconstruction_loss: float
    l0_sparsity: float
    l1_sparsity: float
    fraction_alive: float
    mean_max_activation: float
    reconstruction_score: float
    model_delta_loss: float  # New patching metric

@dataclass
class RepresentationShift:
    """Container for representation shift metrics."""
    cosine_similarity: float
    l2_distance: float
    feature_overlap: float
    js_divergence: float
    feature_correlation: float

class DatasetLoader:
    """Handles loading and preprocessing of various datasets."""
    
    def __init__(self, device: str = "cuda", data_save_dir: str = "../data"):
        self.device = device
        self.data_save_dir = data_save_dir
    
    def load_cifar100_captions(self, split: str = "train", max_samples: int = 100) -> List[str]:
        """Load CIFAR-100 with generated captions for multimodal analysis."""
        try:
            # CIFAR-100 doesn't have captions by default, so we create descriptive ones
            dataset = load_dataset("cifar100", split=split)
            
            # CIFAR-100 class names
            class_names = [
                'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
                'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
                'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
                'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
                'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
                'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
                'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
                'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
                'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
                'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
                'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
                'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
                'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
                'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
            ]
            
            texts = []
            for i, sample in enumerate(dataset):
                if i >= max_samples:
                    break
                class_name = class_names[sample['fine_label']]
                # Generate descriptive captions
                captions = [
                    f"This is a photo of a {class_name}.",
                    f"An image showing a {class_name}.",
                    f"A picture of a {class_name} in natural setting.",
                    f"Visual representation of a {class_name}."
                ]
                texts.extend(captions[:2])  # Take 2 captions per image
            
            print(f"✅ Loaded {len(texts)} CIFAR-100 captions")
            return texts[:max_samples]
            
        except Exception as e:
            print(f"❌ Error loading CIFAR-100: {e}")
            return self._get_fallback_texts()
    
    def load_coco_captions(self, split: str = "validation", max_samples: int = 100) -> List[str]:
        """Load COCO captions dataset."""
        try:
            # Load COCO captions
            dataset = load_dataset("HuggingFaceM4/COCO", split=split)
            
            texts = []
            for i, sample in enumerate(dataset):
                if i >= max_samples:
                    break
                
                # COCO has multiple captions per image
                if 'sentences' in sample and 'raw' in sample['sentences']:
                    for sentence in sample['sentences']['raw'][:2]:  # Take first 2 captions
                        texts.append(sentence)
                elif 'caption' in sample:
                    texts.append(sample['caption'])
            
            print(f"✅ Loaded {len(texts)} COCO captions")
            return texts[:max_samples]
            
        except Exception as e:
            print(f"❌ Error loading COCO: {e}")
            # Try alternative COCO dataset
            try:
                dataset = load_dataset("nielsr/coco-captions", split="validation")
                texts = [sample['caption'] for sample in dataset.select(range(min(max_samples, len(dataset))))]
                print(f"✅ Loaded {len(texts)} COCO captions (alternative)")
                return texts
            except:
                return self._get_fallback_texts()
    
    def load_llava_bench(self, max_samples: int = 100) -> List[str]:
        """Load LLaVA-Bench questions/descriptions."""
        try:
            # LLaVA bench conversations
            dataset = load_dataset("lmms-lab/LLaVA-OneVision-Data", split="dev_mini")
            
            texts = []
            for i, sample in enumerate(dataset):
                if i >= max_samples:
                    break
                
                if 'conversations' in sample:
                    for conv in sample['conversations'][:2]:  # Take first 2 conversations
                        if 'value' in conv:
                            texts.append(conv['value'])
            
            print(f"✅ Loaded {len(texts)} LLaVA-Bench texts")
            return texts[:max_samples]
            
        except Exception as e:
            print(f"❌ Error loading LLaVA-Bench: {e}")
            return self._get_fallback_texts()
    
    def _get_fallback_texts(self) -> List[str]:
        """Fallback texts if datasets fail to load."""
        return [
            "A photo of a red apple on a white background.",
            "The cat is sitting on a wooden chair.",
            "Mountains covered with snow in winter landscape.",
            "A blue car driving on a highway.",
            "Children playing in a park with green grass.",
            "A delicious chocolate cake on a plate.",
            "Ocean waves crashing against rocky shore.",
            "A person reading a book in a library.",
            "Colorful flowers blooming in spring garden.",
            "A dog running happily in the field.",
        ]
    
    def get_mixed_dataset(self, total_samples: int = 150, read_local: bool = False, save: bool = True) -> List[str]:
        """Get a mixed dataset from multiple sources."""
        samples_per_source = total_samples // 3
        
        if read_local:
            with open(f"{self.data_save_dir}/texts.json", "r", encoding="utf-8") as f:
                texts = json.load(f)
        else:
            texts = []
            texts.extend(self.load_cifar100_captions(max_samples=samples_per_source))
            texts.extend(self.load_coco_captions(max_samples=samples_per_source))
            texts.extend(self.load_llava_bench(max_samples=samples_per_source))


            # Shuffle for good measure
            
            random.shuffle(texts)
            if save:
                with open(f"{self.data_save_dir}/texts.json", "w", encoding="utf-8") as f:
                    json.dump(texts, f, ensure_ascii=False, indent=2)
                    
        return texts[:total_samples]

class MemoryEfficientSAEAnalyzer:
    """Memory-efficient SAE analyzer with layer sweeping and patching logic."""
    
    def __init__(self, 
                 model_size: str = "2b",
                 width: str = "16k", 
                 suffix: str = "canonical",
                 device: str = "cuda",
                 output_dir: str = "../figs_tabs"):
        """
        Initialize memory-efficient SAE analyzer.
        
        Args:
            model_size: Model size ("2b" or "9b")
            width: SAE width ("16k", "65k", "262k")
            suffix: SAE variant ("canonical" or specific L0)
            device: Device to use
            output_dir: Directory for saving outputs
        """
        self.device = device if torch.cuda.is_available() else "cpu"
        self.model_size = model_size
        self.width = width
        self.suffix = suffix
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Model cache for memory efficiency
        self.model_cache = {}
        self.sae_cache = {}
        
        print(f"🔧 Initialized SAE Analyzer")
        print(f"   Device: {self.device}")
        print(f"   Model Size: {model_size}")
        print(f"   SAE Width: {width}")
        print(f"   Output Dir: {output_dir}")

    def get_gemmascope_sae(self, layer: int) -> SAE:
        """Load Gemma Scope SAE with caching for memory efficiency."""
        cache_key = f"layer_{layer}"
        
        if cache_key in self.sae_cache:
            return self.sae_cache[cache_key]
        
        release = f"gemma-scope-{self.model_size}-pt-res"
        if self.suffix == "canonical":
            release = f"gemma-scope-{self.model_size}-pt-res-canonical"
            sae_id = f"layer_{layer}/width_{self.width}/canonical"
        else:
            sae_id = f"layer_{layer}/width_{self.width}/{self.suffix}"
        
        print(f"   📥 Loading SAE Layer {layer}: {sae_id}")
        
        try:
            sae = SAE.from_pretrained(release, sae_id).to(self.device)
            sae.eval()
            
            # Cache management - keep only last 2 SAEs to save memory
            if len(self.sae_cache) >= 2:
                oldest_key = list(self.sae_cache.keys())[0]
                del self.sae_cache[oldest_key]
                gc.collect()
            
            self.sae_cache[cache_key] = sae
            return sae
            
        except Exception as e:
            print(f"❌ Error loading SAE layer {layer}: {e}")
            raise

    def get_model(self, model_name: str):
        """Load model with caching and proper device placement."""
        if model_name in self.model_cache:
            return self.model_cache[model_name]
        
        print(f"📥 Loading model: {model_name}")
        
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            # Handle different model types
            if "paligemma" in model_name.lower():
                from transformers import PaliGemmaForConditionalGeneration
                model = PaliGemmaForConditionalGeneration.from_pretrained(
                    model_name, 
                    trust_remote_code=True,
                    torch_dtype=torch.float32,  # Use fp16 for memory efficiency
                    device_map=None  # We'll handle device placement manually
                )
                model = model.to(self.device)
                language_model = model.language_model
            else:
                model = AutoModelForCausalLM.from_pretrained(
                    model_name, 
                    trust_remote_code=True,
                    torch_dtype=torch.float32,
                    device_map=None
                )
                model = model.to(self.device)
                language_model = model
            
            language_model.eval()
            
            # Cache management - keep only one model at a time
            if len(self.model_cache) >= 1:
                for cached_name in list(self.model_cache.keys()):
                    del self.model_cache[cached_name]
                gc.collect()
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
            self.model_cache[model_name] = (tokenizer, model, language_model)
            return tokenizer, model, language_model
            
        except Exception as e:
            print(f"❌ Error loading model {model_name}: {e}")
            raise

    def extract_activations_with_patching(self, 
                                        model_name: str, 
                                        text: str, 
                                        layer: int,
                                        sae: Optional[SAE] = None) -> Tuple[torch.Tensor, float]:
        """
        Extract activations and compute model delta loss with patching.
        FIXED: Properly handles loss computation for both model types.
        
        Returns:
            Tuple of (activations, model_delta_loss)
        """
        tokenizer, model, language_model = self.get_model(model_name)
        
        # Tokenize input
        inputs = tokenizer(
            text, 
            return_tensors="pt", 
            padding="max_length",
            truncation=True,
            max_length=64
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Get unpatched model loss (baseline)
        with torch.no_grad():
            try:
                if "paligemma" in model_name.lower():
                    # For PaliGemma, we need to handle the text-only input properly
                    # Create target labels for loss computation
                    labels = inputs['input_ids'].clone()
                    # Mask padding tokens in labels
                    labels[labels == tokenizer.pad_token_id] = -100
                    unpatched_outputs = language_model(**inputs, labels=labels)
                else:
                    # For regular language models
                    labels = inputs['input_ids'].clone()
                    labels[labels == tokenizer.pad_token_id] = -100
                    unpatched_outputs = language_model(**inputs, labels=labels)
                
                # Check if loss is available
                if hasattr(unpatched_outputs, 'loss') and unpatched_outputs.loss is not None:
                    unpatched_loss = unpatched_outputs.loss.item()
                else:
                    # Compute loss manually using logits
                    logits = unpatched_outputs.logits
                    shift_logits = logits[..., :-1, :].contiguous()
                    shift_labels = labels[..., 1:].contiguous()
                    
                    # Flatten for cross entropy computation
                    shift_logits = shift_logits.view(-1, shift_logits.size(-1))
                    shift_labels = shift_labels.view(-1)
                    
                    # Compute cross entropy loss (ignore -100 labels)
                    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
                    unpatched_loss = loss_fct(shift_logits, shift_labels).item()
                    
            except Exception as e:
                print(f"⚠️  Error computing unpatched loss: {e}")
                unpatched_loss = 0.0
        
        # Extract activations from target layer
        activations = None
        patched_loss = unpatched_loss  # Default if no patching
        
        def activation_hook(module, input, output):
            nonlocal activations
            if isinstance(output, tuple):
                activations = output[0].clone()
            else:
                activations = output.clone()
        
        # Hook the target layer
        if hasattr(language_model, 'model') and hasattr(language_model.model, 'layers'):
            target_layer = language_model.model.layers[layer]
        elif hasattr(language_model, 'layers'):
            target_layer = language_model.layers[layer]
        else:
            print(f"❌ Could not find layer {layer}")
            return torch.randn(1, 64, 2304).to(self.device), 0.0
        
        hook = target_layer.register_forward_hook(activation_hook)
        
        # Forward pass to get activations
        with torch.no_grad():
            try:
                if "paligemma" in model_name.lower():
                    _ = language_model(**inputs)
                else:
                    _ = language_model(**inputs)
            except Exception as e:
                print(f"⚠️  Error in activation extraction: {e}")
        
        hook.remove()
        
        # Compute patched loss if SAE is provided
        if sae is not None and activations is not None:
            patched_loss = self._compute_patched_loss(
                language_model, inputs, activations, sae, layer, model_name
            )
        
        model_delta_loss = patched_loss - unpatched_loss
        
        if activations is None:
            print(f"⚠️  Failed to extract activations from layer {layer}")
            activations = torch.randn(1, 64, 2304).to(self.device)
        
        return activations, model_delta_loss

    def _compute_patched_loss(self, 
                            language_model, 
                            inputs: Dict, 
                            original_activations: torch.Tensor, 
                            sae: SAE, 
                            layer: int,
                            model_name: str) -> float:
        """Compute loss with SAE-patched activations. FIXED: Proper loss computation."""
        try:
            # Get SAE reconstruction
            flat_activations = original_activations.view(-1, original_activations.size(-1))
            sae_output = sae(flat_activations)
            
            # Handle different SAE output formats
            if hasattr(sae_output, 'sae_out'):
                reconstructed = sae_output.sae_out
            elif isinstance(sae_output, tuple):
                reconstructed = sae_output[0]
            else:
                reconstructed = sae_output
            
            # Reshape back to original shape
            reconstructed = reconstructed.view(original_activations.shape)
            
            # Patch the reconstructed activations back into the model
            patched_activations = reconstructed
            
            # Create a patching hook
            def patching_hook(module, input, output):
                if isinstance(output, tuple):
                    return (patched_activations, *output[1:])
                else:
                    return patched_activations
            
            # Hook the target layer for patching
            if hasattr(language_model, 'model') and hasattr(language_model.model, 'layers'):
                target_layer = language_model.model.layers[layer]
            elif hasattr(language_model, 'layers'):
                target_layer = language_model.layers[layer]
            else:
                return 0.0
            
            patch_hook = target_layer.register_forward_hook(patching_hook)
            
            # Forward pass with patched activations
            with torch.no_grad():
                if "paligemma" in model_name.lower():
                    # For PaliGemma, create proper labels
                    labels = inputs['input_ids'].clone()
                    labels[labels == tokenizer.pad_token_id] = -100
                    patched_outputs = language_model(**inputs, labels=labels)
                else:
                    # For regular language models
                    labels = inputs['input_ids'].clone()
                    labels[labels == tokenizer.pad_token_id] = -100
                    patched_outputs = language_model(**inputs, labels=labels)
                
                # Get patched loss
                if hasattr(patched_outputs, 'loss') and patched_outputs.loss is not None:
                    patched_loss = patched_outputs.loss.item()
                else:
                    # Compute loss manually
                    logits = patched_outputs.logits
                    shift_logits = logits[..., :-1, :].contiguous()
                    shift_labels = labels[..., 1:].contiguous()
                    
                    shift_logits = shift_logits.view(-1, shift_logits.size(-1))
                    shift_labels = shift_labels.view(-1)
                    
                    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
                    patched_loss = loss_fct(shift_logits, shift_labels).item()
            
            patch_hook.remove()
            return patched_loss
            
        except Exception as e:
            print(f"⚠️  Patching failed: {e}")
            return 0.0

    def compute_sae_metrics(self, activations: torch.Tensor, sae: SAE, model_delta_loss: float) -> SAEMetrics:
        """Compute comprehensive SAE evaluation metrics including model delta loss."""
        with torch.no_grad():
            # Reshape activations for SAE processing
            batch_size, seq_len, d_model = activations.shape
            flat_activations = activations.view(-1, d_model)
            
            # Forward pass through SAE
            sae_output = sae(flat_activations)
            
            # Handle different SAE output formats
            if hasattr(sae_output, 'feature_acts'):
                feature_acts = sae_output.feature_acts
                reconstructed = sae_output.sae_out
            elif isinstance(sae_output, tuple) and len(sae_output) >= 2:
                reconstructed, feature_acts = sae_output[0], sae_output[1]
            elif hasattr(sae, 'encode') and hasattr(sae, 'decode'):
                feature_acts = sae.encode(flat_activations)
                reconstructed = sae.decode(feature_acts)
            else:
                reconstructed = sae_output
                if hasattr(sae, 'W_enc') and hasattr(sae, 'b_enc'):
                    feature_acts = torch.relu(flat_activations @ sae.W_enc + sae.b_enc)
                else:
                    feature_acts = torch.randn(flat_activations.shape[0], 16384, device=flat_activations.device)
            
            # 1. Reconstruction Loss (MSE)
            reconstruction_loss = torch.nn.functional.mse_loss(reconstructed, flat_activations).item()
            
            # 2. L0 Sparsity (fraction of non-zero features)
            l0_sparsity = (feature_acts > 0).float().mean().item()
            
            # 3. L1 Sparsity (mean absolute activation)
            l1_sparsity = feature_acts.abs().mean().item()
            
            # 4. Fraction of features that are ever active
            fraction_alive = (feature_acts.max(dim=0)[0] > 0).float().mean().item()
            
            # 5. Mean maximum activation per sample
            mean_max_activation = feature_acts.max(dim=1)[0].mean().item()
            
            # 6. Reconstruction score (explained variance)
            var_original = flat_activations.var(dim=0).mean()
            var_residual = (flat_activations - reconstructed).var(dim=0).mean()
            reconstruction_score = max(0.0, 1 - (var_residual / var_original).item())
            
            return SAEMetrics(
                reconstruction_loss=reconstruction_loss,
                l0_sparsity=l0_sparsity,
                l1_sparsity=l1_sparsity,
                fraction_alive=fraction_alive,
                mean_max_activation=mean_max_activation,
                reconstruction_score=reconstruction_score,
                model_delta_loss=model_delta_loss
            )

    def analyze_layer_sweep(self, 
                           model1_name: str, 
                           model2_name: str, 
                           texts: List[str],
                           layers: List[int] = None) -> Dict:
        """
        Perform memory-efficient layer sweep analysis.
        
        Args:
            model1_name: First model (base LLM)
            model2_name: Second model (VLM) 
            texts: List of texts to analyze
            layers: List of layers to analyze (default: [8, 12, 16, 20])
        """
        if layers is None:
            layers = [8, 12, 16, 20]  # Sample layers across the model
        
        print(f"🚀 Starting Layer Sweep Analysis")
        print(f"   Model 1: {model1_name}")
        print(f"   Model 2: {model2_name}")
        print(f"   Layers: {layers}")
        print(f"   Texts: {len(texts)} samples")
        print(f"   Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB" if torch.cuda.is_available() else "")
        
        results = {
            'layers': layers,
            'layer_results': {},
            'texts': texts[:10]  # Store subset for reference
        }
        
        for layer in tqdm(layers, desc="Processing layers"):
            print(f"\n📊 Processing Layer {layer}")
            
            # Load SAE for this layer
            sae = self.get_gemmascope_sae(layer)
            
            layer_metrics = {
                'model1_metrics': [],
                'model2_metrics': [],
                'shift_metrics': []
            }
            
            # Process subset of texts for each layer (memory efficiency)
            sample_texts = texts[:20]  # Process 20 texts per layer
            
            for i, text in enumerate(tqdm(sample_texts, desc=f"Layer {layer} texts", leave=False)):
                try:
                    # Extract activations and compute metrics for model 1
                    acts1, delta_loss1 = self.extract_activations_with_patching(
                        model1_name, text, layer, sae
                    )
                    metrics1 = self.compute_sae_metrics(acts1, sae, delta_loss1)
                    
                    # Extract activations and compute metrics for model 2
                    acts2, delta_loss2 = self.extract_activations_with_patching(
                        model2_name, text, layer, sae
                    )
                    metrics2 = self.compute_sae_metrics(acts2, sae, delta_loss2)
                    
                    # Compute representation shift
                    shift = self.compute_representation_shift(acts1, acts2, sae)
                    
                    layer_metrics['model1_metrics'].append(metrics1)
                    layer_metrics['model2_metrics'].append(metrics2)
                    layer_metrics['shift_metrics'].append(shift)
                    
                except Exception as e:
                    print(f"⚠️  Error processing text {i} in layer {layer}: {e}")
                    continue
            
            # Compute layer-level aggregates
            layer_metrics['aggregate'] = self._compute_layer_aggregate(layer_metrics)
            results['layer_results'][layer] = layer_metrics
            
            # Memory cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                print(f"   Memory after layer {layer}: {torch.cuda.memory_allocated() / 1e9:.2f}GB")
        
        # Compute overall analysis
        results['overall_analysis'] = self._compute_overall_analysis(results)
        
        return results

    def compute_representation_shift(self, 
                                   activations1: torch.Tensor, 
                                   activations2: torch.Tensor,
                                   sae: SAE) -> RepresentationShift:
        """Compute representation shift metrics using SAE features."""
        with torch.no_grad():
            # Process both activation sets through SAE
            flat_acts1 = activations1.view(-1, activations1.size(-1))
            flat_acts2 = activations2.view(-1, activations2.size(-1))
            
            # Get SAE features
            def extract_features(flat_acts):
                sae_output = sae(flat_acts)
                if hasattr(sae_output, 'feature_acts'):
                    return sae_output.feature_acts
                elif isinstance(sae_output, tuple) and len(sae_output) >= 2:
                    return sae_output[1]
                elif hasattr(sae, 'encode'):
                    return sae.encode(flat_acts)
                else:
                    if hasattr(sae, 'W_enc') and hasattr(sae, 'b_enc'):
                        return torch.relu(flat_acts @ sae.W_enc + sae.b_enc)
                    else:
                        return torch.randn(flat_acts.shape[0], 16384, device=flat_acts.device)
            
            features1 = extract_features(flat_acts1)
            features2 = extract_features(flat_acts2)
            
            # 1. Cosine similarity
            cosine_sim = torch.nn.functional.cosine_similarity(
                features1.mean(dim=0), features2.mean(dim=0), dim=0
            ).item()
            
            # 2. L2 distance
            l2_distance = torch.norm(features1.mean(dim=0) - features2.mean(dim=0), p=2).item()
            
            # 3. Feature overlap (Jaccard similarity)
            active1 = (features1 > 0).float()
            active2 = (features2 > 0).float()
            intersection = (active1 * active2).sum(dim=0)
            union = torch.clamp(active1.sum(dim=0) + active2.sum(dim=0) - intersection, min=1)
            feature_overlap = (intersection / union).mean().item()
            
            # 4. Jensen-Shannon divergence
            def js_divergence(p, q):
                p = p + 1e-8
                q = q + 1e-8
                p = p / p.sum()
                q = q / q.sum()
                m = 0.5 * (p + q)
                return 0.5 * (torch.nn.functional.kl_div(p.log(), m, reduction='sum') + 
                             torch.nn.functional.kl_div(q.log(), m, reduction='sum'))
            
            p = features1.mean(dim=0).abs()
            q = features2.mean(dim=0).abs()
            js_div = js_divergence(p, q).item()
            
            # 5. Feature correlation
            try:
                corr_matrix = torch.corrcoef(torch.stack([
                    features1.mean(dim=0), features2.mean(dim=0)
                ]))
                feature_correlation = corr_matrix[0, 1].item() if not torch.isnan(corr_matrix[0, 1]) else 0.0
            except:
                feature_correlation = 0.0
            
            return RepresentationShift(
                cosine_similarity=cosine_sim,
                l2_distance=l2_distance,
                feature_overlap=feature_overlap,
                js_divergence=js_div,
                feature_correlation=feature_correlation
            )

    def _compute_layer_aggregate(self, layer_metrics: Dict) -> Dict:
        """Compute aggregate statistics for a single layer."""
        n_samples = len(layer_metrics['model1_metrics'])
        if n_samples == 0:
            return {}
        
        # Average metrics across samples
        avg_model1 = {}
        avg_model2 = {}
        avg_shift = {}
        
        for field in SAEMetrics.__dataclass_fields__:
            avg_model1[field] = np.mean([getattr(m, field) for m in layer_metrics['model1_metrics']])
            avg_model2[field] = np.mean([getattr(m, field) for m in layer_metrics['model2_metrics']])
        
        for field in RepresentationShift.__dataclass_fields__:
            avg_shift[field] = np.mean([getattr(s, field) for s in layer_metrics['shift_metrics']])
        
        return {
            'avg_model1_metrics': avg_model1,
            'avg_model2_metrics': avg_model2,
            'avg_shift_metrics': avg_shift,
            'n_samples': n_samples
        }

    def _compute_overall_analysis(self, results: Dict) -> Dict:
        """Compute overall analysis across all layers."""
        layers = results['layers']
        
        # Collect metrics across layers
        layer_similarities = []
        layer_overlaps = []
        layer_delta_losses = []
        layer_sparsities = []
        
        for layer in layers:
            if layer in results['layer_results'] and 'aggregate' in results['layer_results'][layer]:
                agg = results['layer_results'][layer]['aggregate']
                if agg:  # Check if aggregate is not empty
                    layer_similarities.append(agg['avg_shift_metrics']['cosine_similarity'])
                    layer_overlaps.append(agg['avg_shift_metrics']['feature_overlap'])
                    layer_delta_losses.append(abs(agg['avg_model1_metrics']['model_delta_loss'] - 
                                                 agg['avg_model2_metrics']['model_delta_loss']))
                    layer_sparsities.append((agg['avg_model1_metrics']['l0_sparsity'] + 
                                           agg['avg_model2_metrics']['l0_sparsity']) / 2)
        
        # Overall insights
        overall = {
            'most_similar_layer': layers[np.argmax(layer_similarities)] if layer_similarities else None,
            'most_different_layer': layers[np.argmin(layer_similarities)] if layer_similarities else None,
            'highest_overlap_layer': layers[np.argmax(layer_overlaps)] if layer_overlaps else None,
            'highest_delta_loss_layer': layers[np.argmax(layer_delta_losses)] if layer_delta_losses else None,
            'avg_similarity_across_layers': np.mean(layer_similarities) if layer_similarities else 0,
            'avg_overlap_across_layers': np.mean(layer_overlaps) if layer_overlaps else 0,
            'avg_delta_loss_across_layers': np.mean(layer_delta_losses) if layer_delta_losses else 0,
            'layer_similarities': dict(zip(layers, layer_similarities)),
            'layer_overlaps': dict(zip(layers, layer_overlaps))
        }
        
        return overall

    def visualize_layer_sweep_results(self, results: Dict, model1_name: str, model2_name: str):
        """Create comprehensive visualization of layer sweep results."""
        layers = results['layers']
        
        # Create output filename
        model1_clean = model1_name.replace('/', '_').replace('-', '_')
        model2_clean = model2_name.replace('/', '_').replace('-', '_')
        save_path = self.output_dir / f"{model1_clean}_{model2_clean}_layer_sweep.png"
        
        fig, axes = plt.subplots(2, 3, figsize=(20, 12))
        fig.suptitle(f'SAE Layer Sweep Analysis: {model1_name} vs {model2_name}', fontsize=16)
        
        # Collect data across layers
        layer_data = {
            'similarities': [],
            'overlaps': [],
            'recon_losses_m1': [],
            'recon_losses_m2': [],
            'sparsities_m1': [],
            'sparsities_m2': [],
            'delta_losses_m1': [],
            'delta_losses_m2': []
        }
        
        for layer in layers:
            if layer in results['layer_results'] and 'aggregate' in results['layer_results'][layer]:
                agg = results['layer_results'][layer]['aggregate']
                if agg:
                    layer_data['similarities'].append(agg['avg_shift_metrics']['cosine_similarity'])
                    layer_data['overlaps'].append(agg['avg_shift_metrics']['feature_overlap'])
                    layer_data['recon_losses_m1'].append(agg['avg_model1_metrics']['reconstruction_loss'])
                    layer_data['recon_losses_m2'].append(agg['avg_model2_metrics']['reconstruction_loss'])
                    layer_data['sparsities_m1'].append(agg['avg_model1_metrics']['l0_sparsity'])
                    layer_data['sparsities_m2'].append(agg['avg_model2_metrics']['l0_sparsity'])
                    layer_data['delta_losses_m1'].append(agg['avg_model1_metrics']['model_delta_loss'])
                    layer_data['delta_losses_m2'].append(agg['avg_model2_metrics']['model_delta_loss'])
        
        # Plot 1: Representation Similarity Across Layers
        axes[0, 0].plot(layers, layer_data['similarities'], 'o-', linewidth=2, markersize=8)
        axes[0, 0].set_title('Cosine Similarity Across Layers')
        axes[0, 0].set_xlabel('Layer')
        axes[0, 0].set_ylabel('Cosine Similarity')
        axes[0, 0].grid(True, alpha=0.3)
        axes[0, 0].axhline(y=0.8, color='red', linestyle='--', alpha=0.5, label='High Similarity')
        axes[0, 0].legend()
        
        # Plot 2: Feature Overlap Across Layers
        axes[0, 1].plot(layers, layer_data['overlaps'], 'o-', color='green', linewidth=2, markersize=8)
        axes[0, 1].set_title('Feature Overlap Across Layers')
        axes[0, 1].set_xlabel('Layer')
        axes[0, 1].set_ylabel('Feature Overlap')
        axes[0, 1].grid(True, alpha=0.3)
        axes[0, 1].axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Moderate Overlap')
        axes[0, 1].legend()
        
        # Plot 3: Reconstruction Loss Comparison
        axes[0, 2].plot(layers, layer_data['recon_losses_m1'], 'o-', label='Model 1 (LLM)', linewidth=2)
        axes[0, 2].plot(layers, layer_data['recon_losses_m2'], 's-', label='Model 2 (VLM)', linewidth=2)
        axes[0, 2].set_title('Reconstruction Loss Across Layers')
        axes[0, 2].set_xlabel('Layer')
        axes[0, 2].set_ylabel('Reconstruction Loss')
        axes[0, 2].legend()
        axes[0, 2].grid(True, alpha=0.3)
        
        # Plot 4: Sparsity Comparison
        axes[1, 0].plot(layers, layer_data['sparsities_m1'], 'o-', label='Model 1 (LLM)', linewidth=2)
        axes[1, 0].plot(layers, layer_data['sparsities_m2'], 's-', label='Model 2 (VLM)', linewidth=2)
        axes[1, 0].set_title('L0 Sparsity Across Layers')
        axes[1, 0].set_xlabel('Layer')
        axes[1, 0].set_ylabel('L0 Sparsity')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # Plot 5: Model Delta Loss (Patching Performance)
        axes[1, 1].plot(layers, layer_data['delta_losses_m1'], 'o-', label='Model 1 (LLM)', linewidth=2)
        axes[1, 1].plot(layers, layer_data['delta_losses_m2'], 's-', label='Model 2 (VLM)', linewidth=2)
        axes[1, 1].set_title('Model Delta Loss (Patching Quality)')
        axes[1, 1].set_xlabel('Layer')
        axes[1, 1].set_ylabel('Delta Loss')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        axes[1, 1].axhline(y=0, color='black', linestyle='-', alpha=0.3)
        
        # Plot 6: Summary Heatmap
        # Create a summary matrix for visualization
        metrics_matrix = np.array([
            layer_data['similarities'],
            layer_data['overlaps'],
            np.array(layer_data['recon_losses_m1']) / max(max(layer_data['recon_losses_m1']), 1e-6),  # Normalize
            np.array(layer_data['sparsities_m1']) * 10,  # Scale up for visibility
        ])
        
        im = axes[1, 2].imshow(metrics_matrix, cmap='RdYlBu_r', aspect='auto')
        axes[1, 2].set_title('Metrics Heatmap Across Layers')
        axes[1, 2].set_xlabel('Layer Index')
        axes[1, 2].set_yticks(range(4))
        axes[1, 2].set_yticklabels(['Similarity', 'Overlap', 'Recon Loss (norm)', 'Sparsity (x10)'])
        axes[1, 2].set_xticks(range(len(layers)))
        axes[1, 2].set_xticklabels([f'L{l}' for l in layers])
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=axes[1, 2])
        cbar.set_label('Metric Value')
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✅ Layer sweep visualization saved to {save_path}")
        
        # Save detailed results as JSON
        json_path = self.output_dir / f"{model1_clean}_{model2_clean}_results.json"
        
        # Convert results to JSON-serializable format
        json_results = {
            'layers': layers,
            'overall_analysis': results['overall_analysis'],
            'layer_summaries': {}
        }
        
        for layer in layers:
            if layer in results['layer_results'] and 'aggregate' in results['layer_results'][layer]:
                agg = results['layer_results'][layer]['aggregate']
                if agg:
                    json_results['layer_summaries'][str(layer)] = agg
        
        with open(json_path, 'w') as f:
            json.dump(json_results, f, indent=2)
        print(f"✅ Detailed results saved to {json_path}")

    def interpret_layer_sweep_results(self, results: Dict) -> Dict[str, str]:
        """Provide interpretation of layer sweep results."""
        overall = results['overall_analysis']
        interpretations = {}
        
        # Overall adaptation assessment
        avg_similarity = overall['avg_similarity_across_layers']
        if avg_similarity > 0.85:
            interpretations['adaptation_magnitude'] = "✅ MINIMAL LLM→VLM adaptation - representations largely preserved"
        elif avg_similarity > 0.7:
            interpretations['adaptation_magnitude'] = "⚠️ MODERATE LLM→VLM adaptation - selective representational changes"
        else:
            interpretations['adaptation_magnitude'] = "🔍 SIGNIFICANT LLM→VLM adaptation - substantial representational reorganization"
        
        # Layer-specific insights
        if overall['most_different_layer'] is not None:
            interpretations['adaptation_location'] = f"🎯 Layer {overall['most_different_layer']} shows maximum adaptation"
        
        if overall['highest_overlap_layer'] is not None:
            interpretations['feature_preservation'] = f"🔗 Layer {overall['highest_overlap_layer']} best preserves LLM features"
        
        # Adaptation pattern
        layer_sims = list(overall['layer_similarities'].values())
        if len(layer_sims) >= 3:
            early_sim = np.mean(layer_sims[:len(layer_sims)//3])
            late_sim = np.mean(layer_sims[-len(layer_sims)//3:])
            
            if early_sim > late_sim + 0.1:
                interpretations['adaptation_pattern'] = "📈 Early layers preserve LLM representations better than late layers"
            elif late_sim > early_sim + 0.1:
                interpretations['adaptation_pattern'] = "📉 Late layers preserve LLM representations better than early layers"
            else:
                interpretations['adaptation_pattern'] = "📊 Uniform adaptation pattern across layers"
        
        # SAE quality assessment
        avg_delta_loss = overall['avg_delta_loss_across_layers']
        if avg_delta_loss < 0.1:
            interpretations['sae_quality'] = "✅ SAE reconstructions preserve model functionality well"
        elif avg_delta_loss < 0.5:
            interpretations['sae_quality'] = "⚠️ SAE reconstructions cause moderate functional degradation"
        else:
            interpretations['sae_quality'] = "❌ SAE reconstructions significantly impact model functionality"
        
        return interpretations


def main():
    """Main function for comprehensive LLM->VLM representation shift analysis."""
    print("🚀 Comprehensive SAE Layer Sweep Analysis: LLM→VLM Adaptation")
    print("=" * 70)
    
    # Configuration
    MODEL_SIZE = "2b"
    WIDTH = "16k"
    SUFFIX = "canonical"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    LAYERS = [4, 8, 12, 16, 20, 24]  # Sample across the model depth
    
    try:
        # Initialize analyzer
        analyzer = MemoryEfficientSAEAnalyzer(
            model_size=MODEL_SIZE,
            width=WIDTH,
            suffix=SUFFIX,
            device=DEVICE
        )
        
        # Load dataset
        print("\n📚 Loading Datasets...")
        dataset_loader = DatasetLoader(device=DEVICE)
        texts = dataset_loader.get_mixed_dataset(total_samples=150)  # Reasonable sample size
#         texts = dataset_loader.get_mixed_dataset(total_samples=150, read_local=True, save=False)  # Reasonable sample size
        print(f"✅ Loaded {len(texts)} texts from mixed datasets")
        print(f"Sample texts: {texts[:3]}")
        
        # Model configuration for LLM->VLM comparison
        model1_name = "google/gemma-2-2b"  # Base Gemma-2-2B (LLM)
        model2_name = "google/paligemma2-3b-pt-224"  # PaliGemma with Gemma-2-2B decoder (VLM)
        
        print(f"\n🔬 Research Configuration:")
        print(f"   Model 1 (LLM): {model1_name}")
        print(f"   Model 2 (VLM): {model2_name}")
        print(f"   Layers to analyze: {LAYERS}")
        print(f"   SAE Configuration: {MODEL_SIZE}-{WIDTH}-{SUFFIX}")
        print(f"   Device: {DEVICE}")
        print(f"   Total texts: {len(texts)}")
        
        # Run layer sweep analysis
        print(f"\n🚀 Starting Layer Sweep Analysis...")
        results = analyzer.analyze_layer_sweep(
            model1_name=model1_name,
            model2_name=model2_name,
            texts=texts,
            layers=LAYERS
        )
        
        # Generate interpretations
        interpretations = analyzer.interpret_layer_sweep_results(results)
        
        print(f"\n📊 LAYER SWEEP RESULTS:")
        print("=" * 50)
        
        overall = results['overall_analysis']
        print(f"Most Similar Layer: {overall['most_similar_layer']}")
        print(f"Most Different Layer: {overall['most_different_layer']}")
        print(f"Highest Feature Overlap Layer: {overall['highest_overlap_layer']}")
        print(f"Average Similarity Across Layers: {overall['avg_similarity_across_layers']:.3f}")
        print(f"Average Feature Overlap: {overall['avg_overlap_across_layers']:.3f}")
        
        print(f"\n🔍 INTERPRETATIONS:")
        print("=" * 50)
        for aspect, interpretation in interpretations.items():
            print(f"{aspect.replace('_', ' ').title()}: {interpretation}")
        
        # Create visualizations
        print(f"\n📈 Generating Visualizations...")
        analyzer.visualize_layer_sweep_results(results, model1_name, model2_name)
        
        print(f"\n✅ Analysis Complete!")
        print(f"📁 Results saved to: {analyzer.output_dir}")
        print(f"🧠 Key Finding: {interpretations.get('adaptation_magnitude', 'Analysis completed')}")
        
        # Memory cleanup
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print(f"🔧 Final GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB")
        
    except Exception as e:
        print(f"❌ Error during analysis: {e}")
        import traceback
        traceback.print_exc()
        
        print("\n💡 Troubleshooting Tips:")
        print("   1. Ensure sufficient GPU memory (8GB+ recommended)")
        print("   2. Reduce LAYERS list or sample size if out of memory")
        print("   3. Check model names are correct and accessible")
        print("   4. Install required packages: pip install sae-lens transformers datasets")


if __name__ == "__main__":
    main()

# Installation requirements:
"""
pip install sae-lens transformers torch matplotlib seaborn numpy datasets tqdm

# For CUDA support (recommended):
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
"""

🚀 Comprehensive SAE Layer Sweep Analysis: LLM→VLM Adaptation
🔧 Initialized SAE Analyzer
   Device: cuda
   Model Size: 2b
   SAE Width: 16k
   Output Dir: ../figs_tabs

📚 Loading Datasets...
✅ Loaded 40 texts from mixed datasets
Sample texts: ['An image showing a cup.', 'A photo of a red apple on a white background.', 'An image showing a boy.']

🔬 Research Configuration:
   Model 1 (LLM): google/gemma-2-2b
   Model 2 (VLM): google/paligemma2-3b-pt-224
   Layers to analyze: [4, 8, 12, 16, 20, 24]
   SAE Configuration: 2b-16k-canonical
   Device: cuda
   Total texts: 40

🚀 Starting Layer Sweep Analysis...
🚀 Starting Layer Sweep Analysis
   Model 1: google/gemma-2-2b
   Model 2: google/paligemma2-3b-pt-224
   Layers: [4, 8, 12, 16, 20, 24]
   Texts: 40 samples
   Memory: 0.00GB


Processing layers:   0%|          | 0/6 [00:00<?, ?it/s]


📊 Processing Layer 4
   📥 Loading SAE Layer 4: layer_4/width_16k/canonical



Layer 4 texts:   0%|          | 0/20 [00:00<?, ?it/s][A

📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 66.64it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.16s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.46it/s][A[A

Layer 4 texts:   5%|▌         | 1/20 [00:13<04:16, 13.49s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.31it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.21s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.43it/s][A[A

Layer 4 texts:  10%|█         | 2/20 [00:26<03:53, 12.98s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.61it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.28s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.35it/s][A[A

Layer 4 texts:  15%|█▌        | 3/20 [00:38<03:37, 12.79s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.03it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.26s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.37it/s][A[A

Layer 4 texts:  20%|██        | 4/20 [00:51<03:24, 12.79s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 59.85it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.12s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.50it/s][A[A

Layer 4 texts:  25%|██▌       | 5/20 [01:04<03:11, 12.78s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 53.79it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.21s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.40it/s][A[A

Layer 4 texts:  30%|███       | 6/20 [01:16<02:57, 12.68s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.35it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.27s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.37it/s][A[A

Layer 4 texts:  35%|███▌      | 7/20 [01:29<02:43, 12.58s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.53it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.08s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s][A[A

Layer 4 texts:  40%|████      | 8/20 [01:41<02:28, 12.38s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 59.33it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.13s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.55it/s][A[A

Layer 4 texts:  45%|████▌     | 9/20 [01:53<02:15, 12.29s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 59.48it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.05it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.76it/s][A[A

Layer 4 texts:  50%|█████     | 10/20 [02:04<02:01, 12.14s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.50it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.25it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s][A[A

Layer 4 texts:  55%|█████▌    | 11/20 [02:16<01:48, 12.10s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.21it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.07it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s][A[A

Layer 4 texts:  60%|██████    | 12/20 [02:29<01:37, 12.18s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.00it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.09it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.79it/s][A[A

Layer 4 texts:  65%|██████▌   | 13/20 [02:41<01:25, 12.16s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.68it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.12it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.89it/s][A[A

Layer 4 texts:  70%|███████   | 14/20 [02:53<01:12, 12.12s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.01it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.01it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.73it/s][A[A

Layer 4 texts:  75%|███████▌  | 15/20 [03:05<01:00, 12.09s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.59it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.02s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.70it/s][A[A

Layer 4 texts:  80%|████████  | 16/20 [03:17<00:48, 12.05s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 57.25it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s][A[A

Layer 4 texts:  85%|████████▌ | 17/20 [03:29<00:35, 11.92s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.90it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.10it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.88it/s][A[A

Layer 4 texts:  90%|█████████ | 18/20 [03:41<00:24, 12.06s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.36it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.08it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.85it/s][A[A

Layer 4 texts:  95%|█████████▌| 19/20 [03:53<00:12, 12.11s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.27it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.15it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.87it/s][A[A

Layer 4 texts: 100%|██████████| 20/20 [04:05<00:00, 12.15s/it][A
Processing layers:  17%|█▋        | 1/6 [04:09<20:48, 249.73s/it]

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
   Memory after layer 4: 12.53GB

📊 Processing Layer 8
   📥 Loading SAE Layer 8: layer_8/width_16k/canonical



Layer 8 texts:   0%|          | 0/20 [00:00<?, ?it/s][A

📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.63it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.08it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.84it/s][A[A

Layer 8 texts:   5%|▌         | 1/20 [00:12<03:48, 12.02s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 57.10it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.07it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.84it/s][A[A

Layer 8 texts:  10%|█         | 2/20 [00:23<03:31, 11.76s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.11it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.08it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.63it/s][A[A

Layer 8 texts:  15%|█▌        | 3/20 [00:35<03:22, 11.93s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.11it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.24it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.10it/s][A[A

Layer 8 texts:  20%|██        | 4/20 [00:47<03:11, 11.97s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.86it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s][A[A

Layer 8 texts:  25%|██▌       | 5/20 [00:59<02:59, 11.98s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.65it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.21it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s][A[A

Layer 8 texts:  30%|███       | 6/20 [01:11<02:48, 12.03s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 56.51it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.21it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s][A[A

Layer 8 texts:  35%|███▌      | 7/20 [01:23<02:35, 11.98s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 56.38it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.24it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s][A[A

Layer 8 texts:  40%|████      | 8/20 [01:35<02:23, 11.97s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.07it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.19it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s][A[A

Layer 8 texts:  45%|████▌     | 9/20 [01:47<02:12, 12.02s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.00it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.11it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.88it/s][A[A

Layer 8 texts:  50%|█████     | 10/20 [01:59<01:59, 11.92s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 56.46it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.18it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.00it/s][A[A

Layer 8 texts:  55%|█████▌    | 11/20 [02:11<01:46, 11.85s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.70it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s][A[A

Layer 8 texts:  60%|██████    | 12/20 [02:23<01:34, 11.85s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.72it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.21it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.04it/s][A[A

Layer 8 texts:  65%|██████▌   | 13/20 [02:34<01:22, 11.84s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.29it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.20it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.02it/s][A[A

Layer 8 texts:  70%|███████   | 14/20 [02:47<01:11, 11.94s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.57it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.16it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.97it/s][A[A

Layer 8 texts:  75%|███████▌  | 15/20 [02:58<00:59, 11.87s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.43it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.15it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.95it/s][A[A

Layer 8 texts:  80%|████████  | 16/20 [03:10<00:47, 11.80s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 55.33it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.09s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.61it/s][A[A

Layer 8 texts:  85%|████████▌ | 17/20 [03:22<00:35, 11.85s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.34it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.09it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.88it/s][A[A

Layer 8 texts:  90%|█████████ | 18/20 [03:34<00:23, 11.80s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.51it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s][A[A

Layer 8 texts:  95%|█████████▌| 19/20 [03:45<00:11, 11.75s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.17it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.18it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.99it/s][A[A

Layer 8 texts: 100%|██████████| 20/20 [03:57<00:00, 11.73s/it][A
Processing layers:  33%|███▎      | 2/6 [08:07<16:11, 242.95s/it]

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
   Memory after layer 8: 12.83GB

📊 Processing Layer 12
   📥 Loading SAE Layer 12: layer_12/width_16k/canonical



Layer 12 texts:   0%|          | 0/20 [00:00<?, ?it/s][A

📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.60it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.24it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s][A[A

Layer 12 texts:   5%|▌         | 1/20 [00:11<03:38, 11.52s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.86it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.25it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s][A[A

Layer 12 texts:  10%|█         | 2/20 [00:23<03:33, 11.84s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.09it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.23it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s][A[A

Layer 12 texts:  15%|█▌        | 3/20 [00:35<03:22, 11.94s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 56.65it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.25it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s][A[A

Layer 12 texts:  20%|██        | 4/20 [00:47<03:12, 12.02s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 63.27it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.23it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s][A[A

Layer 12 texts:  25%|██▌       | 5/20 [00:59<02:57, 11.85s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 57.59it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.20it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.99it/s][A[A

Layer 12 texts:  30%|███       | 6/20 [01:10<02:44, 11.78s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.11it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.28it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.13it/s][A[A

Layer 12 texts:  35%|███▌      | 7/20 [01:22<02:32, 11.77s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.98it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.24it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s][A[A

Layer 12 texts:  40%|████      | 8/20 [01:34<02:21, 11.77s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 59.27it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.24it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.04it/s][A[A

Layer 12 texts:  45%|████▌     | 9/20 [01:46<02:09, 11.81s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 57.50it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.20it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.99it/s][A[A


⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'



Layer 12 texts:  50%|█████     | 10/20 [01:58<01:58, 11.82s/it][A

⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 57.65it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.14it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.93it/s][A[A

Layer 12 texts:  55%|█████▌    | 11/20 [02:09<01:46, 11.78s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.21it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.18it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.98it/s][A[A

Layer 12 texts:  60%|██████    | 12/20 [02:21<01:34, 11.76s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.96it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.28it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.13it/s][A[A

Layer 12 texts:  65%|██████▌   | 13/20 [02:33<01:23, 11.89s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 54.80it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s][A[A

Layer 12 texts:  70%|███████   | 14/20 [02:45<01:11, 11.86s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 58.69it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s][A[A

Layer 12 texts:  75%|███████▌  | 15/20 [02:57<00:59, 11.93s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 57.34it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.19it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  2.00it/s][A[A

Layer 12 texts:  80%|████████  | 16/20 [03:09<00:47, 11.91s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 57.89it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s][A[A

Layer 12 texts:  85%|████████▌ | 17/20 [03:21<00:35, 11.90s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.49it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s][A[A

Layer 12 texts:  90%|█████████ | 18/20 [03:33<00:23, 11.82s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.10it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.05s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.64it/s][A[A

Layer 12 texts:  95%|█████████▌| 19/20 [03:44<00:11, 11.83s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.13it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.04it/s][A[A

Layer 12 texts: 100%|██████████| 20/20 [03:56<00:00, 11.78s/it][A
Processing layers:  50%|█████     | 3/6 [12:05<12:01, 240.48s/it]A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
   Memory after layer 12: 12.83GB

📊 Processing Layer 16
   📥 Loading SAE Layer 16: layer_16/width_16k/canonical



Layer 16 texts:   0%|          | 0/20 [00:00<?, ?it/s][A

📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 57.83it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.19it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.00it/s][A[A

Layer 16 texts:   5%|▌         | 1/20 [00:11<03:44, 11.81s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 55.07it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.19it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  2.00it/s][A[A

Layer 16 texts:  10%|█         | 2/20 [00:23<03:34, 11.93s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.60it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s][A[A

Layer 16 texts:  15%|█▌        | 3/20 [00:35<03:22, 11.92s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 63.13it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.15it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.94it/s][A[A

Layer 16 texts:  20%|██        | 4/20 [00:47<03:08, 11.80s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 55.74it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.24it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.74it/s][A[A

Layer 16 texts:  25%|██▌       | 5/20 [00:59<02:56, 11.78s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.02it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.17it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.94it/s][A[A

Layer 16 texts:  30%|███       | 6/20 [01:11<02:46, 11.89s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.08it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.18it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.97it/s][A[A

Layer 16 texts:  35%|███▌      | 7/20 [01:23<02:35, 11.94s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 57.64it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.20it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s][A[A

Layer 16 texts:  40%|████      | 8/20 [01:35<02:23, 11.97s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 65.06it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.18it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.98it/s][A[A

Layer 16 texts:  45%|████▌     | 9/20 [01:46<02:10, 11.88s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 58.01it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.24it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s][A[A

Layer 16 texts:  50%|█████     | 10/20 [01:58<01:57, 11.80s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s][A[A

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 10.51it/s][A[A


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.30it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.16it/s][A[A

Layer 16 texts:  55%|█████▌    | 11/20 [02:10<01:47, 11.89s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.17it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.04it/s][A[A

Layer 16 texts:  60%|██████    | 12/20 [02:22<01:35, 11.91s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 63.37it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.24it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s][A[A

Layer 16 texts:  65%|██████▌   | 13/20 [02:34<01:23, 11.95s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 63.77it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.26it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.11it/s][A[A

Layer 16 texts:  70%|███████   | 14/20 [02:46<01:12, 12.03s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.57it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.26it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.11it/s][A[A

Layer 16 texts:  75%|███████▌  | 15/20 [02:59<01:00, 12.06s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.37it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.21it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s][A[A

Layer 16 texts:  80%|████████  | 16/20 [03:11<00:48, 12.18s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.41it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.26it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s][A[A

Layer 16 texts:  85%|████████▌ | 17/20 [03:23<00:36, 12.00s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.23it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.01s/it][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.71it/s][A[A

Layer 16 texts:  90%|█████████ | 18/20 [03:34<00:23, 11.93s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 56.68it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.32it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.17it/s][A[A

Layer 16 texts:  95%|█████████▌| 19/20 [03:46<00:11, 11.91s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.22it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.20it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.02it/s][A[A

Layer 16 texts: 100%|██████████| 20/20 [03:58<00:00, 11.95s/it][A
Processing layers:  67%|██████▋   | 4/6 [16:09<08:03, 241.95s/it]A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
   Memory after layer 16: 12.83GB

📊 Processing Layer 20
   📥 Loading SAE Layer 20: layer_20/width_16k/canonical



Layer 20 texts:   0%|          | 0/20 [00:00<?, ?it/s][A

📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.83it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.16it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.96it/s][A[A

Layer 20 texts:   5%|▌         | 1/20 [00:12<03:49, 12.07s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 56.26it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.18it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  2.00it/s][A[A

Layer 20 texts:  10%|█         | 2/20 [00:23<03:35, 11.96s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.68it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.20it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s][A[A

Layer 20 texts:  15%|█▌        | 3/20 [00:36<03:24, 12.03s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.37it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.21it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s][A[A

Layer 20 texts:  20%|██        | 4/20 [00:47<03:11, 11.98s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.14it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.17it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.97it/s][A[A

Layer 20 texts:  25%|██▌       | 5/20 [01:00<03:00, 12.01s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 58.86it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.15it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.95it/s][A[A

Layer 20 texts:  30%|███       | 6/20 [01:11<02:46, 11.87s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.43it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.20it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.99it/s][A[A

Layer 20 texts:  35%|███▌      | 7/20 [01:23<02:33, 11.83s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 57.21it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.25it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s][A[A

Layer 20 texts:  40%|████      | 8/20 [01:35<02:22, 11.87s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.92it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.24it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s][A[A

Layer 20 texts:  45%|████▌     | 9/20 [01:47<02:11, 11.99s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.59it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.21it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.04it/s][A[A

Layer 20 texts:  50%|█████     | 10/20 [01:59<01:58, 11.84s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.08it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.18it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.02it/s][A[A

Layer 20 texts:  55%|█████▌    | 11/20 [02:10<01:46, 11.82s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.99it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.23it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s][A[A

Layer 20 texts:  60%|██████    | 12/20 [02:22<01:34, 11.83s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 59.61it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.18it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  2.00it/s][A[A

Layer 20 texts:  65%|██████▌   | 13/20 [02:34<01:22, 11.77s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.05it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.25it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s][A[A

Layer 20 texts:  70%|███████   | 14/20 [02:46<01:11, 11.90s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.31it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.21it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s][A[A

Layer 20 texts:  75%|███████▌  | 15/20 [02:58<00:59, 11.92s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.66it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.19it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s][A[A

Layer 20 texts:  80%|████████  | 16/20 [03:10<00:47, 11.89s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.02it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.19it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.02it/s][A[A

Layer 20 texts:  85%|████████▌ | 17/20 [03:21<00:35, 11.77s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.71it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.21it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s][A[A

Layer 20 texts:  90%|█████████ | 18/20 [03:33<00:23, 11.69s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.33it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.16it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.97it/s][A[A

Layer 20 texts:  95%|█████████▌| 19/20 [03:45<00:11, 11.74s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 63.45it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.18it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  2.00it/s][A[A

Layer 20 texts: 100%|██████████| 20/20 [03:56<00:00, 11.67s/it][A
Processing layers:  83%|████████▎ | 5/6 [20:11<04:02, 242.07s/it]A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
   Memory after layer 20: 12.83GB

📊 Processing Layer 24
   📥 Loading SAE Layer 24: layer_24/width_16k/canonical



Layer 24 texts:   0%|          | 0/20 [00:00<?, ?it/s][A

📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.55it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.16it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.95it/s][A[A

Layer 24 texts:   5%|▌         | 1/20 [00:11<03:44, 11.82s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.49it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.25it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s][A[A

Layer 24 texts:  10%|█         | 2/20 [00:23<03:29, 11.62s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 63.07it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.23it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s][A[A

Layer 24 texts:  15%|█▌        | 3/20 [00:34<03:16, 11.57s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 63.16it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.30it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.12it/s][A[A

Layer 24 texts:  20%|██        | 4/20 [00:46<03:05, 11.62s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.89it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.24it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s][A[A

Layer 24 texts:  25%|██▌       | 5/20 [00:58<02:55, 11.70s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.89it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.25it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s][A[A

Layer 24 texts:  30%|███       | 6/20 [01:10<02:44, 11.78s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.62it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.26it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s][A[A

Layer 24 texts:  35%|███▌      | 7/20 [01:21<02:31, 11.67s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.40it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.23it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s][A[A

Layer 24 texts:  40%|████      | 8/20 [01:33<02:19, 11.65s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 64.19it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s][A[A

Layer 24 texts:  45%|████▌     | 9/20 [01:45<02:08, 11.72s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.27it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.23it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s][A[A

Layer 24 texts:  50%|█████     | 10/20 [01:57<01:58, 11.81s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.80it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.27it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.10it/s][A[A

Layer 24 texts:  55%|█████▌    | 11/20 [02:08<01:45, 11.71s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.48it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.04it/s][A[A

Layer 24 texts:  60%|██████    | 12/20 [02:20<01:33, 11.68s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 64.52it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.28it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.75it/s][A[A

Layer 24 texts:  65%|██████▌   | 13/20 [02:32<01:21, 11.71s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.36it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.25it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s][A[A

Layer 24 texts:  70%|███████   | 14/20 [02:43<01:10, 11.74s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 59.86it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.12it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.90it/s][A[A

Layer 24 texts:  75%|███████▌  | 15/20 [02:56<01:00, 12.02s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 62.47it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.24it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s][A[A

Layer 24 texts:  80%|████████  | 16/20 [03:08<00:47, 12.00s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.33it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.25it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s][A[A

Layer 24 texts:  85%|████████▌ | 17/20 [03:20<00:36, 12.06s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.63it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.23it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s][A[A

Layer 24 texts:  90%|█████████ | 18/20 [03:32<00:24, 12.07s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 61.38it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.25it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s][A[A

Layer 24 texts:  95%|█████████▌| 19/20 [03:44<00:11, 11.95s/it][A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/gemma-2-2b




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 60.70it/s]


⚠️  Patching failed: name 'tokenizer' is not defined
📥 Loading model: google/paligemma2-3b-pt-224




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A[A

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.23it/s][A[A

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s][A[A

Layer 24 texts: 100%|██████████| 20/20 [03:56<00:00, 11.84s/it][A
Processing layers: 100%|██████████| 6/6 [24:14<00:00, 242.36s/it]A

⚠️  Error computing unpatched loss: 'BaseModelOutputWithPast' object has no attribute 'logits'
⚠️  Patching failed: name 'tokenizer' is not defined
   Memory after layer 24: 12.83GB

📊 LAYER SWEEP RESULTS:
Most Similar Layer: 4
Most Different Layer: 24
Highest Feature Overlap Layer: 8
Average Similarity Across Layers: 0.638
Average Feature Overlap: 0.028

🔍 INTERPRETATIONS:
Adaptation Magnitude: 🔍 SIGNIFICANT LLM→VLM adaptation - substantial representational reorganization
Adaptation Location: 🎯 Layer 24 shows maximum adaptation
Feature Preservation: 🔗 Layer 8 best preserves LLM features
Adaptation Pattern: 📈 Early layers preserve LLM representations better than late layers
Sae Quality: ❌ SAE reconstructions significantly impact model functionality

📈 Generating Visualizations...





✅ Layer sweep visualization saved to ../figs_tabs/google_gemma_2_2b_google_paligemma2_3b_pt_224_layer_sweep.png
✅ Detailed results saved to ../figs_tabs/google_gemma_2_2b_google_paligemma2_3b_pt_224_results.json

✅ Analysis Complete!
📁 Results saved to: ../figs_tabs
🧠 Key Finding: 🔍 SIGNIFICANT LLM→VLM adaptation - substantial representational reorganization
🔧 Final GPU Memory: 12.83GB


'\npip install sae-lens transformers torch matplotlib seaborn numpy datasets tqdm\n\n# For CUDA support (recommended):\npip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n'