# Split-MNIST: 30-Seed Run with Simple MLP

## üéØ Goal: Robust Statistics for Paper

**Context**: Ablation study showed NN1-Simple (89.1%) > NN1-Similarity (87.9%)

**This Experiment**:
- Run 30 independent seeds for robust statistics
- Report mean ¬± std for retention accuracy
- Statistical significance testing
- Compare to baseline (no consolidation)

**Expected Results**:
- Mean retention: ~89% (matching ablation)
- Std dev: ~1.5%
- Significant improvement over baseline

**Timeline**: ~4 days to complete

---

**Date**: November 9, 2025  
**Priority**: CRITICAL (needed for paper revision)

In [2]:
import sys
sys.path.insert(0, '../..')

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from datetime import datetime
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import random
from tqdm.auto import tqdm
from scipy import stats

from src.models import (
    NN1_SimpleMLP,
    NN2_ConsolidationNet,
    ReplayBuffer,
    evaluate_models,
    train_task_with_replay,
    consolidate_nn2
)

sns.set_style('whitegrid')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"üñ•Ô∏è  Device: {device}")
print(f"üìÖ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"üî• PyTorch: {torch.__version__}")
print("\n‚≠ê‚≠ê‚≠ê CRITICAL: 30-Seed Split-MNIST Experiment ‚≠ê‚≠ê‚≠ê\n")

üñ•Ô∏è  Device: cpu
üìÖ 2025-11-09 17:10:26
üî• PyTorch: 2.9.0+cu128

‚≠ê‚≠ê‚≠ê CRITICAL: 30-Seed Split-MNIST Experiment ‚≠ê‚≠ê‚≠ê



  from .autonotebook import tqdm as notebook_tqdm


## 1. Load Split-MNIST Dataset

In [3]:
# Download MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('../../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../../data', train=False, download=True, transform=transform)

# Split into 5 tasks (2 digits each)
def create_task_split(dataset, digit_pairs):
    """Create dataset subset for specific digit pairs"""
    indices = []
    for idx, (img, label) in enumerate(dataset):
        if label in digit_pairs:
            indices.append(idx)
    return Subset(dataset, indices)

tasks = [
    ([0, 1], "Task 1: Digits 0-1"),
    ([2, 3], "Task 2: Digits 2-3"),
    ([4, 5], "Task 3: Digits 4-5"),
    ([6, 7], "Task 4: Digits 6-7"),
    ([8, 9], "Task 5: Digits 8-9"),
]

train_tasks = [create_task_split(train_dataset, digits) for digits, _ in tasks]
test_tasks = [create_task_split(test_dataset, digits) for digits, _ in tasks]

print("‚úÖ Split-MNIST Created:")
for i, ((digits, name), train_task) in enumerate(zip(tasks, train_tasks)):
    print(f"   {name}: {len(train_task)} train samples")

‚úÖ Split-MNIST Created:
   Task 1: Digits 0-1: 12665 train samples
   Task 2: Digits 2-3: 12089 train samples
   Task 3: Digits 4-5: 11263 train samples
   Task 4: Digits 6-7: 12183 train samples
   Task 5: Digits 8-9: 11800 train samples


## 2. Single-Seed Experiment Function

In [None]:
def run_single_seed(seed, verbose=False):
    """
    Run one complete continual learning experiment
    
    Returns:
        dict with final retention accuracies for NN1 and NN2
    """
    # Set seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    # Initialize models
    nn1 = NN1_SimpleMLP(
        in_dim=784,
        neuron_dim=64,
        num_classes=10
    ).to(device)
    
    nn2 = NN2_ConsolidationNet(
        in_dim=784,
        summary_dim=64,
        num_classes=10
    ).to(device)
    
    # Optimizers
    opt1 = torch.optim.Adam(nn1.parameters(), lr=1e-3)
    opt2 = torch.optim.Adam(nn2.parameters(), lr=5e-4)
    
    # Loss functions
    ce_loss = nn.CrossEntropyLoss()
    kl_loss = nn.KLDivLoss(reduction='batchmean')
    
    # Replay buffer
    replay_buffer = ReplayBuffer(buffer_size_per_task=200)
    
    # Track results
    results = {
        'nn1_retention': [],  # Retention on all previous tasks
        'nn2_retention': [],
    }
    
    # Hyperparameters
    batch_size = 64
    epochs_per_task = 5
    
    if verbose:
        print(f"\n{'='*60}")
        print(f"üå± Seed {seed}")
        print(f"{'='*60}")
    
    # Train on each task sequentially
    for task_id, (train_task, test_task) in enumerate(zip(train_tasks, test_tasks)):
        if verbose:
            print(f"\nüìö {tasks[task_id][1]}")
        
        # Create data loaders
        train_loader = DataLoader(train_task, batch_size=batch_size, shuffle=True)
        
        # Train with replay
        train_task_with_replay(
            nn1, nn2,
            train_loader,
            replay_buffer.get_dataset(),
            opt1, opt2,
            ce_loss, kl_loss,
            device=device,
            epochs=epochs_per_task,
            consolidation_interval=10,
            lambda_distill=0.3,
            temperature=2.0,
            grad_clip=1.0,
            replay_ratio=0.3
        )
        
        # Add to replay buffer
        replay_buffer.add_task(train_task)
        
        # Consolidate NN2
        consolidate_nn2(
            nn1, nn2,
            replay_buffer.get_dataset(),
            opt2,
            ce_loss, kl_loss,
            device=device,
            consolidation_epochs=2,
            batch_size=64,
            lambda_distill=0.5,
            temperature=2.0,
            grad_clip=1.0
        )
        
        # Evaluate retention on ALL tasks seen so far
        if task_id >= 1:  # After task 2+
            all_test_data = []
            for prev_task_id in range(task_id + 1):
                all_test_data.extend(test_tasks[prev_task_id])
            
            test_loader = DataLoader(all_test_data, batch_size=128, shuffle=False)
            acc1, acc2 = evaluate_models(nn1, nn2, test_loader, device=device)
            
            results['nn1_retention'].append(acc1 * 100)
            results['nn2_retention'].append(acc2 * 100)
            
            if verbose:
                print(f"   üìä Retention (Tasks 1-{task_id+1}): NN1={acc1*100:.1f}%, NN2={acc2*100:.1f}%")
    
    # Return final retention (after task 5)
    return {
        'seed': seed,
        'nn1_final': results['nn1_retention'][-1],
        'nn2_final': results['nn2_retention'][-1],
        'nn1_all': results['nn1_retention'],
        'nn2_all': results['nn2_retention'],
    }

print("‚úÖ Experiment function ready")

## 3. Test with Single Seed (Sanity Check)

In [None]:
# Quick test with one seed
print("üß™ Testing with seed 42...")
test_result = run_single_seed(42, verbose=True)

print(f"\n‚úÖ Test Complete!")
print(f"   Final NN1 retention: {test_result['nn1_final']:.2f}%")
print(f"   Final NN2 retention: {test_result['nn2_final']:.2f}%")
print(f"\n   Expected: ~89% (matching ablation study)")

## 4. Run 30 Seeds (CRITICAL EXPERIMENT)

‚ö†Ô∏è **WARNING**: This will take ~4 days to complete!

**Timeline**:
- ~8 minutes per seed (5 tasks √ó 5 epochs √ó ~20s)
- 30 seeds √ó 8 min = 240 minutes = 4 hours (if sequential)
- With overhead: ~6-8 hours total

**Note**: If running on CPU, this could take much longer. Consider:
1. Using GPU if available
2. Running seeds in parallel (if multiple GPUs)
3. Running overnight

In [None]:
# Run 30 seeds
NUM_SEEDS = 30
seeds = list(range(42, 42 + NUM_SEEDS))  # Seeds 42-71

print(f"üöÄ Starting 30-seed run...")
print(f"   Seeds: {seeds[0]} to {seeds[-1]}")
print(f"   Estimated time: 4-8 hours\n")

all_results = []

for i, seed in enumerate(tqdm(seeds, desc="Seeds")):
    result = run_single_seed(seed, verbose=False)
    all_results.append(result)
    
    # Save intermediate results every 5 seeds
    if (i + 1) % 5 == 0:
        df_temp = pd.DataFrame(all_results)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        df_temp.to_csv(f"../../results/simple_mlp/csv/split_mnist_30seeds_partial_{i+1}_{timestamp}.csv", index=False)
        print(f"   üíæ Saved partial results ({i+1}/{NUM_SEEDS} seeds)")

print("\n‚úÖ 30-seed run complete!")

## 5. Statistical Analysis

In [None]:
# Convert to DataFrame
df_results = pd.DataFrame(all_results)

# Summary statistics
print("üìä RESULTS SUMMARY (30 Seeds)")
print("="*60)
print(f"\nNN1 Final Retention (After Task 5):")
print(f"   Mean: {df_results['nn1_final'].mean():.2f}%")
print(f"   Std:  {df_results['nn1_final'].std():.2f}%")
print(f"   Min:  {df_results['nn1_final'].min():.2f}%")
print(f"   Max:  {df_results['nn1_final'].max():.2f}%")

print(f"\nNN2 Final Retention (After Task 5):")
print(f"   Mean: {df_results['nn2_final'].mean():.2f}%")
print(f"   Std:  {df_results['nn2_final'].std():.2f}%")
print(f"   Min:  {df_results['nn2_final'].min():.2f}%")
print(f"   Max:  {df_results['nn2_final'].max():.2f}%")

# Confidence intervals (95%)
nn1_ci = stats.t.interval(0.95, len(df_results)-1, 
                          loc=df_results['nn1_final'].mean(),
                          scale=stats.sem(df_results['nn1_final']))
nn2_ci = stats.t.interval(0.95, len(df_results)-1,
                          loc=df_results['nn2_final'].mean(),
                          scale=stats.sem(df_results['nn2_final']))

print(f"\n95% Confidence Intervals:")
print(f"   NN1: [{nn1_ci[0]:.2f}%, {nn1_ci[1]:.2f}%]")
print(f"   NN2: [{nn2_ci[0]:.2f}%, {nn2_ci[1]:.2f}%]")

# Save full results
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
csv_path = f"../../results/simple_mlp/csv/split_mnist_30seeds_final_{timestamp}.csv"
df_results.to_csv(csv_path, index=False)
print(f"\nüíæ Results saved: {csv_path}")

## 6. Visualization

In [None]:
# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Distribution of final retention
ax = axes[0, 0]
ax.hist(df_results['nn1_final'], bins=15, alpha=0.7, label='NN1', color='steelblue', edgecolor='black')
ax.hist(df_results['nn2_final'], bins=15, alpha=0.7, label='NN2', color='coral', edgecolor='black')
ax.axvline(df_results['nn1_final'].mean(), color='steelblue', linestyle='--', linewidth=2)
ax.axvline(df_results['nn2_final'].mean(), color='coral', linestyle='--', linewidth=2)
ax.set_xlabel('Final Retention (%)', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Distribution of Final Retention (30 Seeds)', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Plot 2: Retention across tasks
ax = axes[0, 1]
task_labels = ['T2', 'T3', 'T4', 'T5']
nn1_means = [np.mean([r['nn1_all'][i] for r in all_results]) for i in range(4)]
nn2_means = [np.mean([r['nn2_all'][i] for r in all_results]) for i in range(4)]
nn1_stds = [np.std([r['nn1_all'][i] for r in all_results]) for i in range(4)]
nn2_stds = [np.std([r['nn2_all'][i] for r in all_results]) for i in range(4)]

x = np.arange(len(task_labels))
ax.errorbar(x, nn1_means, yerr=nn1_stds, marker='o', capsize=5, capthick=2, 
            linewidth=2, markersize=8, label='NN1 (Simple MLP)', color='steelblue')
ax.errorbar(x, nn2_means, yerr=nn2_stds, marker='s', capsize=5, capthick=2,
            linewidth=2, markersize=8, label='NN2 (Consolidation)', color='coral')
ax.set_xticks(x)
ax.set_xticklabels(task_labels)
ax.set_xlabel('After Task', fontsize=12)
ax.set_ylabel('Retention Accuracy (%)', fontsize=12)
ax.set_title('Retention Across Tasks (Mean ¬± Std)', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_ylim([75, 100])

# Plot 3: Box plot comparison
ax = axes[1, 0]
data_to_plot = [df_results['nn1_final'], df_results['nn2_final']]
bp = ax.boxplot(data_to_plot, labels=['NN1', 'NN2'], patch_artist=True,
                boxprops=dict(facecolor='lightblue', edgecolor='black'),
                medianprops=dict(color='red', linewidth=2),
                whiskerprops=dict(color='black'),
                capprops=dict(color='black'))
ax.set_ylabel('Final Retention (%)', fontsize=12)
ax.set_title('Box Plot: Final Retention Distribution', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Plot 4: Summary statistics table
ax = axes[1, 1]
ax.axis('off')
summary_text = f"""
30-SEED SPLIT-MNIST RESULTS
{'='*40}

NN1 (Simple MLP) Final Retention:
  Mean:  {df_results['nn1_final'].mean():.2f}%
  Std:   {df_results['nn1_final'].std():.2f}%
  95% CI: [{nn1_ci[0]:.2f}%, {nn1_ci[1]:.2f}%]

NN2 (Consolidation) Final Retention:
  Mean:  {df_results['nn2_final'].mean():.2f}%
  Std:   {df_results['nn2_final'].std():.2f}%
  95% CI: [{nn2_ci[0]:.2f}%, {nn2_ci[1]:.2f}%]

Comparison to Ablation Study:
  Ablation (1 seed):  89.1%
  This (30 seeds):    {df_results['nn1_final'].mean():.2f}% ¬± {df_results['nn1_final'].std():.2f}%
  
Status: {'‚úÖ VALIDATED' if abs(df_results['nn1_final'].mean() - 89.1) < 2 else '‚ö†Ô∏è INVESTIGATE'}
"""
ax.text(0.1, 0.5, summary_text, fontsize=11, family='monospace',
        verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

plt.tight_layout()
fig_path = f"../../results/simple_mlp/figures/split_mnist_30seeds_{timestamp}.png"
plt.savefig(fig_path, dpi=300, bbox_inches='tight')
print(f"üíæ Figure saved: {fig_path}")
plt.show()

## 7. Statistical Significance vs Baseline

Compare to a baseline without consolidation (NN1 only, no NN2)

In [None]:
# For paper: we need to show our method significantly outperforms baseline
# The baseline would be using NN1 without NN2 consolidation
# We can estimate this from the NN1-only results (which degrade faster)

# If you have baseline results, add them here for t-test comparison
# For now, we'll note the NN1 vs NN2 improvement

improvement = df_results['nn2_final'].mean() - df_results['nn1_final'].mean()
t_stat, p_value = stats.ttest_rel(df_results['nn2_final'], df_results['nn1_final'])

print("\nüìà NN2 vs NN1 Comparison:")
print(f"   Improvement: {improvement:+.2f}%")
print(f"   t-statistic: {t_stat:.3f}")
print(f"   p-value: {p_value:.6f}")
print(f"   Significance: {'‚úÖ Significant (p<0.05)' if p_value < 0.05 else '‚ö†Ô∏è Not significant'}")

print("\n" + "="*60)
print("üéâ EXPERIMENT COMPLETE!")
print("="*60)
print(f"\nKey Findings:")
print(f"  ‚Ä¢ Simple MLP retention: {df_results['nn1_final'].mean():.2f}% ¬± {df_results['nn1_final'].std():.2f}%")
print(f"  ‚Ä¢ Consolidation retention: {df_results['nn2_final'].mean():.2f}% ¬± {df_results['nn2_final'].std():.2f}%")
print(f"  ‚Ä¢ Matches ablation finding: {'‚úÖ YES' if abs(df_results['nn1_final'].mean() - 89.1) < 2 else '‚ö†Ô∏è NO'}")
print(f"\nNext Steps:")
print(f"  1. Update paper with these robust statistics")
print(f"  2. Run CIFAR-10 validation experiment")
print(f"  3. Update rebuttal with statistical analysis")