In [15]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer
from tqdm.auto import tqdm
import numpy as np
from typing import Optional, List, Dict
import matplotlib.pyplot as plt


class MemoryCachedTransformerDataset(Dataset):
    """Dataset for extracting and caching transformer layer activations in memory"""
    def __init__(
        self,
        dataset_name: str = "wikitext",
        dataset_config: str = "wikitext-103-raw-v1",
        split: str = "train",
        model_name: str = "gpt2",
        layer_idx: int = 6,
        max_length: int = 128,
        max_samples: Optional[int] = None,
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        print(f"Loading dataset {dataset_name}/{dataset_config}...")
        self.dataset = load_dataset(dataset_name, dataset_config, split=split)
        if max_samples:
            self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset))))
            
        print(f"Loading model {model_name}...")
        self.model = AutoModel.from_pretrained(model_name).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        self.layer_idx = layer_idx
        self.max_length = max_length
        self.device = device
        self.hidden_size = self.model.config.hidden_size
        
        # Compute and store all activations in memory
        print("Computing and caching activations in memory...")
        self.activations = []
        
        for i in tqdm(range(len(self.dataset))):
            tokens = self.tokenizer(
                self.dataset[i]['text'],
                max_length=max_length,
                truncation=True,
                padding="max_length",
                return_tensors="pt"
            )
            
            # Move tokens to device
            tokens = {k: v.to(device) for k, v in tokens.items()}
            
            # Extract activations
            with torch.no_grad():
                activations = self._extract_activations(tokens)
            
            self.activations.append(activations)
        
        # Convert to single tensor for more efficient storage and indexing
        self.activations = torch.stack(self.activations)
        
        # Clean up the model and tokenizer since we don't need them anymore
        del self.model
        del self.tokenizer
        torch.cuda.empty_cache()
    
    def _extract_activations(self, tokens: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Extract activations from specified layer"""
        activations = []
        
        def hook(module, input, output):
            if isinstance(output, tuple):
                act = output[0]
            else:
                act = output
                
            if isinstance(act, torch.Tensor):
                activations.append(act.detach())
            else:
                raise ValueError(f"Unexpected activation type: {type(act)}")
        
        # Get the appropriate layer
        if hasattr(self.model, 'encoder'):
            layer = self.model.encoder.layer[self.layer_idx]
        elif hasattr(self.model, 'h'):
            layer = self.model.h[self.layer_idx]
        elif hasattr(self.model, 'layers'):
            layer = self.model.layers[self.layer_idx]
        elif hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'):
            layer = self.model.model.layers[self.layer_idx]
        else:
            raise ValueError(f"Unknown model architecture")
        
        # For LLaMA models
        if 'llama' in self.model.config.model_type.lower():
            if hasattr(layer, 'mlp'):
                layer = layer.mlp
            elif hasattr(layer, 'feed_forward'):
                layer = layer.feed_forward
        
        handle = layer.register_forward_hook(hook)
        
        try:
            self.model(**tokens)
            if not activations:
                raise ValueError("No activations captured by hook")
            return activations[0].squeeze(0)  # [seq_len, hidden_size]
        finally:
            handle.remove()
    
    def __len__(self):
        return len(self.activations)
    
    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.activations[idx]
        

class SparseAutoencoder(torch.nn.Module):
    """Sparse autoencoder with enhanced sparsity mechanisms"""
    def __init__(
        self, 
        input_dim: int,
        hidden_dim: int,
        sparsity_lambda: float = 0.1,  # Increased from 1e-3
        dict_norm_lambda: float = 1e-3,
        target_sparsity: float = 0.05,  # Target activation rate (5%)
        activation_threshold: float = 1e-6,  # Increased threshold
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        super().__init__()
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim, bias=False),
            torch.nn.ReLU(),  # ReLU for positive sparse activations
        )
        
        self.decoder = torch.nn.Linear(hidden_dim, input_dim, bias=False)
        
        # Initialize weights
        with torch.no_grad():
            # Initialize encoder weights
            encoder_layer = self.encoder[0]  # Get the linear layer
            encoder_init = torch.randn(hidden_dim, input_dim)
            encoder_init = torch.nn.functional.normalize(encoder_init, dim=1)
            encoder_layer.weight.data = encoder_init
            
            # Initialize decoder weights
            decoder_init = torch.randn(input_dim, hidden_dim)
            decoder_init = torch.nn.functional.normalize(decoder_init, dim=0)
            self.decoder.weight.data = decoder_init
            
        self.sparsity_lambda = sparsity_lambda
        self.dict_norm_lambda = dict_norm_lambda
        self.target_sparsity = target_sparsity
        self.activation_threshold = activation_threshold
        self.device = device
        self.to(device)
        
    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        h = self.encoder(x)
        x_recon = self.decoder(h)
        return x_recon, h
        
    def loss(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]:
        x_recon, h = self.forward(x)
        
        # Basic losses
        recon_loss = torch.nn.functional.mse_loss(x_recon, x)
        
        # L1 sparsity
        l1_loss = torch.mean(torch.abs(h))
        
        # KL divergence sparsity penalty
        rho_hat = torch.mean(h, dim=0)  # Average activation of each hidden unit
        kl_loss = torch.mean(self.kl_divergence_sparsity(rho_hat))
        
        # Dictionary element norm loss
        dict_norm_loss = torch.mean((torch.norm(self.decoder.weight, dim=0) - 1)**2)
        
        # Combined loss with both L1 and KL sparsity penalties
        total_loss = (
            recon_loss + 
            self.sparsity_lambda * (l1_loss + kl_loss) +
            self.dict_norm_lambda * dict_norm_loss
        )
        
        # Compute metrics
        sparsity_ratio = torch.mean((torch.abs(h) > self.activation_threshold).float())
        max_activation = torch.max(torch.abs(h))
        feature_usage = torch.mean((torch.abs(h) > self.activation_threshold).float(), dim=0)
        dead_features = torch.sum(feature_usage == 0).item()
        
        metrics = {
            'recon_loss': recon_loss.item(),
            'l1_loss': l1_loss.item(),
            'kl_loss': kl_loss.item(),
            'dict_norm_loss': dict_norm_loss.item(),
            'total_loss': total_loss.item(),
            'sparsity_ratio': sparsity_ratio.item(),
            'max_activation': max_activation.item(),
            'dead_features': dead_features,
            'feature_usage_std': torch.std(feature_usage).item()
        }
        
        return total_loss, metrics

    def compute_metrics(self, x: torch.Tensor) -> dict:
        """
        Compute detailed metrics for sparse autoencoder analysis
        
        Args:
            x: Input tensor of shape [batch_size, input_dim]
            
        Returns:
            Dictionary containing various metrics:
            - Basic loss terms (reconstruction, L1, KL divergence)
            - Sparsity measurements
            - Feature activation statistics
            - Neuron specialization metrics
        """
        x_recon, h = self.forward(x)
        batch_size = x.size(0)
        
        # 1. Basic Losses
        recon_loss = torch.nn.functional.mse_loss(x_recon, x)
        l1_loss = torch.mean(torch.abs(h))
        
        # 2. KL Divergence Loss for Sparsity
        rho_hat = torch.mean(h, dim=0)  # Average activation of each hidden unit
        kl_loss = torch.mean(self.kl_divergence_sparsity(rho_hat))
        
        # 3. Dictionary Element Norm Loss
        dict_norm_loss = torch.mean((torch.norm(self.decoder.weight, dim=0) - 1)**2)
        
        # 4. Total Loss
        total_loss = (
            recon_loss + 
            self.sparsity_lambda * (l1_loss + kl_loss) +
            self.dict_norm_lambda * dict_norm_loss
        )
        
        # 5. Sparsity Metrics
        # Count activations above threshold
        active_neurons = (torch.abs(h) > self.activation_threshold).float()
        sparsity_ratio = torch.mean(active_neurons)
        
        # Per-neuron activation frequencies
        neuron_activity = torch.mean(active_neurons, dim=0)  # [hidden_dim]
        dead_features = torch.sum(neuron_activity == 0).item()
        rarely_active = torch.sum(neuron_activity < 0.01).item()  # <1% activation rate
        hyperactive = torch.sum(neuron_activity > 0.5).item()  # >50% activation rate
        
        # 6. Activation Statistics
        max_activation = torch.max(torch.abs(h))
        mean_activation = torch.mean(torch.abs(h))
        std_activation = torch.std(torch.abs(h))
        
        # 7. Feature Usage Distribution
        feature_usage_std = torch.std(neuron_activity).item()
        feature_usage_entropy = -torch.sum(
            neuron_activity * torch.log(neuron_activity + 1e-10)
        ).item()
        
        # 8. Neuron Correlation Analysis
        # Compute pairwise correlations between most active neurons
        top_k = min(100, h.size(1))  # Use top 100 neurons or all if less
        _, top_indices = torch.topk(neuron_activity, top_k)
        top_activations = h[:, top_indices]
        correlations = torch.corrcoef(top_activations.T)
        mean_correlation = torch.mean(torch.abs(correlations - torch.eye(top_k, device=correlations.device))).item()
        
        # 9. Reconstruction Quality per Feature
        # Compute how much each feature contributes to reconstruction
        with torch.no_grad():
            feature_importance = []
            for i in range(min(100, h.size(1))):  # Sample 100 features for efficiency
                h_zeroed = h.clone()
                h_zeroed[:, i] = 0
                x_recon_zeroed = self.decoder(h_zeroed)
                feature_importance.append(
                    torch.nn.functional.mse_loss(x_recon_zeroed, x).item()
                )
            feature_importance = torch.tensor(feature_importance)
            feature_importance_std = torch.std(feature_importance).item()
        
        metrics = {
            # Loss components
            'recon_loss': recon_loss.item(),
            'l1_loss': l1_loss.item(),
            'kl_loss': kl_loss.item(),
            'dict_norm_loss': dict_norm_loss.item(),
            'total_loss': total_loss.item(),
            
            # Sparsity metrics
            'sparsity_ratio': sparsity_ratio.item(),
            'dead_features': dead_features,
            'rarely_active': rarely_active,
            'hyperactive': hyperactive,
            
            # Activation statistics
            'max_activation': max_activation.item(),
            'mean_activation': mean_activation.item(),
            'std_activation': std_activation.item(),
            
            # Feature usage metrics
            'feature_usage_std': feature_usage_std,
            'feature_usage_entropy': feature_usage_entropy,
            'mean_correlation': mean_correlation,
            'feature_importance_std': feature_importance_std
        }
        
        return metrics
    
    def kl_divergence_sparsity(self, rho_hat: torch.Tensor) -> torch.Tensor:
        """
        Compute KL divergence between average activation and target sparsity
        
        Args:
            rho_hat: Average activation of hidden units [hidden_dim]
            
        Returns:
            KL divergence loss measuring deviation from target sparsity
        """
        epsilon = 1e-10  # Small constant for numerical stability
        rho = self.target_sparsity
        
        # Clip values to prevent log(0)
        rho_hat = torch.clamp(rho_hat, epsilon, 1 - epsilon)
        
        # KL divergence
        kl_div = rho * torch.log((rho + epsilon) / (rho_hat + epsilon)) + \
                 (1 - rho) * torch.log((1 - rho + epsilon) / (1 - rho_hat + epsilon))
        
        return kl_div
    
    def analyze_feature(self, feature_idx: int, dataloader: torch.utils.data.DataLoader) -> dict:
        """
        Analyze a specific feature's behavior across the dataset
        
        Args:
            feature_idx: Index of the feature to analyze
            dataloader: DataLoader containing samples to analyze
            
        Returns:
            Dictionary containing feature analysis
        """
        self.eval()
        activations = []
        coactivations = []
        
        with torch.no_grad():
            for batch in dataloader:
                batch = batch.reshape(-1, batch.size(-1)).to(self.device)
                _, h = self.forward(batch)
                
                # Get activations for this feature
                feature_acts = h[:, feature_idx]
                activations.append(feature_acts)
                
                # Get co-activated features
                active_mask = torch.abs(h) > self.activation_threshold
                coactive = active_mask & (active_mask[:, feature_idx].unsqueeze(1))
                coactivations.append(coactive)
        
        activations = torch.cat(activations)
        coactivations = torch.cat(coactivations)
        
        analysis = {
            'mean_activation': torch.mean(activations).item(),
            'std_activation': torch.std(activations).item(),
            'activation_rate': torch.mean((torch.abs(activations) > self.activation_threshold).float()).item(),
            'max_activation': torch.max(torch.abs(activations)).item(),
            'top_coactivated_features': torch.topk(torch.sum(coactivations, dim=0), k=5)[1].tolist()
        }
        
        return analysis

def evaluate_model(
    model: SparseAutoencoder,
    dataloader: DataLoader,
    device: str
) -> dict:
    """Evaluate model on given dataloader"""
    model.eval()
    all_metrics = []
    
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.reshape(-1, batch.size(-1)).to(device)
            metrics = model.compute_metrics(batch)
            all_metrics.append(metrics)
    
    # Average metrics
    avg_metrics = {
        k: np.mean([m[k] for m in all_metrics])
        for k in all_metrics[0].keys()
    }
    
    return avg_metrics

def plot_training_history(history: dict):
    """Plot training and validation metrics"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    metrics = ['total_loss', 'recon_loss', 'sparsity_ratio', 'dead_features']
    
    for ax, metric in zip(axes.flat, metrics):
        ax.plot(history[f'train_{metric}'], label='Train')
        ax.plot(history[f'val_{metric}'], label='Validation')
        ax.set_title(metric)
        ax.set_xlabel('Epoch')
        ax.legend()
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()

def train_autoencoder(
    dataset_name: str = "wikitext",
    dataset_config: str = "wikitext-103-raw-v1",
    model_name: str = "gpt2",
    layer_idx: int = 6,
    hidden_dim: int = 4096,
    batch_size: int = 32,
    num_epochs: int = 50,
    max_samples: Optional[int] = 10000,
    lr: float = 5e-4,  # Reduced learning rate
    sparsity_lambda: float = 0.1,  # Increased sparsity penalty
    target_sparsity: float = 0.05,  # Target 5% activation
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    val_split: float = 0.1
) -> tuple[SparseAutoencoder, dict]:
    """Train sparse autoencoder with validation"""
    # Create full dataset
    full_dataset = MemoryCachedTransformerDataset(
        dataset_name=dataset_name,
        dataset_config=dataset_config,
        model_name=model_name,
        layer_idx=layer_idx,
        max_samples=max_samples,
        device=device
    )
    
    # Split into train/val
    val_size = int(len(full_dataset) * val_split)
    train_size = len(full_dataset) - val_size
    train_dataset, val_dataset = random_split(
        full_dataset, [train_size, val_size]
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )
    
    print(f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}")
    
    # Initialize model with new parameters
    sae = SparseAutoencoder(
        input_dim=full_dataset.hidden_size,
        hidden_dim=hidden_dim,
        sparsity_lambda=sparsity_lambda,
        target_sparsity=target_sparsity,
        device=device
    )
    
    
    
    optimizer = torch.optim.Adam(sae.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2
    )
    
    # Training history
    history = {
        'train_total_loss': [], 'val_total_loss': [],
        'train_recon_loss': [], 'val_recon_loss': [],
        'train_sparsity_ratio': [], 'val_sparsity_ratio': [],
        'train_dead_features': [], 'val_dead_features': []
    }
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training
        sae.train()
        epoch_metrics = []
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            batch = batch.reshape(-1, full_dataset.hidden_size).to(device)
            
            loss, metrics = sae.loss(batch)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(sae.parameters(), max_norm=1.0)
            optimizer.step()
            
            epoch_metrics.append(metrics)
        
        train_metrics = {
            k: np.mean([m[k] for m in epoch_metrics])
            for k in epoch_metrics[0].keys()
        }
        
        # Validation
        val_metrics = evaluate_model(sae, val_loader, device)
        
        # Update history
        history['train_total_loss'].append(train_metrics['total_loss'])
        history['val_total_loss'].append(val_metrics['total_loss'])
        history['train_recon_loss'].append(train_metrics['recon_loss'])
        history['val_recon_loss'].append(val_metrics['recon_loss'])
        history['train_sparsity_ratio'].append(train_metrics['sparsity_ratio'])
        history['val_sparsity_ratio'].append(val_metrics['sparsity_ratio'])
        history['train_dead_features'].append(train_metrics['dead_features'])
        history['val_dead_features'].append(val_metrics['dead_features'])
        
        # Print metrics
        print(f"\nEpoch {epoch+1} metrics:")
        print("Train:", {k: f"{v:.4f}" for k, v in train_metrics.items()})
        print("Val:", {k: f"{v:.4f}" for k, v in val_metrics.items()})
        
        # Learning rate scheduling
        scheduler.step(val_metrics['total_loss'])
        
        # Save best model
        if val_metrics['total_loss'] < best_val_loss:
            best_val_loss = val_metrics['total_loss']
            torch.save(sae.state_dict(), 'best_sae.pt')
        
        # Plot training history
        plot_training_history(history)
    
    return sae, history

if __name__ == "__main__":
    sae, history = train_autoencoder(
        dataset_name="wikitext",
        dataset_config="wikitext-103-raw-v1",
        model_name="meta-llama/Llama-3.2-1B",
        max_samples=3000,
        num_epochs=50,
        batch_size=16,
        hidden_dim=5000,
        val_split=0.2,
        sparsity_lambda=100,
    )

Loading dataset wikitext/wikitext-103-raw-v1...
Loading model meta-llama/Llama-3.2-1B...
Computing and caching activations in memory...


100%|███████████████████████████████████████| 3000/3000 [01:02<00:00, 48.30it/s]


Train size: 2400, Val size: 600


Epoch 1/50: 100%|█████████████████████████████| 150/150 [00:02<00:00, 65.03it/s]



Epoch 1 metrics:
Train: {'recon_loss': '0.0027', 'l1_loss': '0.0256', 'kl_loss': '0.0121', 'dict_norm_loss': '0.0010', 'total_loss': '3.7662', 'sparsity_ratio': '0.4017', 'max_activation': '0.4348', 'dead_features': '0.0000', 'feature_usage_std': '0.1807'}
Val: {'recon_loss': '0.0017', 'l1_loss': '0.0248', 'kl_loss': '0.0117', 'dict_norm_loss': '0.0026', 'total_loss': '3.6483', 'sparsity_ratio': '0.3789', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '1014.8158', 'max_activation': '0.4849', 'mean_activation': '0.0248', 'std_activation': '0.0443', 'feature_usage_std': '0.1629', 'feature_usage_entropy': '1654.7842', 'mean_correlation': '0.0654', 'feature_importance_std': '0.0000'}


Epoch 2/50: 100%|█████████████████████████████| 150/150 [00:02<00:00, 68.73it/s]



Epoch 2 metrics:
Train: {'recon_loss': '0.0014', 'l1_loss': '0.0254', 'kl_loss': '0.0114', 'dict_norm_loss': '0.0044', 'total_loss': '3.6814', 'sparsity_ratio': '0.3533', 'max_activation': '0.5857', 'dead_features': '0.0000', 'feature_usage_std': '0.1613'}
Val: {'recon_loss': '0.0012', 'l1_loss': '0.0243', 'kl_loss': '0.0121', 'dict_norm_loss': '0.0060', 'total_loss': '3.6390', 'sparsity_ratio': '0.3112', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '495.2105', 'max_activation': '0.6668', 'mean_activation': '0.0243', 'std_activation': '0.0539', 'feature_usage_std': '0.1400', 'feature_usage_entropy': '1662.1155', 'mean_correlation': '0.0768', 'feature_importance_std': '0.0000'}


Epoch 3/50: 100%|█████████████████████████████| 150/150 [00:02<00:00, 68.80it/s]



Epoch 3 metrics:
Train: {'recon_loss': '0.0011', 'l1_loss': '0.0255', 'kl_loss': '0.0110', 'dict_norm_loss': '0.0074', 'total_loss': '3.6528', 'sparsity_ratio': '0.2931', 'max_activation': '0.7499', 'dead_features': '0.0000', 'feature_usage_std': '0.1341'}
Val: {'recon_loss': '0.0011', 'l1_loss': '0.0255', 'kl_loss': '0.0106', 'dict_norm_loss': '0.0086', 'total_loss': '3.6135', 'sparsity_ratio': '0.2776', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '401.1842', 'max_activation': '0.8354', 'mean_activation': '0.0255', 'std_activation': '0.0655', 'feature_usage_std': '0.1319', 'feature_usage_entropy': '1630.9165', 'mean_correlation': '0.0843', 'feature_importance_std': '0.0000'}


Epoch 4/50: 100%|█████████████████████████████| 150/150 [00:02<00:00, 68.44it/s]



Epoch 4 metrics:
Train: {'recon_loss': '0.0010', 'l1_loss': '0.0251', 'kl_loss': '0.0112', 'dict_norm_loss': '0.0097', 'total_loss': '3.6320', 'sparsity_ratio': '0.2538', 'max_activation': '0.9204', 'dead_features': '0.0000', 'feature_usage_std': '0.1318'}
Val: {'recon_loss': '0.0011', 'l1_loss': '0.0268', 'kl_loss': '0.0094', 'dict_norm_loss': '0.0107', 'total_loss': '3.6137', 'sparsity_ratio': '0.2540', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '351.0526', 'max_activation': '1.0099', 'mean_activation': '0.0268', 'std_activation': '0.0769', 'feature_usage_std': '0.1297', 'feature_usage_entropy': '1588.9095', 'mean_correlation': '0.1120', 'feature_importance_std': '0.0000'}


Epoch 5/50: 100%|█████████████████████████████| 150/150 [00:02<00:00, 68.52it/s]



Epoch 5 metrics:
Train: {'recon_loss': '0.0011', 'l1_loss': '0.0252', 'kl_loss': '0.0110', 'dict_norm_loss': '0.0128', 'total_loss': '3.6170', 'sparsity_ratio': '0.2271', 'max_activation': '1.0791', 'dead_features': '0.0000', 'feature_usage_std': '0.1210'}
Val: {'recon_loss': '0.0011', 'l1_loss': '0.0252', 'kl_loss': '0.0109', 'dict_norm_loss': '0.0141', 'total_loss': '3.6086', 'sparsity_ratio': '0.2163', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '181.5789', 'max_activation': '1.1666', 'mean_activation': '0.0252', 'std_activation': '0.0831', 'feature_usage_std': '0.1085', 'feature_usage_entropy': '1532.3412', 'mean_correlation': '0.1553', 'feature_importance_std': '0.0000'}


Epoch 6/50: 100%|█████████████████████████████| 150/150 [00:02<00:00, 68.38it/s]



Epoch 6 metrics:
Train: {'recon_loss': '0.0011', 'l1_loss': '0.0253', 'kl_loss': '0.0107', 'dict_norm_loss': '0.0161', 'total_loss': '3.6017', 'sparsity_ratio': '0.1981', 'max_activation': '1.2436', 'dead_features': '0.0000', 'feature_usage_std': '0.1014'}
Val: {'recon_loss': '0.0012', 'l1_loss': '0.0237', 'kl_loss': '0.0123', 'dict_norm_loss': '0.0184', 'total_loss': '3.6065', 'sparsity_ratio': '0.1837', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '96.8684', 'max_activation': '1.3056', 'mean_activation': '0.0237', 'std_activation': '0.0862', 'feature_usage_std': '0.0910', 'feature_usage_entropy': '1455.1582', 'mean_correlation': '0.2358', 'feature_importance_std': '0.0000'}


Epoch 7/50: 100%|█████████████████████████████| 150/150 [00:02<00:00, 69.08it/s]



Epoch 7 metrics:
Train: {'recon_loss': '0.0010', 'l1_loss': '0.0250', 'kl_loss': '0.0109', 'dict_norm_loss': '0.0225', 'total_loss': '3.5966', 'sparsity_ratio': '0.1812', 'max_activation': '1.3900', 'dead_features': '0.0000', 'feature_usage_std': '0.0859'}
Val: {'recon_loss': '0.0010', 'l1_loss': '0.0257', 'kl_loss': '0.0102', 'dict_norm_loss': '0.0262', 'total_loss': '3.5947', 'sparsity_ratio': '0.1837', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '73.3158', 'max_activation': '1.4584', 'mean_activation': '0.0257', 'std_activation': '0.0946', 'feature_usage_std': '0.0850', 'feature_usage_entropy': '1461.8641', 'mean_correlation': '0.2477', 'feature_importance_std': '0.0000'}


Epoch 8/50: 100%|█████████████████████████████| 150/150 [00:02<00:00, 68.86it/s]



Epoch 8 metrics:
Train: {'recon_loss': '0.0010', 'l1_loss': '0.0253', 'kl_loss': '0.0107', 'dict_norm_loss': '0.0288', 'total_loss': '3.6007', 'sparsity_ratio': '0.1684', 'max_activation': '1.5360', 'dead_features': '0.0000', 'feature_usage_std': '0.0774'}
Val: {'recon_loss': '0.0010', 'l1_loss': '0.0251', 'kl_loss': '0.0107', 'dict_norm_loss': '0.0297', 'total_loss': '3.5877', 'sparsity_ratio': '0.1666', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '40.7105', 'max_activation': '1.6190', 'mean_activation': '0.0251', 'std_activation': '0.0999', 'feature_usage_std': '0.0728', 'feature_usage_entropy': '1414.9563', 'mean_correlation': '0.2948', 'feature_importance_std': '0.0000'}


Epoch 9/50: 100%|█████████████████████████████| 150/150 [00:02<00:00, 68.55it/s]



Epoch 9 metrics:
Train: {'recon_loss': '0.0011', 'l1_loss': '0.0252', 'kl_loss': '0.0108', 'dict_norm_loss': '0.0338', 'total_loss': '3.5939', 'sparsity_ratio': '0.1547', 'max_activation': '1.7143', 'dead_features': '0.0000', 'feature_usage_std': '0.0712'}
Val: {'recon_loss': '0.0012', 'l1_loss': '0.0252', 'kl_loss': '0.0106', 'dict_norm_loss': '0.0354', 'total_loss': '3.5853', 'sparsity_ratio': '0.1523', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '35.2632', 'max_activation': '1.7907', 'mean_activation': '0.0252', 'std_activation': '0.1076', 'feature_usage_std': '0.0686', 'feature_usage_entropy': '1360.5371', 'mean_correlation': '0.3003', 'feature_importance_std': '0.0000'}


Epoch 10/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.97it/s]



Epoch 10 metrics:
Train: {'recon_loss': '0.0010', 'l1_loss': '0.0251', 'kl_loss': '0.0107', 'dict_norm_loss': '0.0340', 'total_loss': '3.5803', 'sparsity_ratio': '0.1421', 'max_activation': '1.8685', 'dead_features': '0.0000', 'feature_usage_std': '0.0687'}
Val: {'recon_loss': '0.0011', 'l1_loss': '0.0261', 'kl_loss': '0.0098', 'dict_norm_loss': '0.0353', 'total_loss': '3.5906', 'sparsity_ratio': '0.1477', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '34.6579', 'max_activation': '1.9451', 'mean_activation': '0.0261', 'std_activation': '0.1152', 'feature_usage_std': '0.0718', 'feature_usage_entropy': '1332.9890', 'mean_correlation': '0.3463', 'feature_importance_std': '0.0000'}


Epoch 11/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.70it/s]



Epoch 11 metrics:
Train: {'recon_loss': '0.0010', 'l1_loss': '0.0253', 'kl_loss': '0.0104', 'dict_norm_loss': '0.0376', 'total_loss': '3.5768', 'sparsity_ratio': '0.1350', 'max_activation': '2.0043', 'dead_features': '0.0000', 'feature_usage_std': '0.0656'}
Val: {'recon_loss': '0.0010', 'l1_loss': '0.0252', 'kl_loss': '0.0106', 'dict_norm_loss': '0.0393', 'total_loss': '3.5815', 'sparsity_ratio': '0.1349', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '31.4211', 'max_activation': '2.0690', 'mean_activation': '0.0252', 'std_activation': '0.1193', 'feature_usage_std': '0.0655', 'feature_usage_entropy': '1280.5245', 'mean_correlation': '0.3448', 'feature_importance_std': '0.0000'}


Epoch 12/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.68it/s]



Epoch 12 metrics:
Train: {'recon_loss': '0.0011', 'l1_loss': '0.0252', 'kl_loss': '0.0104', 'dict_norm_loss': '0.0412', 'total_loss': '3.5602', 'sparsity_ratio': '0.1287', 'max_activation': '2.1346', 'dead_features': '0.0000', 'feature_usage_std': '0.0637'}
Val: {'recon_loss': '0.0010', 'l1_loss': '0.0241', 'kl_loss': '0.0118', 'dict_norm_loss': '0.0414', 'total_loss': '3.5876', 'sparsity_ratio': '0.1245', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '31.1053', 'max_activation': '2.1877', 'mean_activation': '0.0241', 'std_activation': '0.1223', 'feature_usage_std': '0.0631', 'feature_usage_entropy': '1230.2170', 'mean_correlation': '0.3520', 'feature_importance_std': '0.0000'}


Epoch 13/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.76it/s]



Epoch 13 metrics:
Train: {'recon_loss': '0.0011', 'l1_loss': '0.0252', 'kl_loss': '0.0104', 'dict_norm_loss': '0.0411', 'total_loss': '3.5668', 'sparsity_ratio': '0.1221', 'max_activation': '2.2457', 'dead_features': '0.0000', 'feature_usage_std': '0.0622'}
Val: {'recon_loss': '0.0012', 'l1_loss': '0.0257', 'kl_loss': '0.0101', 'dict_norm_loss': '0.0438', 'total_loss': '3.5771', 'sparsity_ratio': '0.1241', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '29.0000', 'max_activation': '2.3080', 'mean_activation': '0.0257', 'std_activation': '0.1302', 'feature_usage_std': '0.0630', 'feature_usage_entropy': '1226.5381', 'mean_correlation': '0.3701', 'feature_importance_std': '0.0000'}


Epoch 14/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.68it/s]



Epoch 14 metrics:
Train: {'recon_loss': '0.0011', 'l1_loss': '0.0252', 'kl_loss': '0.0105', 'dict_norm_loss': '0.0447', 'total_loss': '3.5743', 'sparsity_ratio': '0.1150', 'max_activation': '2.3621', 'dead_features': '0.0000', 'feature_usage_std': '0.0595'}
Val: {'recon_loss': '0.0011', 'l1_loss': '0.0250', 'kl_loss': '0.0107', 'dict_norm_loss': '0.0445', 'total_loss': '3.5711', 'sparsity_ratio': '0.1144', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '27.8158', 'max_activation': '2.4237', 'mean_activation': '0.0250', 'std_activation': '0.1352', 'feature_usage_std': '0.0583', 'feature_usage_entropy': '1179.8540', 'mean_correlation': '0.3910', 'feature_importance_std': '0.0000'}


Epoch 15/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.74it/s]



Epoch 15 metrics:
Train: {'recon_loss': '0.0013', 'l1_loss': '0.0252', 'kl_loss': '0.0105', 'dict_norm_loss': '0.0456', 'total_loss': '3.5776', 'sparsity_ratio': '0.1074', 'max_activation': '2.4988', 'dead_features': '0.0000', 'feature_usage_std': '0.0557'}
Val: {'recon_loss': '0.0012', 'l1_loss': '0.0247', 'kl_loss': '0.0111', 'dict_norm_loss': '0.0532', 'total_loss': '3.5748', 'sparsity_ratio': '0.1040', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '19.2105', 'max_activation': '2.5517', 'mean_activation': '0.0247', 'std_activation': '0.1426', 'feature_usage_std': '0.0513', 'feature_usage_entropy': '1124.9290', 'mean_correlation': '0.4264', 'feature_importance_std': '0.0000'}


Epoch 16/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.63it/s]



Epoch 16 metrics:
Train: {'recon_loss': '0.0014', 'l1_loss': '0.0253', 'kl_loss': '0.0105', 'dict_norm_loss': '0.0546', 'total_loss': '3.5804', 'sparsity_ratio': '0.1026', 'max_activation': '2.6203', 'dead_features': '0.0000', 'feature_usage_std': '0.0527'}
Val: {'recon_loss': '0.0014', 'l1_loss': '0.0242', 'kl_loss': '0.0115', 'dict_norm_loss': '0.0593', 'total_loss': '3.5706', 'sparsity_ratio': '0.0986', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '19.0526', 'max_activation': '2.6736', 'mean_activation': '0.0242', 'std_activation': '0.1482', 'feature_usage_std': '0.0506', 'feature_usage_entropy': '1089.3126', 'mean_correlation': '0.4208', 'feature_importance_std': '0.0000'}


Epoch 17/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.80it/s]



Epoch 17 metrics:
Train: {'recon_loss': '0.0011', 'l1_loss': '0.0252', 'kl_loss': '0.0104', 'dict_norm_loss': '0.0612', 'total_loss': '3.5533', 'sparsity_ratio': '0.0961', 'max_activation': '2.7241', 'dead_features': '0.0000', 'feature_usage_std': '0.0511'}
Val: {'recon_loss': '0.0011', 'l1_loss': '0.0242', 'kl_loss': '0.0115', 'dict_norm_loss': '0.0631', 'total_loss': '3.5705', 'sparsity_ratio': '0.0910', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '17.8421', 'max_activation': '2.7543', 'mean_activation': '0.0242', 'std_activation': '0.1541', 'feature_usage_std': '0.0491', 'feature_usage_entropy': '1040.1262', 'mean_correlation': '0.4230', 'feature_importance_std': '0.0000'}


Epoch 18/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 69.33it/s]



Epoch 18 metrics:
Train: {'recon_loss': '0.0012', 'l1_loss': '0.0251', 'kl_loss': '0.0103', 'dict_norm_loss': '0.0648', 'total_loss': '3.5494', 'sparsity_ratio': '0.0916', 'max_activation': '2.7880', 'dead_features': '0.0000', 'feature_usage_std': '0.0492'}
Val: {'recon_loss': '0.0013', 'l1_loss': '0.0268', 'kl_loss': '0.0089', 'dict_norm_loss': '0.0679', 'total_loss': '3.5716', 'sparsity_ratio': '0.0989', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '14.9211', 'max_activation': '2.8285', 'mean_activation': '0.0268', 'std_activation': '0.1630', 'feature_usage_std': '0.0517', 'feature_usage_entropy': '1086.0735', 'mean_correlation': '0.4537', 'feature_importance_std': '0.0000'}


Epoch 19/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.54it/s]



Epoch 19 metrics:
Train: {'recon_loss': '0.0012', 'l1_loss': '0.0254', 'kl_loss': '0.0101', 'dict_norm_loss': '0.0702', 'total_loss': '3.5435', 'sparsity_ratio': '0.0887', 'max_activation': '2.8518', 'dead_features': '0.0000', 'feature_usage_std': '0.0484'}
Val: {'recon_loss': '0.0014', 'l1_loss': '0.0251', 'kl_loss': '0.0105', 'dict_norm_loss': '0.0706', 'total_loss': '3.5601', 'sparsity_ratio': '0.0871', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '15.1053', 'max_activation': '2.8807', 'mean_activation': '0.0251', 'std_activation': '0.1652', 'feature_usage_std': '0.0478', 'feature_usage_entropy': '1010.9999', 'mean_correlation': '0.4409', 'feature_importance_std': '0.0000'}


Epoch 20/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.73it/s]



Epoch 20 metrics:
Train: {'recon_loss': '0.0013', 'l1_loss': '0.0253', 'kl_loss': '0.0102', 'dict_norm_loss': '0.0716', 'total_loss': '3.5533', 'sparsity_ratio': '0.0846', 'max_activation': '2.9213', 'dead_features': '0.0000', 'feature_usage_std': '0.0482'}
Val: {'recon_loss': '0.0012', 'l1_loss': '0.0240', 'kl_loss': '0.0117', 'dict_norm_loss': '0.0704', 'total_loss': '3.5680', 'sparsity_ratio': '0.0807', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '15.3158', 'max_activation': '2.9446', 'mean_activation': '0.0240', 'std_activation': '0.1685', 'feature_usage_std': '0.0463', 'feature_usage_entropy': '966.7092', 'mean_correlation': '0.4361', 'feature_importance_std': '0.0000'}


Epoch 21/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.76it/s]



Epoch 21 metrics:
Train: {'recon_loss': '0.0013', 'l1_loss': '0.0252', 'kl_loss': '0.0103', 'dict_norm_loss': '0.0718', 'total_loss': '3.5513', 'sparsity_ratio': '0.0813', 'max_activation': '2.9731', 'dead_features': '0.0000', 'feature_usage_std': '0.0476'}
Val: {'recon_loss': '0.0026', 'l1_loss': '0.0255', 'kl_loss': '0.0100', 'dict_norm_loss': '0.0698', 'total_loss': '3.5550', 'sparsity_ratio': '0.0821', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '15.0000', 'max_activation': '3.0048', 'mean_activation': '0.0255', 'std_activation': '0.1766', 'feature_usage_std': '0.0470', 'feature_usage_entropy': '973.8613', 'mean_correlation': '0.4626', 'feature_importance_std': '0.0000'}


Epoch 22/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 69.21it/s]



Epoch 22 metrics:
Train: {'recon_loss': '0.0013', 'l1_loss': '0.0252', 'kl_loss': '0.0102', 'dict_norm_loss': '0.0729', 'total_loss': '3.5422', 'sparsity_ratio': '0.0776', 'max_activation': '3.0200', 'dead_features': '0.0000', 'feature_usage_std': '0.0468'}
Val: {'recon_loss': '0.0014', 'l1_loss': '0.0256', 'kl_loss': '0.0099', 'dict_norm_loss': '0.0756', 'total_loss': '3.5513', 'sparsity_ratio': '0.0792', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '14.9211', 'max_activation': '3.0380', 'mean_activation': '0.0256', 'std_activation': '0.1816', 'feature_usage_std': '0.0473', 'feature_usage_entropy': '950.2690', 'mean_correlation': '0.4757', 'feature_importance_std': '0.0000'}


Epoch 23/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 69.06it/s]



Epoch 23 metrics:
Train: {'recon_loss': '0.0012', 'l1_loss': '0.0251', 'kl_loss': '0.0102', 'dict_norm_loss': '0.0763', 'total_loss': '3.5385', 'sparsity_ratio': '0.0752', 'max_activation': '3.0532', 'dead_features': '0.0000', 'feature_usage_std': '0.0467'}
Val: {'recon_loss': '0.0013', 'l1_loss': '0.0260', 'kl_loss': '0.0095', 'dict_norm_loss': '0.0774', 'total_loss': '3.5519', 'sparsity_ratio': '0.0793', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '14.8421', 'max_activation': '3.0749', 'mean_activation': '0.0260', 'std_activation': '0.1864', 'feature_usage_std': '0.0496', 'feature_usage_entropy': '945.1851', 'mean_correlation': '0.4742', 'feature_importance_std': '0.0000'}


Epoch 24/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 69.42it/s]



Epoch 24 metrics:
Train: {'recon_loss': '0.0014', 'l1_loss': '0.0254', 'kl_loss': '0.0101', 'dict_norm_loss': '0.0787', 'total_loss': '3.5483', 'sparsity_ratio': '0.0728', 'max_activation': '3.0833', 'dead_features': '0.0000', 'feature_usage_std': '0.0471'}
Val: {'recon_loss': '0.0013', 'l1_loss': '0.0248', 'kl_loss': '0.0106', 'dict_norm_loss': '0.0770', 'total_loss': '3.5499', 'sparsity_ratio': '0.0699', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '15.0000', 'max_activation': '3.0886', 'mean_activation': '0.0248', 'std_activation': '0.1896', 'feature_usage_std': '0.0460', 'feature_usage_entropy': '876.5588', 'mean_correlation': '0.4617', 'feature_importance_std': '0.0000'}


Epoch 25/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 69.05it/s]



Epoch 25 metrics:
Train: {'recon_loss': '0.0013', 'l1_loss': '0.0252', 'kl_loss': '0.0102', 'dict_norm_loss': '0.0758', 'total_loss': '3.5377', 'sparsity_ratio': '0.0686', 'max_activation': '3.1058', 'dead_features': '0.0000', 'feature_usage_std': '0.0460'}
Val: {'recon_loss': '0.0016', 'l1_loss': '0.0260', 'kl_loss': '0.0095', 'dict_norm_loss': '0.0800', 'total_loss': '3.5494', 'sparsity_ratio': '0.0707', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '14.9211', 'max_activation': '3.1259', 'mean_activation': '0.0260', 'std_activation': '0.1967', 'feature_usage_std': '0.0475', 'feature_usage_entropy': '878.4620', 'mean_correlation': '0.4968', 'feature_importance_std': '0.0000'}


Epoch 26/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.99it/s]



Epoch 26 metrics:
Train: {'recon_loss': '0.0014', 'l1_loss': '0.0253', 'kl_loss': '0.0102', 'dict_norm_loss': '0.0802', 'total_loss': '3.5538', 'sparsity_ratio': '0.0649', 'max_activation': '3.1325', 'dead_features': '0.0000', 'feature_usage_std': '0.0460'}
Val: {'recon_loss': '0.0017', 'l1_loss': '0.0253', 'kl_loss': '0.0101', 'dict_norm_loss': '0.0810', 'total_loss': '3.5468', 'sparsity_ratio': '0.0636', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '15.1842', 'max_activation': '3.1510', 'mean_activation': '0.0253', 'std_activation': '0.2019', 'feature_usage_std': '0.0458', 'feature_usage_entropy': '821.6993', 'mean_correlation': '0.4857', 'feature_importance_std': '0.0000'}


Epoch 27/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.89it/s]



Epoch 27 metrics:
Train: {'recon_loss': '0.0014', 'l1_loss': '0.0252', 'kl_loss': '0.0101', 'dict_norm_loss': '0.0789', 'total_loss': '3.5343', 'sparsity_ratio': '0.0612', 'max_activation': '3.1506', 'dead_features': '0.0000', 'feature_usage_std': '0.0448'}
Val: {'recon_loss': '0.0013', 'l1_loss': '0.0260', 'kl_loss': '0.0095', 'dict_norm_loss': '0.0800', 'total_loss': '3.5457', 'sparsity_ratio': '0.0633', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '14.7632', 'max_activation': '3.1604', 'mean_activation': '0.0260', 'std_activation': '0.2070', 'feature_usage_std': '0.0463', 'feature_usage_entropy': '815.1338', 'mean_correlation': '0.5103', 'feature_importance_std': '0.0000'}


Epoch 28/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.73it/s]



Epoch 28 metrics:
Train: {'recon_loss': '0.0014', 'l1_loss': '0.0253', 'kl_loss': '0.0100', 'dict_norm_loss': '0.0766', 'total_loss': '3.5320', 'sparsity_ratio': '0.0597', 'max_activation': '3.1618', 'dead_features': '0.0000', 'feature_usage_std': '0.0449'}
Val: {'recon_loss': '0.0014', 'l1_loss': '0.0257', 'kl_loss': '0.0097', 'dict_norm_loss': '0.0702', 'total_loss': '3.5427', 'sparsity_ratio': '0.0614', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '14.5263', 'max_activation': '3.1630', 'mean_activation': '0.0257', 'std_activation': '0.2098', 'feature_usage_std': '0.0464', 'feature_usage_entropy': '795.2557', 'mean_correlation': '0.5446', 'feature_importance_std': '0.0000'}


Epoch 29/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.90it/s]



Epoch 29 metrics:
Train: {'recon_loss': '0.0014', 'l1_loss': '0.0253', 'kl_loss': '0.0100', 'dict_norm_loss': '0.0769', 'total_loss': '3.5297', 'sparsity_ratio': '0.0583', 'max_activation': '3.1709', 'dead_features': '0.0000', 'feature_usage_std': '0.0456'}
Val: {'recon_loss': '0.0013', 'l1_loss': '0.0249', 'kl_loss': '0.0105', 'dict_norm_loss': '0.0830', 'total_loss': '3.5428', 'sparsity_ratio': '0.0568', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '14.4474', 'max_activation': '3.1725', 'mean_activation': '0.0249', 'std_activation': '0.2121', 'feature_usage_std': '0.0447', 'feature_usage_entropy': '756.6307', 'mean_correlation': '0.5322', 'feature_importance_std': '0.0000'}


Epoch 30/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.82it/s]



Epoch 30 metrics:
Train: {'recon_loss': '0.0014', 'l1_loss': '0.0253', 'kl_loss': '0.0100', 'dict_norm_loss': '0.0817', 'total_loss': '3.5294', 'sparsity_ratio': '0.0559', 'max_activation': '3.1795', 'dead_features': '0.0000', 'feature_usage_std': '0.0448'}
Val: {'recon_loss': '0.0017', 'l1_loss': '0.0253', 'kl_loss': '0.0101', 'dict_norm_loss': '0.0836', 'total_loss': '3.5410', 'sparsity_ratio': '0.0554', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '13.6579', 'max_activation': '3.1836', 'mean_activation': '0.0253', 'std_activation': '0.2161', 'feature_usage_std': '0.0442', 'feature_usage_entropy': '742.7288', 'mean_correlation': '0.5573', 'feature_importance_std': '0.0000'}


Epoch 31/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.74it/s]



Epoch 31 metrics:
Train: {'recon_loss': '0.0014', 'l1_loss': '0.0251', 'kl_loss': '0.0102', 'dict_norm_loss': '0.0864', 'total_loss': '3.5330', 'sparsity_ratio': '0.0527', 'max_activation': '3.1861', 'dead_features': '0.0000', 'feature_usage_std': '0.0433'}
Val: {'recon_loss': '0.0013', 'l1_loss': '0.0261', 'kl_loss': '0.0094', 'dict_norm_loss': '0.0863', 'total_loss': '3.5466', 'sparsity_ratio': '0.0552', 'dead_features': '0.0000', 'rarely_active': '0.0000', 'hyperactive': '13.1053', 'max_activation': '3.1897', 'mean_activation': '0.0261', 'std_activation': '0.2206', 'feature_usage_std': '0.0455', 'feature_usage_entropy': '734.9283', 'mean_correlation': '0.5949', 'feature_importance_std': '0.0000'}


Epoch 32/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.82it/s]



Epoch 32 metrics:
Train: {'recon_loss': '0.0015', 'l1_loss': '0.0254', 'kl_loss': '0.0099', 'dict_norm_loss': '0.0888', 'total_loss': '3.5317', 'sparsity_ratio': '0.0511', 'max_activation': '3.1928', 'dead_features': '0.0000', 'feature_usage_std': '0.0430'}
Val: {'recon_loss': '0.0014', 'l1_loss': '0.0253', 'kl_loss': '0.0101', 'dict_norm_loss': '0.0889', 'total_loss': '3.5405', 'sparsity_ratio': '0.0509', 'dead_features': '0.0000', 'rarely_active': '0.1316', 'hyperactive': '13.3158', 'max_activation': '3.1952', 'mean_activation': '0.0253', 'std_activation': '0.2228', 'feature_usage_std': '0.0433', 'feature_usage_entropy': '698.2536', 'mean_correlation': '0.5868', 'feature_importance_std': '0.0000'}


Epoch 33/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.81it/s]



Epoch 33 metrics:
Train: {'recon_loss': '0.0014', 'l1_loss': '0.0253', 'kl_loss': '0.0099', 'dict_norm_loss': '0.0896', 'total_loss': '3.5275', 'sparsity_ratio': '0.0492', 'max_activation': '3.1979', 'dead_features': '0.0000', 'feature_usage_std': '0.0428'}
Val: {'recon_loss': '0.0013', 'l1_loss': '0.0252', 'kl_loss': '0.0102', 'dict_norm_loss': '0.0894', 'total_loss': '3.5399', 'sparsity_ratio': '0.0485', 'dead_features': '0.0000', 'rarely_active': '0.2895', 'hyperactive': '12.9474', 'max_activation': '3.1967', 'mean_activation': '0.0252', 'std_activation': '0.2251', 'feature_usage_std': '0.0425', 'feature_usage_entropy': '675.6532', 'mean_correlation': '0.5964', 'feature_importance_std': '0.0000'}


Epoch 34/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.85it/s]



Epoch 34 metrics:
Train: {'recon_loss': '0.0014', 'l1_loss': '0.0253', 'kl_loss': '0.0100', 'dict_norm_loss': '0.0914', 'total_loss': '3.5358', 'sparsity_ratio': '0.0470', 'max_activation': '3.2023', 'dead_features': '0.0000', 'feature_usage_std': '0.0423'}
Val: {'recon_loss': '0.0016', 'l1_loss': '0.0248', 'kl_loss': '0.0107', 'dict_norm_loss': '0.0892', 'total_loss': '3.5439', 'sparsity_ratio': '0.0449', 'dead_features': '0.0000', 'rarely_active': '0.6579', 'hyperactive': '12.8158', 'max_activation': '3.2052', 'mean_activation': '0.0248', 'std_activation': '0.2275', 'feature_usage_std': '0.0408', 'feature_usage_entropy': '641.9308', 'mean_correlation': '0.5905', 'feature_importance_std': '0.0000'}


Epoch 35/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.77it/s]



Epoch 35 metrics:
Train: {'recon_loss': '0.0015', 'l1_loss': '0.0253', 'kl_loss': '0.0101', 'dict_norm_loss': '0.0884', 'total_loss': '3.5332', 'sparsity_ratio': '0.0446', 'max_activation': '3.2065', 'dead_features': '0.0000', 'feature_usage_std': '0.0413'}
Val: {'recon_loss': '0.0015', 'l1_loss': '0.0253', 'kl_loss': '0.0101', 'dict_norm_loss': '0.0883', 'total_loss': '3.5388', 'sparsity_ratio': '0.0440', 'dead_features': '0.0000', 'rarely_active': '1.6053', 'hyperactive': '12.4474', 'max_activation': '3.2062', 'mean_activation': '0.0253', 'std_activation': '0.2309', 'feature_usage_std': '0.0410', 'feature_usage_entropy': '630.8351', 'mean_correlation': '0.5933', 'feature_importance_std': '0.0000'}


Epoch 36/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.77it/s]



Epoch 36 metrics:
Train: {'recon_loss': '0.0015', 'l1_loss': '0.0254', 'kl_loss': '0.0099', 'dict_norm_loss': '0.0880', 'total_loss': '3.5271', 'sparsity_ratio': '0.0430', 'max_activation': '3.2110', 'dead_features': '0.0000', 'feature_usage_std': '0.0410'}
Val: {'recon_loss': '0.0014', 'l1_loss': '0.0254', 'kl_loss': '0.0100', 'dict_norm_loss': '0.0915', 'total_loss': '3.5356', 'sparsity_ratio': '0.0426', 'dead_features': '0.0000', 'rarely_active': '4.4211', 'hyperactive': '12.4474', 'max_activation': '3.2053', 'mean_activation': '0.0254', 'std_activation': '0.2330', 'feature_usage_std': '0.0412', 'feature_usage_entropy': '614.5599', 'mean_correlation': '0.6231', 'feature_importance_std': '0.0000'}


Epoch 37/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.89it/s]



Epoch 37 metrics:
Train: {'recon_loss': '0.0014', 'l1_loss': '0.0252', 'kl_loss': '0.0100', 'dict_norm_loss': '0.0941', 'total_loss': '3.5268', 'sparsity_ratio': '0.0410', 'max_activation': '3.2145', 'dead_features': '0.0000', 'feature_usage_std': '0.0409'}
Val: {'recon_loss': '0.0014', 'l1_loss': '0.0257', 'kl_loss': '0.0097', 'dict_norm_loss': '0.0953', 'total_loss': '3.5377', 'sparsity_ratio': '0.0413', 'dead_features': '0.0000', 'rarely_active': '6.8684', 'hyperactive': '12.5000', 'max_activation': '3.2167', 'mean_activation': '0.0257', 'std_activation': '0.2361', 'feature_usage_std': '0.0414', 'feature_usage_entropy': '599.9575', 'mean_correlation': '0.6263', 'feature_importance_std': '0.0000'}


Epoch 38/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 69.00it/s]



Epoch 38 metrics:
Train: {'recon_loss': '0.0015', 'l1_loss': '0.0253', 'kl_loss': '0.0099', 'dict_norm_loss': '0.0928', 'total_loss': '3.5275', 'sparsity_ratio': '0.0396', 'max_activation': '3.2161', 'dead_features': '0.0000', 'feature_usage_std': '0.0405'}
Val: {'recon_loss': '0.0021', 'l1_loss': '0.0256', 'kl_loss': '0.0097', 'dict_norm_loss': '0.0916', 'total_loss': '3.5348', 'sparsity_ratio': '0.0398', 'dead_features': '0.0000', 'rarely_active': '11.7368', 'hyperactive': '11.8421', 'max_activation': '3.2255', 'mean_activation': '0.0256', 'std_activation': '0.2384', 'feature_usage_std': '0.0406', 'feature_usage_entropy': '583.2651', 'mean_correlation': '0.6468', 'feature_importance_std': '0.0000'}


Epoch 39/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 69.02it/s]



Epoch 39 metrics:
Train: {'recon_loss': '0.0016', 'l1_loss': '0.0253', 'kl_loss': '0.0100', 'dict_norm_loss': '0.0908', 'total_loss': '3.5312', 'sparsity_ratio': '0.0381', 'max_activation': '3.2188', 'dead_features': '0.0000', 'feature_usage_std': '0.0404'}
Val: {'recon_loss': '0.0015', 'l1_loss': '0.0254', 'kl_loss': '0.0099', 'dict_norm_loss': '0.0927', 'total_loss': '3.5304', 'sparsity_ratio': '0.0378', 'dead_features': '0.0000', 'rarely_active': '18.6316', 'hyperactive': '11.8684', 'max_activation': '3.2239', 'mean_activation': '0.0254', 'std_activation': '0.2401', 'feature_usage_std': '0.0403', 'feature_usage_entropy': '561.2244', 'mean_correlation': '0.6603', 'feature_importance_std': '0.0000'}


Epoch 40/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.87it/s]



Epoch 40 metrics:
Train: {'recon_loss': '0.0016', 'l1_loss': '0.0253', 'kl_loss': '0.0099', 'dict_norm_loss': '0.0916', 'total_loss': '3.5225', 'sparsity_ratio': '0.0368', 'max_activation': '3.2206', 'dead_features': '0.0000', 'feature_usage_std': '0.0401'}
Val: {'recon_loss': '0.0013', 'l1_loss': '0.0250', 'kl_loss': '0.0103', 'dict_norm_loss': '0.0957', 'total_loss': '3.5288', 'sparsity_ratio': '0.0358', 'dead_features': '0.0000', 'rarely_active': '30.8158', 'hyperactive': '11.6316', 'max_activation': '3.2213', 'mean_activation': '0.0250', 'std_activation': '0.2413', 'feature_usage_std': '0.0391', 'feature_usage_entropy': '541.0594', 'mean_correlation': '0.6565', 'feature_importance_std': '0.0000'}


Epoch 41/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 69.33it/s]



Epoch 41 metrics:
Train: {'recon_loss': '0.0015', 'l1_loss': '0.0253', 'kl_loss': '0.0099', 'dict_norm_loss': '0.0987', 'total_loss': '3.5221', 'sparsity_ratio': '0.0357', 'max_activation': '3.2231', 'dead_features': '0.0000', 'feature_usage_std': '0.0394'}
Val: {'recon_loss': '0.0019', 'l1_loss': '0.0251', 'kl_loss': '0.0102', 'dict_norm_loss': '0.0987', 'total_loss': '3.5281', 'sparsity_ratio': '0.0346', 'dead_features': '0.0000', 'rarely_active': '43.7895', 'hyperactive': '11.0263', 'max_activation': '3.2230', 'mean_activation': '0.0251', 'std_activation': '0.2440', 'feature_usage_std': '0.0383', 'feature_usage_entropy': '527.8495', 'mean_correlation': '0.6582', 'feature_importance_std': '0.0000'}


Epoch 42/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.69it/s]



Epoch 42 metrics:
Train: {'recon_loss': '0.0016', 'l1_loss': '0.0253', 'kl_loss': '0.0099', 'dict_norm_loss': '0.0968', 'total_loss': '3.5240', 'sparsity_ratio': '0.0342', 'max_activation': '3.2247', 'dead_features': '0.0000', 'feature_usage_std': '0.0391'}
Val: {'recon_loss': '0.0021', 'l1_loss': '0.0257', 'kl_loss': '0.0096', 'dict_norm_loss': '0.0952', 'total_loss': '3.5278', 'sparsity_ratio': '0.0342', 'dead_features': '0.0000', 'rarely_active': '60.1842', 'hyperactive': '11.3684', 'max_activation': '3.2273', 'mean_activation': '0.0257', 'std_activation': '0.2471', 'feature_usage_std': '0.0396', 'feature_usage_entropy': '519.4700', 'mean_correlation': '0.6766', 'feature_importance_std': '0.0000'}


Epoch 43/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.60it/s]



Epoch 43 metrics:
Train: {'recon_loss': '0.0015', 'l1_loss': '0.0253', 'kl_loss': '0.0099', 'dict_norm_loss': '0.0986', 'total_loss': '3.5198', 'sparsity_ratio': '0.0327', 'max_activation': '3.2262', 'dead_features': '0.0000', 'feature_usage_std': '0.0383'}
Val: {'recon_loss': '0.0014', 'l1_loss': '0.0251', 'kl_loss': '0.0101', 'dict_norm_loss': '0.1002', 'total_loss': '3.5246', 'sparsity_ratio': '0.0320', 'dead_features': '0.0000', 'rarely_active': '84.4474', 'hyperactive': '10.9211', 'max_activation': '3.2219', 'mean_activation': '0.0251', 'std_activation': '0.2478', 'feature_usage_std': '0.0379', 'feature_usage_entropy': '496.9270', 'mean_correlation': '0.6742', 'feature_importance_std': '0.0000'}


Epoch 44/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.69it/s]



Epoch 44 metrics:
Train: {'recon_loss': '0.0015', 'l1_loss': '0.0253', 'kl_loss': '0.0099', 'dict_norm_loss': '0.1017', 'total_loss': '3.5235', 'sparsity_ratio': '0.0318', 'max_activation': '3.2279', 'dead_features': '0.0000', 'feature_usage_std': '0.0386'}
Val: {'recon_loss': '0.0013', 'l1_loss': '0.0251', 'kl_loss': '0.0101', 'dict_norm_loss': '0.1028', 'total_loss': '3.5242', 'sparsity_ratio': '0.0312', 'dead_features': '0.0000', 'rarely_active': '105.5000', 'hyperactive': '10.9737', 'max_activation': '3.2312', 'mean_activation': '0.0251', 'std_activation': '0.2500', 'feature_usage_std': '0.0380', 'feature_usage_entropy': '486.3195', 'mean_correlation': '0.6832', 'feature_importance_std': '0.0000'}


Epoch 45/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 69.13it/s]



Epoch 45 metrics:
Train: {'recon_loss': '0.0016', 'l1_loss': '0.0253', 'kl_loss': '0.0099', 'dict_norm_loss': '0.1022', 'total_loss': '3.5213', 'sparsity_ratio': '0.0306', 'max_activation': '3.2294', 'dead_features': '0.0000', 'feature_usage_std': '0.0380'}
Val: {'recon_loss': '0.0015', 'l1_loss': '0.0256', 'kl_loss': '0.0096', 'dict_norm_loss': '0.1036', 'total_loss': '3.5245', 'sparsity_ratio': '0.0309', 'dead_features': '0.0000', 'rarely_active': '143.6316', 'hyperactive': '10.6579', 'max_activation': '3.2317', 'mean_activation': '0.0256', 'std_activation': '0.2528', 'feature_usage_std': '0.0383', 'feature_usage_entropy': '479.2269', 'mean_correlation': '0.6961', 'feature_importance_std': '0.0000'}


Epoch 46/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.84it/s]



Epoch 46 metrics:
Train: {'recon_loss': '0.0015', 'l1_loss': '0.0253', 'kl_loss': '0.0098', 'dict_norm_loss': '0.1039', 'total_loss': '3.5153', 'sparsity_ratio': '0.0297', 'max_activation': '3.2303', 'dead_features': '0.0000', 'feature_usage_std': '0.0378'}
Val: {'recon_loss': '0.0015', 'l1_loss': '0.0255', 'kl_loss': '0.0097', 'dict_norm_loss': '0.1012', 'total_loss': '3.5245', 'sparsity_ratio': '0.0298', 'dead_features': '0.0000', 'rarely_active': '183.1316', 'hyperactive': '10.9737', 'max_activation': '3.2300', 'mean_activation': '0.0255', 'std_activation': '0.2537', 'feature_usage_std': '0.0384', 'feature_usage_entropy': '466.7622', 'mean_correlation': '0.6864', 'feature_importance_std': '0.0000'}


Epoch 47/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.83it/s]



Epoch 47 metrics:
Train: {'recon_loss': '0.0016', 'l1_loss': '0.0253', 'kl_loss': '0.0098', 'dict_norm_loss': '0.1030', 'total_loss': '3.5169', 'sparsity_ratio': '0.0289', 'max_activation': '3.2314', 'dead_features': '0.0000', 'feature_usage_std': '0.0378'}
Val: {'recon_loss': '0.0015', 'l1_loss': '0.0257', 'kl_loss': '0.0096', 'dict_norm_loss': '0.1063', 'total_loss': '3.5243', 'sparsity_ratio': '0.0294', 'dead_features': '0.0000', 'rarely_active': '210.6316', 'hyperactive': '10.8158', 'max_activation': '3.2293', 'mean_activation': '0.0257', 'std_activation': '0.2552', 'feature_usage_std': '0.0383', 'feature_usage_entropy': '461.0586', 'mean_correlation': '0.6964', 'feature_importance_std': '0.0000'}


Epoch 48/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.96it/s]



Epoch 48 metrics:
Train: {'recon_loss': '0.0012', 'l1_loss': '0.0253', 'kl_loss': '0.0099', 'dict_norm_loss': '0.1071', 'total_loss': '3.5186', 'sparsity_ratio': '0.0283', 'max_activation': '3.2308', 'dead_features': '0.0000', 'feature_usage_std': '0.0377'}
Val: {'recon_loss': '0.0012', 'l1_loss': '0.0251', 'kl_loss': '0.0101', 'dict_norm_loss': '0.1078', 'total_loss': '3.5240', 'sparsity_ratio': '0.0277', 'dead_features': '0.0000', 'rarely_active': '233.5263', 'hyperactive': '10.8947', 'max_activation': '3.2315', 'mean_activation': '0.0251', 'std_activation': '0.2550', 'feature_usage_std': '0.0370', 'feature_usage_entropy': '444.8228', 'mean_correlation': '0.6871', 'feature_importance_std': '0.0000'}


Epoch 49/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.75it/s]



Epoch 49 metrics:
Train: {'recon_loss': '0.0012', 'l1_loss': '0.0253', 'kl_loss': '0.0098', 'dict_norm_loss': '0.1092', 'total_loss': '3.5179', 'sparsity_ratio': '0.0277', 'max_activation': '3.2318', 'dead_features': '0.0000', 'feature_usage_std': '0.0375'}
Val: {'recon_loss': '0.0012', 'l1_loss': '0.0252', 'kl_loss': '0.0100', 'dict_norm_loss': '0.1094', 'total_loss': '3.5227', 'sparsity_ratio': '0.0273', 'dead_features': '0.0000', 'rarely_active': '269.3947', 'hyperactive': '10.9211', 'max_activation': '3.2324', 'mean_activation': '0.0252', 'std_activation': '0.2559', 'feature_usage_std': '0.0370', 'feature_usage_entropy': '438.5318', 'mean_correlation': '0.6882', 'feature_importance_std': '0.0000'}


Epoch 50/50: 100%|████████████████████████████| 150/150 [00:02<00:00, 68.77it/s]



Epoch 50 metrics:
Train: {'recon_loss': '0.0012', 'l1_loss': '0.0253', 'kl_loss': '0.0099', 'dict_norm_loss': '0.1097', 'total_loss': '3.5208', 'sparsity_ratio': '0.0272', 'max_activation': '3.2320', 'dead_features': '0.0000', 'feature_usage_std': '0.0374'}
Val: {'recon_loss': '0.0012', 'l1_loss': '0.0255', 'kl_loss': '0.0097', 'dict_norm_loss': '0.1095', 'total_loss': '3.5222', 'sparsity_ratio': '0.0276', 'dead_features': '0.0000', 'rarely_active': '308.9211', 'hyperactive': '10.8158', 'max_activation': '3.2313', 'mean_activation': '0.0255', 'std_activation': '0.2569', 'feature_usage_std': '0.0377', 'feature_usage_entropy': '439.4990', 'mean_correlation': '0.6909', 'feature_importance_std': '0.0000'}


In [16]:
vae

NameError: name 'vae' is not defined