In [None]:


import unittest
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import pandas as pd
import json
from datetime import datetime
import os


def logistic_map(x, r=3.99):
    return r * x * (1 - x)

def generate_logistic_map_image(image_size=28, initial_value=0.4, r=3.99):
    iterations = image_size * image_size
    x = initial_value
    seq = []
    for _ in range(iterations):
        x = logistic_map(x, r)
        seq.append(x)
    img = np.array(seq).reshape((image_size, image_size))
    return img

def generate_logistic_dataset(num_images, image_size=28, r=3.99, fixed_initial=False):
    dataset = []
    for _ in range(num_images):
        init_val = 0.4 if fixed_initial else np.random.rand()
        img = generate_logistic_map_image(image_size=image_size, initial_value=init_val, r=r)
        dataset.append(img)
    return np.array(dataset)[..., np.newaxis].astype('float32')

def henon_map(x, y, a=1.4, b=0.3):
    """Henon attractor - 2D chaotic system"""
    x = np.clip(x, -2.0, 2.0)
    y = np.clip(y, -2.0, 2.0)
    
    x_new = 1 - a * x**2 + y
    y_new = b * x
    
    return x_new, y_new

def generate_henon_image(image_size=28, x0=0.1, y0=0.1, a=1.4, b=0.3):
    """Generate image from Henon attractor trajectory"""
    iterations = image_size * image_size
    x, y = x0, y0
    points = []

    for _ in range(iterations):
        x, y = henon_map(x, y, a, b)
        # Normalize to [0, 1] (Henon values roughly in [-1.5, 1.5])
        normalized = (x + 2.0) / 4.0
        points.append(np.clip(normalized, 0, 1))

    img = np.array(points).reshape((image_size, image_size))
    return img

def generate_henon_dataset(num_images, image_size=28, a=1.4, b=0.3):
    dataset = []
    for _ in range(num_images):
        x0 = np.random.uniform(-0.5, 0.5)
        y0 = np.random.uniform(-0.5, 0.5)
        img = generate_henon_image(image_size, x0, y0, a, b)
        dataset.append(img)
    return np.array(dataset)[..., np.newaxis].astype('float32')


class KSparseLayer(layers.Layer):
    def __init__(self, k=32, **kwargs):
        super(KSparseLayer, self).__init__(**kwargs)
        self.k = k

    def call(self, inputs, training=None):
        batch_size = tf.shape(inputs)[0]
        latent_dim = tf.shape(inputs)[1]
        values, indices = tf.nn.top_k(tf.abs(inputs), k=self.k, sorted=False)
        mask = tf.reduce_sum(
            tf.one_hot(indices, latent_dim, dtype=inputs.dtype),
            axis=1
        )
        return inputs * mask

    def get_config(self):
        config = super().get_config()
        config.update({"k": self.k})
        return config

class TargetVarianceRegularizer(layers.Layer):
    def __init__(self, lambda_reg=0.01, target_variance=0.1, **kwargs):
        super(TargetVarianceRegularizer, self).__init__(**kwargs)
        self.lambda_reg = lambda_reg
        self.target_variance = target_variance

    def call(self, inputs):
        current_variance = tf.math.reduce_variance(inputs, axis=0)
        mean_variance = tf.reduce_mean(current_variance)
        variance_penalty = self.lambda_reg * tf.square(
            mean_variance - self.target_variance
        )
        self.add_loss(variance_penalty)
        return inputs

    def get_config(self):
        config = super().get_config()
        config.update({
            "lambda_reg": self.lambda_reg,
            "target_variance": self.target_variance
        })
        return config

def chaos_activation(x):
    return tf.sin(8.0 * x) + 0.5 * tf.tanh(4.0 * x)


def build_ksparse_chaos_ae(image_size=(28, 28), latent_dim=128, k_active=32):
    """K-Sparse Chaos autoencoder"""
    input_img = keras.Input(shape=(*image_size, 1))
    x = layers.Flatten()(input_img)

    x = layers.Dense(256)(x)
    x = layers.Activation(chaos_activation)(x)
    x = layers.Dropout(0.2)(x)

    latent_pre = layers.Dense(latent_dim, name='latent_pre')(x)
    latent_pre = layers.Activation(chaos_activation)(latent_pre)

    latent = KSparseLayer(k=k_active, name='latent_ksparse')(latent_pre)
    latent = TargetVarianceRegularizer(
        lambda_reg=0.01,
        target_variance=0.1
    )(latent)

    encoder = keras.Model(input_img, latent, name='ksparse_chaos_encoder')

    x = layers.Dense(256)(latent)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(chaos_activation)(x)
    x = layers.Dropout(0.1)(x)

    decoded = layers.Dense(np.prod(image_size), activation='sigmoid')(x)
    decoded = layers.Reshape((*image_size, 1))(decoded)

    autoencoder = keras.Model(input_img, decoded, name='ksparse_chaos_autoencoder')
    autoencoder.compile(optimizer='adam', loss='mse')

    return autoencoder, encoder

def build_dense_relu_ae(image_size=(28, 28), latent_dim=64):
    """Dense ReLU baseline"""
    h, w = image_size
    input_img = keras.Input(shape=(h, w, 1))
    x = layers.Flatten()(input_img)

    x = layers.Dense(256, activation="relu")(x)
    x = layers.Dense(128, activation="relu")(x)
    latent = layers.Dense(latent_dim, activation="relu", name="latent")(x)

    encoder = keras.Model(input_img, latent, name="dense_encoder")

    x = layers.Dense(128, activation="relu")(latent)
    x = layers.Dense(256, activation="relu")(x)
    x = layers.Dense(h * w, activation="sigmoid")(x)
    decoded = layers.Reshape((h, w, 1))(x)

    autoencoder = keras.Model(input_img, decoded)
    autoencoder.compile(optimizer=keras.optimizers.Adam(1e-3), loss="mse")

    return autoencoder, encoder


def analyze_latent_statistics(encoder, images, zero_threshold=1e-6):
    """Comprehensive latent space analysis"""
    latents = encoder.predict(images, verbose=0)

    # Variance
    variance_per_dim = np.var(latents, axis=0)
    mean_variance = float(np.mean(variance_per_dim))

    # Dead neurons
    dead_mask = np.all(np.abs(latents) < zero_threshold, axis=0)
    dead_neurons = int(np.sum(dead_mask))
    total_neurons = latents.shape[1]

    # Sparsity
    zero_mask = np.abs(latents) < zero_threshold
    overall_sparsity = np.mean(zero_mask)

    # Active neurons statistics
    active_per_sample = np.sum(~zero_mask, axis=1)
    mean_active = np.mean(active_per_sample)

    # Variance of active neurons only
    variance_active = []
    for dim in range(latents.shape[1]):
        dim_values = latents[:, dim]
        active_mask = ~zero_mask[:, dim]
        if np.sum(active_mask) > 1:
            active_values = dim_values[active_mask]
            variance_active.append(np.var(active_values))

    mean_variance_active = np.mean(variance_active) if variance_active else 0.0

    return {
        'variance_per_dim': variance_per_dim,
        'mean_variance': mean_variance,
        'dead_neurons': dead_neurons,
        'total_neurons': total_neurons,
        'dead_percentage': dead_neurons / total_neurons,
        'overall_sparsity': overall_sparsity,
        'mean_active_neurons': mean_active,
        'mean_variance_active': mean_variance_active,
        'latents': latents
    }

def track_dead_neurons_over_time(model, encoder, images, epochs=50, batch_size=64):
    """Track how neurons die during training"""
    trajectory = []

    for epoch in range(epochs):
        # Train one epoch
        model.fit(images, images, epochs=1, batch_size=batch_size, verbose=0)

        # Analyze current state
        stats = analyze_latent_statistics(encoder, images[:200])

        trajectory.append({
            'epoch': epoch,
            'dead_neurons': stats['dead_neurons'],
            'mean_variance': stats['mean_variance'],
            'sparsity': stats['overall_sparsity']
        })

    return trajectory


class TestKSparseAblation(unittest.TestCase):
    """
    CRITICAL EXPERIMENT #1: K-Sparse Ablation Study
    Test different K values to find optimal sparsity
    """

    @classmethod
    def setUpClass(cls):
        print("\n" + "="*80)
        print("K-SPARSE ABLATION STUDY")
        print("="*80)

        # Generate data once
        cls.train_images = generate_logistic_dataset(2000, fixed_initial=False)
        cls.test_images = generate_logistic_dataset(500, fixed_initial=False)

        # Test different K values
        cls.k_values = [4, 8, 16, 32, 64, 96, 112]
        cls.latent_dim = 128
        cls.results = {}

        print(f"\nTesting K values: {cls.k_values}")
        print(f"Latent dimension: {cls.latent_dim}")
        print(f"Sparsity range: {(cls.latent_dim - max(cls.k_values))/cls.latent_dim:.1%} to {(cls.latent_dim - min(cls.k_values))/cls.latent_dim:.1%}")

        # Train models for each K
        for k in cls.k_values:
            print(f"\n[K={k}] Training K-Sparse Chaos AE...")
            ae, enc = build_ksparse_chaos_ae(
                latent_dim=cls.latent_dim,
                k_active=k
            )

            history = ae.fit(
                cls.train_images, cls.train_images,
                epochs=10,
                batch_size=64,
                validation_split=0.1,
                verbose=1
            )

            # Analyze
            stats = analyze_latent_statistics(enc, cls.test_images)

            cls.results[k] = {
                'autoencoder': ae,
                'encoder': enc,
                'stats': stats,
                'val_loss': history.history['val_loss'][-1],
                'sparsity': (cls.latent_dim - k) / cls.latent_dim
            }

            print(f"  Variance: {stats['mean_variance']:.6f}")
            print(f"  Dead neurons: {stats['dead_neurons']}/{cls.latent_dim}")
            print(f"  Sparsity: {cls.results[k]['sparsity']:.1%}")

    def test_01_variance_vs_k(self):
        """Test how variance changes with K"""
        print("\n" + "="*80)
        print("TEST: Variance vs K")
        print("="*80)

        for k in self.k_values:
            variance = self.results[k]['stats']['mean_variance']
            print(f"K={k:3d}: variance = {variance:.6f}")

        # Check that variance generally increases with K
        variances = [self.results[k]['stats']['mean_variance'] for k in self.k_values]

        # At least monotonic trend (allowing some noise)
        trend_violations = 0
        for i in range(len(variances)-1):
            if variances[i+1] < variances[i]:
                trend_violations += 1

        print(f"\nTrend violations: {trend_violations}/{len(variances)-1}")
        self.assertLess(trend_violations, len(variances)//2,
                       "Variance should generally increase with K")

    def test_02_find_optimal_k(self):
        """Find K that balances variance and sparsity"""
        print("\n" + "="*80)
        print("TEST: Find Optimal K")
        print("="*80)

        # Define metric: variance / (1 - sparsity)
        # Higher = better efficiency
        scores = {}
        for k in self.k_values:
            variance = self.results[k]['stats']['mean_variance']
            sparsity = self.results[k]['sparsity']
            score = variance / (1 - sparsity + 0.01)  # avoid division by zero
            scores[k] = score
            print(f"K={k:3d}: score = {score:.3f} (var={variance:.3f}, sparsity={sparsity:.1%})")

        optimal_k = max(scores, key=scores.get)
        print(f"\n‚≠ê Optimal K: {optimal_k}")
        print(f"   Variance: {self.results[optimal_k]['stats']['mean_variance']:.6f}")
        print(f"   Sparsity: {self.results[optimal_k]['sparsity']:.1%}")

        # Save result
        self.optimal_k = optimal_k

    def test_03_create_ablation_plot(self):
        """Create comprehensive ablation visualization"""
        print("\n" + "="*80)
        print("Creating K-Sparse Ablation Plots")
        print("="*80)

        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

        k_vals = self.k_values
        variances = [self.results[k]['stats']['mean_variance'] for k in k_vals]
        sparsities = [self.results[k]['sparsity'] for k in k_vals]
        dead = [self.results[k]['stats']['dead_neurons'] for k in k_vals]
        losses = [self.results[k]['val_loss'] for k in k_vals]

        # Plot 1: Variance vs K
        axes[0, 0].plot(k_vals, variances, 'o-', linewidth=2, markersize=8)
        axes[0, 0].set_xlabel('K (Active Neurons)')
        axes[0, 0].set_ylabel('Mean Variance')
        axes[0, 0].set_title('Variance vs K', fontweight='bold')
        axes[0, 0].grid(True, alpha=0.3)
        axes[0, 0].axvline(x=32, color='red', linestyle='--', alpha=0.5, label='K=32 (your choice)')
        axes[0, 0].legend()

        # Plot 2: Sparsity vs Variance
        axes[0, 1].plot(sparsities, variances, 'o-', linewidth=2, markersize=8)
        axes[0, 1].set_xlabel('Sparsity')
        axes[0, 1].set_ylabel('Mean Variance')
        axes[0, 1].set_title('Sparsity-Variance Trade-off', fontweight='bold')
        axes[0, 1].grid(True, alpha=0.3)

        # Annotate K=32
        idx_32 = k_vals.index(32)
        axes[0, 1].annotate('K=32',
                           xy=(sparsities[idx_32], variances[idx_32]),
                           xytext=(10, 10), textcoords='offset points',
                           bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7),
                           arrowprops=dict(arrowstyle='->', color='red'))

        # Plot 3: Dead Neurons vs K
        axes[1, 0].bar(range(len(k_vals)), dead, color='steelblue', alpha=0.7)
        axes[1, 0].set_xticks(range(len(k_vals)))
        axes[1, 0].set_xticklabels(k_vals)
        axes[1, 0].set_xlabel('K (Active Neurons)')
        axes[1, 0].set_ylabel('Dead Neurons')
        axes[1, 0].set_title('Dead Neurons vs K', fontweight='bold')
        axes[1, 0].grid(True, alpha=0.3, axis='y')

        # Plot 4: Reconstruction Loss vs K
        axes[1, 1].plot(k_vals, losses, 'o-', linewidth=2, markersize=8, color='green')
        axes[1, 1].set_xlabel('K (Active Neurons)')
        axes[1, 1].set_ylabel('Validation Loss (MSE)')
        axes[1, 1].set_title('Reconstruction Quality vs K', fontweight='bold')
        axes[1, 1].grid(True, alpha=0.3)

        plt.suptitle('K-Sparse Ablation Study: Finding Optimal Sparsity',
                    fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.savefig('k_sparse_ablation.png', dpi=300, bbox_inches='tight')
        print("‚úì Saved: k_sparse_ablation.png")
        plt.close()

        # Also create a summary table
        self.create_ablation_table()

    def create_ablation_table(self):
        """Create LaTeX-ready table"""
        print("\n" + "="*80)
        print("ABLATION RESULTS TABLE (LaTeX format)")
        print("="*80)

        print("\\begin{table}[h]")
        print("\\centering")
        print("\\begin{tabular}{ccccc}")
        print("\\hline")
        print("K & Sparsity & Variance & Dead Neurons & Val Loss \\\\")
        print("\\hline")

        for k in self.k_values:
            r = self.results[k]
            print(f"{k} & {r['sparsity']:.1%} & {r['stats']['mean_variance']:.4f} & "
                  f"{r['stats']['dead_neurons']}/{self.latent_dim} & {r['val_loss']:.4f} \\\\")

        print("\\hline")
        print("\\end{tabular}")
        print("\\caption{K-Sparse ablation study results}")
        print("\\end{table}")


class TestFairBaselineComparison(unittest.TestCase):
    """
    CRITICAL EXPERIMENT #2: FAIR BASELINE COMPARISON
    Compare Dense_64 vs Dense_128 vs V4_128 (—á–µ—Å—Ç–Ω–æ–µ —Å—Ä–∞–≤–Ω–µ–Ω–∏–µ!)
    """

    @classmethod
    def setUpClass(cls):
        print("\n" + "="*80)
        print("üéØ FAIR BASELINE COMPARISON (Dense_64 vs Dense_128 vs V4_128)")
        print("="*80)

        cls.num_runs = 10  # –£–≤–µ–ª–∏—á–µ–Ω–æ —Å 5 –¥–æ 10 –¥–ª—è –ª—É—á—à–µ–π —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫–∏
        cls.architectures = {
            'Dense_ReLU_64': lambda: build_dense_relu_ae(latent_dim=64),
            'Dense_ReLU_128': lambda: build_dense_relu_ae(latent_dim=128),  # ‚Üê –ù–û–í–û–ï!
            'V4_KSparse_128': lambda: build_ksparse_chaos_ae(latent_dim=128, k_active=32)
        }

        cls.results = {name: [] for name in cls.architectures.keys()}

        # Generate data once
        cls.train_images = generate_logistic_dataset(2000, fixed_initial=False)
        cls.test_images = generate_logistic_dataset(500, fixed_initial=False)

        print(f"\nüìä –ó–∞–ø—É—Å–∫–∞–µ–º N={cls.num_runs} runs –¥–ª—è –∫–∞–∂–¥–æ–π –∞—Ä—Ö–∏—Ç–µ–∫—Ç—É—Ä—ã...")
        print(f"   Train: {cls.train_images.shape}")
        print(f"   Test: {cls.test_images.shape}")
        print("\n‚ö†Ô∏è –≠—Ç–æ –∑–∞–π–º–µ—Ç ~2-3 —á–∞—Å–∞. –ù–∞–±–µ—Ä–∏—Ç–µ—Å—å —Ç–µ—Ä–ø–µ–Ω–∏—è!\n")

        # Run experiments
        for arch_idx, (arch_name, builder) in enumerate(cls.architectures.items(), 1):
            print(f"\n{'='*70}")
            print(f"[{arch_idx}/3] –ê—Ä—Ö–∏—Ç–µ–∫—Ç—É—Ä–∞: {arch_name}")
            print(f"{'='*70}")

            for run in range(cls.num_runs):
                print(f"  [Run {run+1}/{cls.num_runs}] ", end="")

                # Set seeds
                np.random.seed(run)
                tf.random.set_seed(run)

                # Build and train
                ae, enc = builder()
                history = ae.fit(
                    cls.train_images, cls.train_images,
                    epochs=10,
                    batch_size=64,
                    validation_split=0.1,
                    verbose=0  # —Ç–∏—Ö–∏–π —Ä–µ–∂–∏–º
                )

                # Analyze
                stats = analyze_latent_statistics(enc, cls.test_images)

                cls.results[arch_name].append({
                    'run': run,
                    'variance': stats['mean_variance'],
                    'dead_neurons': stats['dead_neurons'],
                    'total_neurons': stats['total_neurons'],
                    'dead_percentage': stats['dead_percentage'],
                    'val_loss': history.history['val_loss'][-1]
                })

                print(f"var={stats['mean_variance']:.4f}, "
                      f"dead={stats['dead_neurons']}/{stats['total_neurons']} "
                      f"({stats['dead_percentage']:.0%}), "
                      f"loss={history.history['val_loss'][-1]:.4f}")

    def test_01_compute_statistics(self):
        """Compute mean, std, confidence intervals"""
        print("\n" + "="*80)
        print("TEST: Statistical Summary")
        print("="*80)

        summary = {}

        for arch_name, runs in self.results.items():
            print(f"\n{arch_name}:")
            print(f"{'='*50}")

            # Extract metrics
            variances = [r['variance'] for r in runs]
            dead_counts = [r['dead_neurons'] for r in runs]
            dead_pcts = [r['dead_percentage'] for r in runs]
            losses = [r['val_loss'] for r in runs]

            # Compute statistics
            var_mean = np.mean(variances)
            var_std = np.std(variances)
            var_ci = 1.96 * var_std / np.sqrt(len(variances))  # 95% CI

            dead_mean = np.mean(dead_counts)
            dead_std = np.std(dead_counts)
            dead_pct_mean = np.mean(dead_pcts)

            loss_mean = np.mean(losses)
            loss_std = np.std(losses)

            print(f"Variance:     {var_mean:.6f} ¬± {var_std:.6f} (CI: ¬±{var_ci:.6f})")
            print(f"Dead neurons: {dead_mean:.1f} ¬± {dead_std:.1f} ({dead_pct_mean:.1%})")
            print(f"Val loss:     {loss_mean:.6f} ¬± {loss_std:.6f}")

            summary[arch_name] = {
                'variances': variances,
                'dead_pcts': dead_pcts,
                'var_mean': var_mean,
                'var_std': var_std,
                'var_ci': var_ci,
                'dead_mean': dead_mean,
                'dead_pct_mean': dead_pct_mean,
                'loss_mean': loss_mean
            }

        self.summary = summary

    def test_02_fair_statistical_comparison(self):
        """Compare Dense_128 vs V4_128 (FAIR)"""
        print("\n" + "="*80)
        print("TEST: FAIR Statistical Comparison (t-tests)")
        print("="*80)

        # 1. Dense_64 vs Dense_128 (–≤–ª–∏—è–Ω–∏–µ —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–∏)
        if 'Dense_ReLU_64' in self.summary and 'Dense_ReLU_128' in self.summary:
            print("\n1. Dense_64 vs Dense_128 (–í–ª–∏—è–Ω–∏–µ —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–∏):")
            print("   " + "-"*60)

            vars_64 = self.summary['Dense_ReLU_64']['variances']
            vars_128 = self.summary['Dense_ReLU_128']['variances']

            t_stat, p_value = stats.ttest_ind(vars_64, vars_128)

            mean_64 = np.mean(vars_64)
            mean_128 = np.mean(vars_128)
            improvement = mean_128 / mean_64

            print(f"   Variance: {mean_64:.4f} ‚Üí {mean_128:.4f}")
            print(f"   t-statistic: {t_stat:.4f}")
            print(f"   p-value: {p_value:.6f}")

            if p_value < 0.05:
                print(f"   ‚úì –ó–ù–ê–ß–ò–ú–û: Dense_128 –∏–º–µ–µ—Ç –≤ {improvement:.2f}√ó –≤—ã—à–µ variance (p<0.05)")
                print(f"   ‚Üí –ü—Ä–æ—Å—Ç–æ –±–æ–ª—å—à–µ dimensions –ü–û–ú–û–ì–ê–ï–¢")
            else:
                print(f"   ‚úó –ù–ï –ó–ù–ê–ß–ò–ú–û")

        # 2. Dense_128 vs V4_128 (–ß–ï–°–¢–ù–û–ï –°–†–ê–í–ù–ï–ù–ò–ï!)
        if 'Dense_ReLU_128' in self.summary and 'V4_KSparse_128' in self.summary:
            print("\n2. Dense_128 vs V4_128 (üéØ –ß–ï–°–¢–ù–û–ï –°–†–ê–í–ù–ï–ù–ò–ï):")
            print("   " + "-"*60)

            vars_dense = self.summary['Dense_ReLU_128']['variances']
            vars_v4 = self.summary['V4_KSparse_128']['variances']

            dead_dense = self.summary['Dense_ReLU_128']['dead_pct_mean']
            dead_v4 = self.summary['V4_KSparse_128']['dead_pct_mean']

            t_stat, p_value = stats.ttest_ind(vars_dense, vars_v4)

            mean_dense = np.mean(vars_dense)
            mean_v4 = np.mean(vars_v4)
            improvement = mean_v4 / mean_dense

            print(f"   Variance: {mean_dense:.4f} ‚Üí {mean_v4:.4f}")
            print(f"   Dead neurons: {dead_dense:.1%} ‚Üí {dead_v4:.1%}")
            print(f"   t-statistic: {t_stat:.4f}")
            print(f"   p-value: {p_value:.6f}")

            if p_value < 0.05:
                print(f"   ‚úì –ó–ù–ê–ß–ò–ú–û: V4 –∏–º–µ–µ—Ç –≤ {improvement:.2f}√ó –≤—ã—à–µ variance (p<0.05)")
                print(f"   ‚úì –ó–ù–ê–ß–ò–ú–û: V4 –∏–º–µ–µ—Ç {dead_v4:.1%} dead vs {dead_dense:.1%}")
                print(f"\n   üéØ –≠–¢–û –í–ê–® –ì–õ–ê–í–ù–´–ô –†–ï–ó–£–õ–¨–¢–ê–¢ –î–õ–Ø –°–¢–ê–¢–¨–ò!")
                print(f"   ‚Üí K-Sparse + Chaos activation –†–ê–ë–û–¢–ê–ï–¢ –ø—Ä–∏ –æ–¥–∏–Ω–∞–∫–æ–≤–æ–π capacity")
            else:
                print(f"   ‚úó –ù–ï –ó–ù–ê–ß–ò–ú–û")

            # Test that improvement is significant
            self.assertLess(p_value, 0.05, "V4 should significantly outperform Dense_128")
            self.assertGreater(improvement, 1.5, "V4 should have at least 1.5√ó higher variance")

    def test_03_create_fair_comparison_plots(self):
        """Create detailed comparison visualization"""
        print("\n" + "="*80)
        print("Creating Fair Comparison Plots")
        print("="*80)

        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

        arch_names = list(self.results.keys())
        x = np.arange(len(arch_names))

        # Colors
        colors = {
            'Dense_ReLU_64': 'lightcoral',
            'Dense_ReLU_128': 'skyblue',
            'V4_KSparse_128': 'lightgreen'
        }
        color_list = [colors[name] for name in arch_names]

        # Collect data
        variances_mean = []
        variances_std = []
        dead_pct_mean = []
        dead_pct_std = []
        loss_mean = []
        loss_std = []

        for arch in arch_names:
            vars_list = [r['variance'] for r in self.results[arch]]
            dead_list = [r['dead_percentage'] * 100 for r in self.results[arch]]
            loss_list = [r['val_loss'] for r in self.results[arch]]

            variances_mean.append(np.mean(vars_list))
            variances_std.append(np.std(vars_list))
            dead_pct_mean.append(np.mean(dead_list))
            dead_pct_std.append(np.std(dead_list))
            loss_mean.append(np.mean(loss_list))
            loss_std.append(np.std(loss_list))

        # Plot 1: Variance (–≥–ª–∞–≤–Ω—ã–π!)
        axes[0, 0].bar(x, variances_mean, yerr=variances_std, capsize=5,
                       alpha=0.7, color=color_list, edgecolor='black', linewidth=1.5)
        axes[0, 0].set_xticks(x)
        axes[0, 0].set_xticklabels(arch_names, rotation=15, ha='right', fontsize=10)
        axes[0, 0].set_ylabel('Mean Variance', fontsize=12, fontweight='bold')
        axes[0, 0].set_title('Variance Comparison\n(FAIR: Same capacity for 128-dim models)',
                             fontweight='bold', fontsize=13)
        axes[0, 0].grid(True, alpha=0.3, axis='y')

        # –ê–Ω–Ω–æ—Ç–∞—Ü–∏—è –¥–ª—è —á–µ—Å—Ç–Ω–æ–≥–æ —Å—Ä–∞–≤–Ω–µ–Ω–∏—è
        if len(arch_names) >= 3:
            # –°—Ç—Ä–µ–ª–∫–∞ –º–µ–∂–¥—É Dense_128 –∏ V4_128
            axes[0, 0].annotate('', xy=(2, variances_mean[2]), xytext=(1, variances_mean[1]),
                                arrowprops=dict(arrowstyle='<->', color='red', lw=2.5))

            improvement = variances_mean[2] / variances_mean[1]
            mid_y = (variances_mean[1] + variances_mean[2]) / 2

            axes[0, 0].text(1.5, mid_y, f'{improvement:.2f}√ó\nFAIR',
                           ha='center', va='center', fontsize=12, color='red', fontweight='bold',
                           bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow',
                                     edgecolor='red', linewidth=2, alpha=0.8))

            # Unfair —Å—Ä–∞–≤–Ω–µ–Ω–∏–µ —Å–µ—Ä—ã–º
            axes[0, 0].annotate('', xy=(2, variances_mean[2]), xytext=(0, variances_mean[0]),
                                arrowprops=dict(arrowstyle='<->', color='gray', lw=1,
                                                linestyle='--', alpha=0.5))

            unfair_improvement = variances_mean[2] / variances_mean[0]
            axes[0, 0].text(1.0, (variances_mean[0] + variances_mean[2])/2 + 0.05,
                           f'{unfair_improvement:.2f}√ó (unfair)',
                           ha='center', fontsize=9, color='gray', style='italic', alpha=0.7)

        # Plot 2: Dead Neurons %
        axes[0, 1].bar(x, dead_pct_mean, yerr=dead_pct_std, capsize=5,
                       alpha=0.7, color=color_list, edgecolor='black', linewidth=1.5)
        axes[0, 1].set_xticks(x)
        axes[0, 1].set_xticklabels(arch_names, rotation=15, ha='right', fontsize=10)
        axes[0, 1].set_ylabel('Dead Neurons (%)', fontsize=12, fontweight='bold')
        axes[0, 1].set_title('Dead Neurons Comparison\n(Lower is better)',
                             fontweight='bold', fontsize=13)
        axes[0, 1].grid(True, alpha=0.3, axis='y')
        axes[0, 1].set_ylim(0, max(dead_pct_mean) * 1.2 if max(dead_pct_mean) > 0 else 10)

        # Plot 3: Loss
        axes[1, 0].bar(x, loss_mean, yerr=loss_std, capsize=5,
                       alpha=0.7, color=color_list, edgecolor='black', linewidth=1.5)
        axes[1, 0].set_xticks(x)
        axes[1, 0].set_xticklabels(arch_names, rotation=15, ha='right', fontsize=10)
        axes[1, 0].set_ylabel('Validation Loss (MSE)', fontsize=12, fontweight='bold')
        axes[1, 0].set_title('Reconstruction Quality\n(Lower is better)',
                             fontweight='bold', fontsize=13)
        axes[1, 0].grid(True, alpha=0.3, axis='y')

        # Plot 4: Summary table
        axes[1, 1].axis('off')

        table_data = []
        for idx, arch in enumerate(arch_names):
            latent_dim = 64 if '64' in arch else 128
            table_data.append([
                arch.replace('_', '\n'),
                f"{latent_dim}",
                f"{variances_mean[idx]:.3f}",
                f"{dead_pct_mean[idx]:.1f}%"
            ])

        table = axes[1, 1].table(cellText=table_data,
                                 colLabels=['Architecture', 'Dims', 'Variance', 'Dead %'],
                                 cellLoc='center',
                                 loc='center',
                                 bbox=[0, 0.3, 1, 0.6])

        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 2)

        # Highlight —á–µ—Å—Ç–Ω–æ–µ —Å—Ä–∞–≤–Ω–µ–Ω–∏–µ
        for i, arch in enumerate(arch_names):
            if '128' in arch:
                for j in range(4):
                    table[(i+1, j)].set_facecolor('lightyellow')
                    table[(i+1, j)].set_edgecolor('orange')
                    table[(i+1, j)].set_linewidth(2)

        axes[1, 1].text(0.5, 0.15, '‚ö†Ô∏è Yellow rows: Fair comparison (same capacity)',
                       ha='center', fontsize=10, style='italic',
                       bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

        axes[1, 1].set_title('Summary Table', fontweight='bold', fontsize=13)

        plt.suptitle(f'Fair Baseline Comparison (N={self.num_runs} runs)',
                     fontsize=16, fontweight='bold')
        plt.tight_layout()

        plt.savefig('fair_baseline_comparison.png', dpi=300, bbox_inches='tight')
        print("‚úì Saved: fair_baseline_comparison.png")
        plt.close()

    def test_04_create_latex_table(self):
        """Create LaTeX table for paper"""
        print("\n" + "="*80)
        print("LATEX TABLE FOR PAPER (Fair Comparison)")
        print("="*80)

        print("\n\\begin{table}[h]")
        print("\\centering")
        print("\\caption{Fair Baseline Comparison on Logistic Map Dataset}")
        print("\\label{tab:fair_comparison}")
        print("\\begin{tabular}{lcccc}")
        print("\\hline")
        print("Architecture & Latent Dim & Variance & Dead Neurons & Val Loss \\\\")
        print("\\hline")

        for arch_name, runs in self.results.items():
            variances = [r['variance'] for r in runs]
            dead_pcts = [r['dead_percentage'] for r in runs]
            losses = [r['val_loss'] for r in runs]

            var_mean = np.mean(variances)
            var_std = np.std(variances)
            dead_pct = np.mean(dead_pcts)
            loss_mean = np.mean(losses)
            loss_std = np.std(losses)

            latent_dim = 64 if '64' in arch_name else 128

            # Format name
            if 'Dense_ReLU_64' in arch_name:
                name = "Dense ReLU"
            elif 'Dense_ReLU_128' in arch_name:
                name = "Dense ReLU (fair)"
            else:
                name = "\\textbf{K-Sparse Chaos (V4)}"

            print(f"{name} & {latent_dim} & "
                  f"${var_mean:.3f} \\pm {var_std:.3f}$ & "
                  f"{dead_pct:.1%} & "
                  f"${loss_mean:.4f} \\pm {loss_std:.4f}$ \\\\")

        print("\\hline")

        # Add improvement
        if 'Dense_ReLU_128' in self.results and 'V4_KSparse_128' in self.results:
            vars_dense = [r['variance'] for r in self.results['Dense_ReLU_128']]
            vars_v4 = [r['variance'] for r in self.results['V4_KSparse_128']]

            dead_dense = np.mean([r['dead_percentage'] for r in self.results['Dense_ReLU_128']])
            dead_v4 = np.mean([r['dead_percentage'] for r in self.results['V4_KSparse_128']])

            improvement = np.mean(vars_v4) / np.mean(vars_dense)

            print("\\hline")
            print(f"\\multicolumn{{5}}{{l}}{{\\textit{{Fair improvement (V4 vs Dense-128): "
                  f"{improvement:.2f}√ó variance, {dead_v4:.0%} vs {dead_dense:.0%} dead neurons}}}} \\\\")

        print("\\end{tabular}")
        print("\\end{table}")

        print("\n" + "="*80)
        print("üìã –°–∫–æ–ø–∏—Ä—É–π—Ç–µ —ç—Ç—É —Ç–∞–±–ª–∏—Ü—É –≤ –≤–∞—à—É —Å—Ç–∞—Ç—å—é!")
        print("="*80)

############################################
# TEST CLASS 3: MULTIPLE RUNS (—Å—Ç–∞—Ä—ã–π)
############################################

class TestMultipleRuns(unittest.TestCase):
    """
    CRITICAL EXPERIMENT #3: Multiple Runs for Statistical Significance
    (–°—Ç–∞—Ä—ã–π —Ç–µ—Å—Ç –¥–ª—è —Å–æ–≤–º–µ—Å—Ç–∏–º–æ—Å—Ç–∏ - –∏—Å–ø–æ–ª—å–∑—É–π—Ç–µ TestFairBaselineComparison –≤–º–µ—Å—Ç–æ —ç—Ç–æ–≥–æ)
    """

    @classmethod
    def setUpClass(cls):
        print("\n" + "="*80)
        print("MULTIPLE RUNS FOR STATISTICAL SIGNIFICANCE (Legacy)")
        print("="*80)
        print("‚ö†Ô∏è –ò—Å–ø–æ–ª—å–∑—É–π—Ç–µ TestFairBaselineComparison –¥–ª—è —á–µ—Å—Ç–Ω–æ–≥–æ —Å—Ä–∞–≤–Ω–µ–Ω–∏—è")

        cls.num_runs = 5
        cls.architectures = {
            'Dense_ReLU': lambda: build_dense_relu_ae(latent_dim=64),
            'KSparse_Chaos': lambda: build_ksparse_chaos_ae(latent_dim=128, k_active=32)
        }

        cls.results = {name: [] for name in cls.architectures.keys()}

        # Generate data once
        cls.train_images = generate_logistic_dataset(2000, fixed_initial=False)
        cls.test_images = generate_logistic_dataset(500, fixed_initial=False)

        # Run experiments
        for arch_name, builder in cls.architectures.items():
            print(f"\n{'='*70}")
            print(f"Architecture: {arch_name}")
            print(f"{'='*70}")

            for run in range(cls.num_runs):
                print(f"\n[Run {run+1}/{cls.num_runs}]")

                # Set seeds
                np.random.seed(run)
                tf.random.set_seed(run)

                # Build and train
                ae, enc = builder()
                history = ae.fit(
                    cls.train_images, cls.train_images,
                    epochs=10,
                    batch_size=64,
                    validation_split=0.1,
                    verbose=0
                )

                # Analyze
                stats = analyze_latent_statistics(enc, cls.test_images)

                cls.results[arch_name].append({
                    'run': run,
                    'stats': stats,
                    'val_loss': history.history['val_loss'][-1]
                })

                print(f"  Variance: {stats['mean_variance']:.6f}")
                print(f"  Dead neurons: {stats['dead_neurons']}/{stats['total_neurons']}")
                print(f"  Val loss: {history.history['val_loss'][-1]:.6f}")

    def test_01_compute_statistics(self):
        """Compute mean, std, confidence intervals"""
        print("\n" + "="*80)
        print("TEST: Statistical Summary")
        print("="*80)

        for arch_name, runs in self.results.items():
            print(f"\n{arch_name}:")
            print(f"{'='*50}")

            # Extract metrics
            variances = [r['stats']['mean_variance'] for r in runs]
            dead_counts = [r['stats']['dead_neurons'] for r in runs]
            losses = [r['val_loss'] for r in runs]

            # Compute statistics
            var_mean = np.mean(variances)
            var_std = np.std(variances)
            var_ci = 1.96 * var_std / np.sqrt(len(variances))

            dead_mean = np.mean(dead_counts)
            dead_std = np.std(dead_counts)

            loss_mean = np.mean(losses)
            loss_std = np.std(losses)

            print(f"Variance:     {var_mean:.6f} ¬± {var_std:.6f} (CI: ¬±{var_ci:.6f})")
            print(f"Dead neurons: {dead_mean:.1f} ¬± {dead_std:.1f}")
            print(f"Val loss:     {loss_mean:.6f} ¬± {loss_std:.6f}")

            if not hasattr(self, 'summary'):
                self.summary = {}
            self.summary[arch_name] = {
                'variance_mean': var_mean,
                'variance_std': var_std,
                'variance_ci': var_ci,
                'dead_mean': dead_mean,
                'loss_mean': loss_mean
            }

    def test_02_statistical_comparison(self):
        """Compare architectures with t-test"""
        print("\n" + "="*80)
        print("TEST: Statistical Comparison (t-test)")
        print("="*80)

        arch_names = list(self.results.keys())

        if len(arch_names) >= 2:
            arch1, arch2 = arch_names[0], arch_names[1]

            vars1 = [r['stats']['mean_variance'] for r in self.results[arch1]]
            vars2 = [r['stats']['mean_variance'] for r in self.results[arch2]]

            t_stat, p_value = stats.ttest_ind(vars1, vars2)

            print(f"\nComparing {arch1} vs {arch2}:")
            print(f"  t-statistic: {t_stat:.4f}")
            print(f"  p-value: {p_value:.6f}")

            if p_value < 0.05:
                print(f"  ‚úì SIGNIFICANT difference (p < 0.05)")
                if np.mean(vars2) > np.mean(vars1):
                    print(f"  ‚Üí {arch2} has significantly higher variance")
            else:
                print(f"  ‚úó No significant difference (p >= 0.05)")

    def test_03_create_errorbar_plot(self):
        """Create plot with error bars"""
        print("\n" + "="*80)
        print("Creating Error Bar Plots")
        print("="*80)

        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        arch_names = list(self.results.keys())
        x = np.arange(len(arch_names))

        # Collect data
        variances_mean = []
        variances_std = []
        dead_mean = []
        dead_std = []
        loss_mean = []
        loss_std = []

        for arch in arch_names:
            vars_list = [r['stats']['mean_variance'] for r in self.results[arch]]
            dead_list = [r['stats']['dead_neurons'] for r in self.results[arch]]
            loss_list = [r['val_loss'] for r in self.results[arch]]

            variances_mean.append(np.mean(vars_list))
            variances_std.append(np.std(vars_list))
            dead_mean.append(np.mean(dead_list))
            dead_std.append(np.std(dead_list))
            loss_mean.append(np.mean(loss_list))
            loss_std.append(np.std(loss_list))

        # Plot 1: Variance
        axes[0].bar(x, variances_mean, yerr=variances_std, capsize=5, alpha=0.7)
        axes[0].set_xticks(x)
        axes[0].set_xticklabels(arch_names, rotation=45, ha='right')
        axes[0].set_ylabel('Mean Variance')
        axes[0].set_title('Variance Comparison\n(with std error bars)', fontweight='bold')
        axes[0].grid(True, alpha=0.3, axis='y')

        # Plot 2: Dead Neurons
        axes[1].bar(x, dead_mean, yerr=dead_std, capsize=5, alpha=0.7, color='orange')
        axes[1].set_xticks(x)
        axes[1].set_xticklabels(arch_names, rotation=45, ha='right')
        axes[1].set_ylabel('Dead Neurons')
        axes[1].set_title('Dead Neurons Comparison\n(with std error bars)', fontweight='bold')
        axes[1].grid(True, alpha=0.3, axis='y')

        # Plot 3: Loss
        axes[2].bar(x, loss_mean, yerr=loss_std, capsize=5, alpha=0.7, color='green')
        axes[2].set_xticks(x)
        axes[2].set_xticklabels(arch_names, rotation=45, ha='right')
        axes[2].set_ylabel('Validation Loss (MSE)')
        axes[2].set_title('Reconstruction Loss Comparison\n(with std error bars)', fontweight='bold')
        axes[2].grid(True, alpha=0.3, axis='y')

        plt.suptitle(f'Multiple Runs Comparison (N={self.num_runs} runs)',
                    fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.savefig('multiple_runs_comparison.png', dpi=300, bbox_inches='tight')
        print("‚úì Saved: multiple_runs_comparison.png")
        plt.close()


class TestHenonGeneralization(unittest.TestCase):
    """
    CRITICAL EXPERIMENT #4: Henon Map Dataset
    Test if results generalize to different chaotic system
    """

    @classmethod
    def setUpClass(cls):
        print("\n" + "="*80)
        print("HENON MAP GENERALIZATION TEST")
        print("="*80)

        # Generate both datasets
        cls.logistic_train = generate_logistic_dataset(2000, fixed_initial=False)
        cls.henon_train = generate_henon_dataset(2000)

        cls.logistic_test = generate_logistic_dataset(500, fixed_initial=False)
        cls.henon_test = generate_henon_dataset(500)

        print(f"Logistic train: {cls.logistic_train.shape}")
        print(f"Henon train: {cls.henon_train.shape}")

        cls.results = {}

        # Train on both datasets
        for dataset_name, train_data, test_data in [
            ('Logistic', cls.logistic_train, cls.logistic_test),
            ('Henon', cls.henon_train, cls.henon_test)
        ]:
            print(f"\n{'='*70}")
            print(f"Training K-Sparse Chaos on {dataset_name} Map")
            print(f"{'='*70}")

            ae, enc = build_ksparse_chaos_ae(latent_dim=128, k_active=32)

            history = ae.fit(
                train_data, train_data,
                epochs=10,
                batch_size=64,
                validation_split=0.1,
                verbose=1
            )

            stats = analyze_latent_statistics(enc, test_data)

            cls.results[dataset_name] = {
                'autoencoder': ae,
                'encoder': enc,
                'stats': stats,
                'val_loss': history.history['val_loss'][-1]
            }

            print(f"\nResults:")
            print(f"  Variance: {stats['mean_variance']:.6f}")
            print(f"  Dead neurons: {stats['dead_neurons']}/{stats['total_neurons']}")
            print(f"  Val loss: {history.history['val_loss'][-1]:.6f}")

    def test_01_compare_datasets(self):
        """Compare results across datasets"""
        print("\n" + "="*80)
        print("TEST: Logistic vs Henon Comparison")
        print("="*80)

        print(f"\n{'Metric':<25} {'Logistic':<15} {'Henon':<15} {'Ratio'}")
        print("-" * 60)

        log_var = self.results['Logistic']['stats']['mean_variance']
        hen_var = self.results['Henon']['stats']['mean_variance']
        print(f"{'Variance':<25} {log_var:<15.6f} {hen_var:<15.6f} {hen_var/log_var:.2f}√ó")

        log_dead = self.results['Logistic']['stats']['dead_neurons']
        hen_dead = self.results['Henon']['stats']['dead_neurons']
        print(f"{'Dead Neurons':<25} {log_dead:<15} {hen_dead:<15} -")

        log_loss = self.results['Logistic']['val_loss']
        hen_loss = self.results['Henon']['val_loss']
        print(f"{'Val Loss':<25} {log_loss:<15.6f} {hen_loss:<15.6f} {hen_loss/log_loss:.2f}√ó")

        # Check consistency
        self.assertEqual(log_dead, 0, "Logistic: should have 0 dead neurons")
        self.assertEqual(hen_dead, 0, "Henon: should have 0 dead neurons")

        # Variance should be in same ballpark (within 3√ó)
        variance_ratio = max(log_var, hen_var) / min(log_var, hen_var)
        self.assertLess(variance_ratio, 3.0,
                       f"Variance too different between datasets: {variance_ratio:.2f}√ó")

    def test_02_visualize_latent_spaces(self):
        """Visualize and compare latent spaces"""
        print("\n" + "="*80)
        print("Visualizing Latent Spaces")
        print("="*80)

        fig, axes = plt.subplots(2, 2, figsize=(12, 10))

        for idx, (name, dataset) in enumerate([
            ('Logistic', self.logistic_test),
            ('Henon', self.henon_test)
        ]):
            enc = self.results[name]['encoder']
            latents = enc.predict(dataset[:100], verbose=0)

            # 2D projection
            axes[0, idx].scatter(latents[:, 0], latents[:, 1],
                               alpha=0.6, s=30, edgecolors='black', linewidths=0.5)
            axes[0, idx].set_xlabel('Dim 0')
            axes[0, idx].set_ylabel('Dim 1')
            axes[0, idx].set_title(f'{name} Map\nLatent Space (first 2 dims)',
                                  fontweight='bold')
            axes[0, idx].grid(True, alpha=0.3)

            # Variance distribution
            variance = np.var(latents, axis=0)
            axes[1, idx].hist(np.log10(variance + 1e-10), bins=20, alpha=0.7,
                            edgecolor='black')
            axes[1, idx].set_xlabel('Log10(Variance)')
            axes[1, idx].set_ylabel('Frequency')
            axes[1, idx].set_title(f'{name} Map\nVariance Distribution',
                                  fontweight='bold')
            axes[1, idx].axvline(x=np.log10(self.results[name]['stats']['mean_variance']),
                               color='red', linestyle='--', linewidth=2, label='Mean')
            axes[1, idx].legend()
            axes[1, idx].grid(True, alpha=0.3)

        plt.suptitle('Generalization Test: Logistic vs Henon Map',
                    fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.savefig('henon_generalization.png', dpi=300, bbox_inches='tight')
        print("‚úì Saved: henon_generalization.png")
        plt.close()


class TestDeadNeuronTrajectory(unittest.TestCase):
    """
    EXPERIMENT #5: Track how neurons die during training
    """

    def test_01_track_broken_ae_death(self):
        """Track neuron death in broken L1 architecture"""
        print("\n" + "="*80)
        print("TEST: Tracking Neuron Death (Broken L1 AE)")
        print("="*80)

        self.skipTest("Requires build_sparse_ae_broken function")

    def test_02_track_chaos_ae_stability(self):
        """Verify neurons stay alive in Chaos AE"""
        print("\n" + "="*80)
        print("TEST: Neuron Stability (K-Sparse Chaos AE)")
        print("="*80)

        images = generate_logistic_dataset(1000, fixed_initial=False)

        ae, enc = build_ksparse_chaos_ae(latent_dim=128, k_active=32)

        trajectory = track_dead_neurons_over_time(ae, enc, images, epochs=30)

        # Plot trajectory
        epochs = [t['epoch'] for t in trajectory]
        dead = [t['dead_neurons'] for t in trajectory]
        variance = [t['mean_variance'] for t in trajectory]

        fig, axes = plt.subplots(1, 2, figsize=(12, 5))

        axes[0].plot(epochs, dead, 'o-', linewidth=2)
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Dead Neurons')
        axes[0].set_title('Dead Neurons Over Time', fontweight='bold')
        axes[0].grid(True, alpha=0.3)

        axes[1].plot(epochs, variance, 'o-', linewidth=2, color='green')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Mean Variance')
        axes[1].set_title('Variance Over Time', fontweight='bold')
        axes[1].grid(True, alpha=0.3)

        plt.suptitle('K-Sparse Chaos AE: Training Stability',
                    fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.savefig('training_stability.png', dpi=300, bbox_inches='tight')
        print("‚úì Saved: training_stability.png")
        plt.close()

        # Assertions
        final_dead = trajectory[-1]['dead_neurons']
        self.assertEqual(final_dead, 0, "Should maintain 0 dead neurons")

        final_variance = trajectory[-1]['mean_variance']
        initial_variance = trajectory[5]['mean_variance']
        self.assertGreater(final_variance, initial_variance * 0.8,
                          "Variance should not collapse during training")


def run_all_critical_experiments():
    """Run all critical experiments for paper"""
    print("\n" + "="*80)
    print("RUNNING ALL CRITICAL EXPERIMENTS (WITH FAIR BASELINE)")
    print("="*80)

    # Create test suite
    suite = unittest.TestSuite()

    # Add test classes in order of priority
    suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestKSparseAblation))
    suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestFairBaselineComparison))  # ‚Üê –ù–û–í–û–ï!
    suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestMultipleRuns))
    suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestHenonGeneralization))
    suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestDeadNeuronTrajectory))

    # Run with detailed output
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)

    # Summary
    print("\n" + "="*80)
    print("EXPERIMENT SUMMARY")
    print("="*80)
    print(f"Total tests run: {result.testsRun}")
    print(f"Successes: {result.testsRun - len(result.failures) - len(result.errors)}")
    print(f"Failures: {len(result.failures)}")
    print(f"Errors: {len(result.errors)}")

    if result.wasSuccessful():
        print("\n‚úÖ ALL CRITICAL EXPERIMENTS PASSED!")
        print("\nGenerated files:")
        print("  - k_sparse_ablation.png")
        print("  - fair_baseline_comparison.png")
        print("  - multiple_runs_comparison.png")
        print("  - henon_generalization.png")
        print("  - training_stability.png")
        print("\nüéâ Ready for paper with FAIR BASELINE COMPARISON!")
    else:
        print("\n‚ö†Ô∏è Some experiments failed. Review output above.")

    return result.wasSuccessful()

if __name__ == '__main__':
    success = run_all_critical_experiments()
    exit(0 if success else 1)


RUNNING ALL CRITICAL EXPERIMENTS (WITH FAIR BASELINE)

K-SPARSE ABLATION STUDY

Testing K values: [4, 8, 16, 32, 64, 96, 112]
Latent dimension: 128
Sparsity range: 12.5% to 96.9%

[K=4] Training K-Sparse Chaos AE...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
  Variance: 0.068163
  Dead neurons: 0/128
  Sparsity: 96.9%

[K=8] Training K-Sparse Chaos AE...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
  Variance: 0.130711
  Dead neurons: 0/128
  Sparsity: 93.8%

[K=16] Training K-Sparse Chaos AE...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
  Variance: 0.238414
  Dead neurons: 0/128
  Sparsity: 87.5%

[K=32] Training K-Sparse Chaos AE...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
  Variance: 0.417582
  Dead neurons: 0/128


test_01_variance_vs_k (__main__.TestKSparseAblation)
Test how variance changes with K ... ok
test_02_find_optimal_k (__main__.TestKSparseAblation)
Find K that balances variance and sparsity ... ok
test_03_create_ablation_plot (__main__.TestKSparseAblation)
Create comprehensive ablation visualization ... 

  Variance: 0.660226
  Dead neurons: 0/128
  Sparsity: 12.5%

TEST: Variance vs K
K=  4: variance = 0.068163
K=  8: variance = 0.130711
K= 16: variance = 0.238414
K= 32: variance = 0.417582
K= 64: variance = 0.571774
K= 96: variance = 0.651265
K=112: variance = 0.660226

Trend violations: 0/6

TEST: Find Optimal K
K=  4: score = 1.652 (var=0.068, sparsity=96.9%)
K=  8: score = 1.803 (var=0.131, sparsity=93.8%)
K= 16: score = 1.766 (var=0.238, sparsity=87.5%)
K= 32: score = 1.606 (var=0.418, sparsity=75.0%)
K= 64: score = 1.121 (var=0.572, sparsity=50.0%)
K= 96: score = 0.857 (var=0.651, sparsity=25.0%)
K=112: score = 0.746 (var=0.660, sparsity=12.5%)

‚≠ê Optimal K: 8
   Variance: 0.130711
   Sparsity: 93.8%

Creating K-Sparse Ablation Plots


ok


‚úì Saved: k_sparse_ablation.png

ABLATION RESULTS TABLE (LaTeX format)
\begin{table}[h]
\centering
\begin{tabular}{ccccc}
\hline
K & Sparsity & Variance & Dead Neurons & Val Loss \\
\hline
4 & 96.9% & 0.0682 & 0/128 & 0.1187 \\
8 & 93.8% & 0.1307 & 0/128 & 0.1186 \\
16 & 87.5% & 0.2384 & 0/128 & 0.1187 \\
32 & 75.0% & 0.4176 & 0/128 & 0.1197 \\
64 & 50.0% & 0.5718 & 0/128 & 0.1206 \\
96 & 25.0% & 0.6513 & 0/128 & 0.1214 \\
112 & 12.5% & 0.6602 & 0/128 & 0.1218 \\
\hline
\end{tabular}
\caption{K-Sparse ablation study results}
\end{table}

üéØ FAIR BASELINE COMPARISON (Dense_64 vs Dense_128 vs V4_128)

üìä –ó–∞–ø—É—Å–∫–∞–µ–º N=10 runs –¥–ª—è –∫–∞–∂–¥–æ–π –∞—Ä—Ö–∏—Ç–µ–∫—Ç—É—Ä—ã...
   Train: (2000, 28, 28, 1)
   Test: (500, 28, 28, 1)

‚ö†Ô∏è –≠—Ç–æ –∑–∞–π–º–µ—Ç ~2-3 —á–∞—Å–∞. –ù–∞–±–µ—Ä–∏—Ç–µ—Å—å —Ç–µ—Ä–ø–µ–Ω–∏—è!


[1/3] –ê—Ä—Ö–∏—Ç–µ–∫—Ç—É—Ä–∞: Dense_ReLU_64
  [Run 1/10] var=0.0783, dead=27/64 (42%), loss=0.1137
  [Run 2/10] var=0.1069, dead=24/64 (38%), loss=0.1136
  [Run 3/10] var=0

test_01_compute_statistics (__main__.TestFairBaselineComparison)
Compute mean, std, confidence intervals ... ok
test_02_fair_statistical_comparison (__main__.TestFairBaselineComparison)
Compare Dense_128 vs V4_128 (FAIR) ... ERROR
test_03_create_fair_comparison_plots (__main__.TestFairBaselineComparison)
Create detailed comparison visualization ... 

var=0.4180, dead=0/128 (0%), loss=0.1198

TEST: Statistical Summary

Dense_ReLU_64:
Variance:     0.091792 ¬± 0.007784 (CI: ¬±0.004825)
Dead neurons: 28.5 ¬± 3.4 (44.5%)
Val loss:     0.113511 ¬± 0.000145

Dense_ReLU_128:
Variance:     0.061902 ¬± 0.005692 (CI: ¬±0.003528)
Dead neurons: 45.3 ¬± 6.6 (35.4%)
Val loss:     0.113181 ¬± 0.000160

V4_KSparse_128:
Variance:     0.417723 ¬± 0.002239 (CI: ¬±0.001388)
Dead neurons: 0.0 ¬± 0.0 (0.0%)
Val loss:     0.119851 ¬± 0.000091

TEST: FAIR Statistical Comparison (t-tests)

Creating Fair Comparison Plots


ok
test_04_create_latex_table (__main__.TestFairBaselineComparison)
Create LaTeX table for paper ... ok


‚úì Saved: fair_baseline_comparison.png

LATEX TABLE FOR PAPER (Fair Comparison)

\begin{table}[h]
\centering
\caption{Fair Baseline Comparison on Logistic Map Dataset}
\label{tab:fair_comparison}
\begin{tabular}{lcccc}
\hline
Architecture & Latent Dim & Variance & Dead Neurons & Val Loss \\
\hline
Dense ReLU & 64 & $0.092 \pm 0.008$ & 44.5% & $0.1135 \pm 0.0001$ \\
Dense ReLU (fair) & 128 & $0.062 \pm 0.006$ & 35.4% & $0.1132 \pm 0.0002$ \\
\textbf{K-Sparse Chaos (V4)} & 128 & $0.418 \pm 0.002$ & 0.0% & $0.1199 \pm 0.0001$ \\
\hline
\hline
\multicolumn{5}{l}{\textit{Fair improvement (V4 vs Dense-128): 6.75√ó variance, 0% vs 35% dead neurons}} \\
\end{tabular}
\end{table}

üìã –°–∫–æ–ø–∏—Ä—É–π—Ç–µ —ç—Ç—É —Ç–∞–±–ª–∏—Ü—É –≤ –≤–∞—à—É —Å—Ç–∞—Ç—å—é!

MULTIPLE RUNS FOR STATISTICAL SIGNIFICANCE (Legacy)
‚ö†Ô∏è –ò—Å–ø–æ–ª—å–∑—É–π—Ç–µ TestFairBaselineComparison –¥–ª—è —á–µ—Å—Ç–Ω–æ–≥–æ —Å—Ä–∞–≤–Ω–µ–Ω–∏—è

Architecture: Dense_ReLU

[Run 1/5]
  Variance: 0.070269
  Dead neurons: 26/64
  Val loss:

test_01_compute_statistics (__main__.TestMultipleRuns)
Compute mean, std, confidence intervals ... ok
test_02_statistical_comparison (__main__.TestMultipleRuns)
Compare architectures with t-test ... ok
test_03_create_errorbar_plot (__main__.TestMultipleRuns)
Create plot with error bars ... 

  Variance: 0.420392
  Dead neurons: 0/128
  Val loss: 0.119869

TEST: Statistical Summary

Dense_ReLU:
Variance:     0.076053 ¬± 0.008961 (CI: ¬±0.007855)
Dead neurons: 29.2 ¬± 3.6
Val loss:     0.113784 ¬± 0.000094

KSparse_Chaos:
Variance:     0.418768 ¬± 0.000956 (CI: ¬±0.000838)
Dead neurons: 0.0 ¬± 0.0
Val loss:     0.119885 ¬± 0.000083

TEST: Statistical Comparison (t-test)

Comparing Dense_ReLU vs KSparse_Chaos:
  t-statistic: -76.0567
  p-value: 0.000000
  ‚úì SIGNIFICANT difference (p < 0.05)
  ‚Üí KSparse_Chaos has significantly higher variance

Creating Error Bar Plots


ok


‚úì Saved: multiple_runs_comparison.png

HENON MAP GENERALIZATION TEST
Logistic train: (2000, 28, 28, 1)
Henon train: (2000, 28, 28, 1)

Training K-Sparse Chaos on Logistic Map
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10

Results:
  Variance: 0.417578
  Dead neurons: 0/128
  Val loss: 0.120130

Training K-Sparse Chaos on Henon Map
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


test_01_compare_datasets (__main__.TestHenonGeneralization)
Compare results across datasets ... ok
test_02_visualize_latent_spaces (__main__.TestHenonGeneralization)
Visualize and compare latent spaces ... 


Results:
  Variance: 0.422372
  Dead neurons: 0/128
  Val loss: 0.045649

TEST: Logistic vs Henon Comparison

Metric                    Logistic        Henon           Ratio
------------------------------------------------------------
Variance                  0.417578        0.422372        1.01√ó
Dead Neurons              0               0               -
Val Loss                  0.120130        0.045649        0.38√ó

Visualizing Latent Spaces


ok
test_01_track_broken_ae_death (__main__.TestDeadNeuronTrajectory)
Track neuron death in broken L1 architecture ... skipped 'Requires build_sparse_ae_broken function'
test_02_track_chaos_ae_stability (__main__.TestDeadNeuronTrajectory)
Verify neurons stay alive in Chaos AE ... 

‚úì Saved: henon_generalization.png

TEST: Tracking Neuron Death (Broken L1 AE)

TEST: Neuron Stability (K-Sparse Chaos AE)


ok

ERROR: test_02_fair_statistical_comparison (__main__.TestFairBaselineComparison)
Compare Dense_128 vs V4_128 (FAIR)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\Admin\AppData\Local\Temp\ipykernel_5884\3247979534.py", line 587, in test_02_fair_statistical_comparison
    if 'Dense_ReLU_64' in self.summary and 'Dense_ReLU_128' in self.summary:
AttributeError: 'TestFairBaselineComparison' object has no attribute 'summary'

----------------------------------------------------------------------
Ran 14 tests in 275.970s

FAILED (errors=1, skipped=1)


‚úì Saved: training_stability.png

EXPERIMENT SUMMARY
Total tests run: 14
Successes: 13
Failures: 0
Errors: 1

‚ö†Ô∏è Some experiments failed. Review output above.
