# Investigation: Œª=0.0 for Consolidation

## üî¨ Research Question

**Finding**: Hyperparameter sensitivity analysis showed Œª=0.0 (no distillation during consolidation) performed **best** (90.61% ¬± 0.34%)

**Current approach**: Uses Œª=0.5 during consolidation

**Hypothesis**: Replay alone during consolidation might be more effective than replay+distillation

---

## üéØ Experiments

1. **Split-MNIST (10 seeds)**: Compare Œª=0.0 vs Œª=0.5 consolidation
2. **Split-CIFAR-10 (5 seeds)**: Validate if Œª=0.0 improves CIFAR-10 results too
3. **Analysis**: Why does no distillation work better during consolidation?

---

**Date**: November 9, 2025  
**Priority**: HIGH (potential paper improvement)  
**Expected Runtime**: ~1.5 hours on Colab T4

In [None]:
import sys
sys.path.insert(0, '../..')

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from datetime import datetime
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import random
from tqdm.auto import tqdm
from scipy import stats

from src.models import (
    NN1_SimpleMLP,
    NN2_ConsolidationNet,
    ReplayBuffer,
    evaluate_models,
    train_task_with_replay,
    consolidate_nn2
)

sns.set_style('whitegrid')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"üñ•Ô∏è  Device: {device}")
print(f"üìÖ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"üî• PyTorch: {torch.__version__}")
print("\n‚≠ê Œª=0.0 Investigation ‚≠ê\n")

> ‚ÑπÔ∏è The shared training utilities now provide task-balanced replay sampling and support Œª=0.0 without computing teacher logits, so this notebook mirrors the production code path.

## 1. Load Datasets

In [None]:
# MNIST
transform_mnist = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_train = datasets.MNIST('../../data', train=True, download=True, transform=transform_mnist)
mnist_test = datasets.MNIST('../../data', train=False, download=True, transform=transform_mnist)

# CIFAR-10
transform_cifar_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

cifar_train = datasets.CIFAR10('../../data', train=True, download=True, transform=transform_cifar_train)
cifar_test = datasets.CIFAR10('../../data', train=False, download=True, transform=transform_cifar_test)

print("‚úÖ Datasets loaded")

In [None]:
# Create task splits
def create_task_split(dataset, class_pairs):
    indices = []
    for idx in range(len(dataset)):
        if hasattr(dataset, 'targets'):
            label = dataset.targets[idx]
        else:
            _, label = dataset[idx]
        if label in class_pairs:
            indices.append(idx)
    return Subset(dataset, indices)

# MNIST tasks
mnist_tasks_def = [
    ([0, 1], "Task 1: 0-1"),
    ([2, 3], "Task 2: 2-3"),
    ([4, 5], "Task 3: 4-5"),
    ([6, 7], "Task 4: 6-7"),
    ([8, 9], "Task 5: 8-9"),
]

mnist_train_tasks = [create_task_split(mnist_train, digits) for digits, _ in mnist_tasks_def]
mnist_test_tasks = [create_task_split(mnist_test, digits) for digits, _ in mnist_tasks_def]

# CIFAR-10 tasks
cifar_tasks_def = [
    ([0, 1], "Task 1: airplane, automobile"),
    ([2, 3], "Task 2: bird, cat"),
    ([4, 5], "Task 3: deer, dog"),
    ([6, 7], "Task 4: frog, horse"),
    ([8, 9], "Task 5: ship, truck"),
]

cifar_train_tasks = [create_task_split(cifar_train, classes) for classes, _ in cifar_tasks_def]
cifar_test_tasks = [create_task_split(cifar_test, classes) for classes, _ in cifar_tasks_def]

print("‚úÖ Task splits created")

## 2. Experiment Function (Configurable Œª)

In [None]:
def run_experiment(dataset='mnist', seed=42, consolidation_lambda=0.5, verbose=False):
    """
    Run FSC-Net with configurable consolidation lambda
    
    Args:
        dataset: 'mnist' or 'cifar10'
        seed: Random seed
        consolidation_lambda: Lambda for consolidation phase (0.0 = no distillation)
        verbose: Print progress
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    # Select dataset
    if dataset == 'mnist':
        train_tasks = mnist_train_tasks
        test_tasks = mnist_test_tasks
        in_dim = 784
    else:
        train_tasks = cifar_train_tasks
        test_tasks = cifar_test_tasks
        in_dim = 3072
    
    if verbose:
        print(f"\n{'='*60}")
        print(f"üå± {dataset.upper()} | Seed {seed} | Consolidation Œª={consolidation_lambda}")
        print(f"{'='*60}")
    
    # Initialize models
    nn1 = NN1_SimpleMLP(in_dim=in_dim, neuron_dim=64, num_classes=10).to(device)
    nn2 = NN2_ConsolidationNet(in_dim=in_dim, summary_dim=64, num_classes=10).to(device)
    
    opt1 = torch.optim.Adam(nn1.parameters(), lr=1e-3)
    opt2 = torch.optim.Adam(nn2.parameters(), lr=5e-4)
    ce_loss = nn.CrossEntropyLoss()
    kl_loss = nn.KLDivLoss(reduction='batchmean')
    
    replay_buffer = ReplayBuffer(buffer_size_per_task=200)
    
    # Train on each task
    for task_id, train_task in enumerate(train_tasks):
        train_loader = DataLoader(train_task, batch_size=64, shuffle=True)
        
        # Task training (use Œª=0.3 for task training - this stays the same)
        train_task_with_replay(
            nn1, nn2, train_loader, replay_buffer.get_dataset(),
            opt1, opt2, ce_loss, kl_loss,
            device=device, epochs=5,
            lambda_distill=0.3,  # Keep task training lambda at 0.3
            temperature=2.0
        )
        
        replay_buffer.add_task(train_task)
        
        # Offline consolidation (TEST different lambda here)
        consolidate_nn2(
            nn1, nn2, replay_buffer.get_dataset(),
            opt2, ce_loss, kl_loss,
            device=device,
            consolidation_epochs=2,
            lambda_distill=consolidation_lambda,  # TEST PARAMETER
            temperature=2.0
        )
    
    # Final evaluation
    all_test_data = []
    for task in test_tasks:
        all_test_data.extend(task)
    
    test_loader = DataLoader(all_test_data, batch_size=128, shuffle=False)
    acc1, acc2 = evaluate_models(nn1, nn2, test_loader, device=device)
    
    if verbose:
        print(f"üìä Final: NN1={acc1*100:.2f}%, NN2={acc2*100:.2f}%")
    
    return {
        'dataset': dataset,
        'seed': seed,
        'consolidation_lambda': consolidation_lambda,
        'nn1_final': acc1 * 100,
        'nn2_final': acc2 * 100,
        'improvement': (acc2 - acc1) * 100
    }

print("‚úÖ Experiment function ready")

## 3. Experiment 1: MNIST Comparison (10 seeds)

In [None]:
# Compare Œª=0.0 vs Œª=0.5 on MNIST
seeds = list(range(42, 52))  # 10 seeds
lambdas = [0.0, 0.5]

print("üî¨ MNIST: Œª=0.0 vs Œª=0.5 (10 seeds)\n")

mnist_results = []

for lam in lambdas:
    print(f"\nüìä Testing Œª={lam}...")
    for seed in tqdm(seeds, desc=f"Œª={lam}"):
        result = run_experiment(
            dataset='mnist',
            seed=seed,
            consolidation_lambda=lam,
            verbose=False
        )
        mnist_results.append(result)

df_mnist = pd.DataFrame(mnist_results)
print("\n‚úÖ MNIST experiments complete!")

## 4. Experiment 2: CIFAR-10 Comparison (5 seeds)

In [None]:
# Compare Œª=0.0 vs Œª=0.5 on CIFAR-10
seeds = [42, 43, 44, 45, 46]  # 5 seeds (same as validation)
lambdas = [0.0, 0.5]

print("üî¨ CIFAR-10: Œª=0.0 vs Œª=0.5 (5 seeds)\n")

cifar_results = []

for lam in lambdas:
    print(f"\nüìä Testing Œª={lam}...")
    for seed in tqdm(seeds, desc=f"Œª={lam}"):
        result = run_experiment(
            dataset='cifar10',
            seed=seed,
            consolidation_lambda=lam,
            verbose=False
        )
        cifar_results.append(result)

df_cifar = pd.DataFrame(cifar_results)
print("\n‚úÖ CIFAR-10 experiments complete!")

## 5. Statistical Analysis

In [None]:
def analyze_results(df, dataset_name):
    """Analyze and compare Œª=0.0 vs Œª=0.5"""
    
    lambda_0 = df[df['consolidation_lambda'] == 0.0]['nn2_final'].values
    lambda_5 = df[df['consolidation_lambda'] == 0.5]['nn2_final'].values
    
    mean_0, std_0 = np.mean(lambda_0), np.std(lambda_0, ddof=1)
    mean_5, std_5 = np.mean(lambda_5), np.std(lambda_5, ddof=1)
    
    # Paired t-test
    t_stat, p_value = stats.ttest_rel(lambda_0, lambda_5)
    
    # Effect size (Cohen's d)
    diff = lambda_0 - lambda_5
    cohens_d = np.mean(diff) / np.std(diff, ddof=1)
    
    print(f"\n{'='*60}")
    print(f"{dataset_name} RESULTS")
    print(f"{'='*60}")
    
    print(f"\nüìä Œª=0.0 (No distillation during consolidation):")
    print(f"   Mean: {mean_0:.2f}%")
    print(f"   Std:  {std_0:.2f}%")
    print(f"   Values: {[f'{v:.2f}' for v in lambda_0]}")
    
    print(f"\nüìä Œª=0.5 (Current approach):")
    print(f"   Mean: {mean_5:.2f}%")
    print(f"   Std:  {std_5:.2f}%")
    print(f"   Values: {[f'{v:.2f}' for v in lambda_5]}")
    
    print(f"\nüìà Comparison:")
    print(f"   Difference: {mean_0 - mean_5:+.2f}%")
    print(f"   t-statistic: {t_stat:.3f}")
    print(f"   p-value: {p_value:.4f}")
    print(f"   Cohen's d: {cohens_d:.3f}")
    print(f"   Significant (p<0.05): {'YES ‚úÖ' if p_value < 0.05 else 'NO ‚ùå'}")
    
    if mean_0 > mean_5:
        print(f"\nüéØ Verdict: Œª=0.0 is {'SIGNIFICANTLY ' if p_value < 0.05 else ''}BETTER by {mean_0 - mean_5:.2f}%")
    else:
        print(f"\nüéØ Verdict: Œª=0.5 is {'SIGNIFICANTLY ' if p_value < 0.05 else ''}BETTER by {mean_5 - mean_0:.2f}%")
    
    print(f"{'='*60}\n")
    
    return {
        'mean_lambda0': mean_0,
        'std_lambda0': std_0,
        'mean_lambda5': mean_5,
        'std_lambda5': std_5,
        'difference': mean_0 - mean_5,
        'p_value': p_value,
        'cohens_d': cohens_d
    }

mnist_stats = analyze_results(df_mnist, "MNIST")
cifar_stats = analyze_results(df_cifar, "CIFAR-10")

## 6. Visualization

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# MNIST comparison
ax = axes[0]
mnist_lambda0 = df_mnist[df_mnist['consolidation_lambda'] == 0.0]['nn2_final'].values
mnist_lambda5 = df_mnist[df_mnist['consolidation_lambda'] == 0.5]['nn2_final'].values

bp = ax.boxplot([mnist_lambda5, mnist_lambda0], positions=[1, 2], widths=0.6,
                 patch_artist=True, showmeans=True)
for patch, color in zip(bp['boxes'], ['lightcoral', 'lightgreen']):
    patch.set_facecolor(color)

ax.set_xticks([1, 2])
ax.set_xticklabels(['Œª=0.5\n(Current)', 'Œª=0.0\n(No Distill)'], fontweight='bold')
ax.set_ylabel('NN2 Retention (%)', fontsize=12, fontweight='bold')
ax.set_title(f'MNIST: Œª=0.0 vs Œª=0.5\n(p={mnist_stats["p_value"]:.4f})', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

# Add significance annotation
if mnist_stats['p_value'] < 0.05:
    y_max = max(max(mnist_lambda0), max(mnist_lambda5))
    ax.plot([1, 2], [y_max + 0.5, y_max + 0.5], 'k-', linewidth=1.5)
    ax.text(1.5, y_max + 0.7, '***' if mnist_stats['p_value'] < 0.001 else '**' if mnist_stats['p_value'] < 0.01 else '*',
            ha='center', fontsize=14, fontweight='bold')

# CIFAR-10 comparison
ax = axes[1]
cifar_lambda0 = df_cifar[df_cifar['consolidation_lambda'] == 0.0]['nn2_final'].values
cifar_lambda5 = df_cifar[df_cifar['consolidation_lambda'] == 0.5]['nn2_final'].values

bp = ax.boxplot([cifar_lambda5, cifar_lambda0], positions=[1, 2], widths=0.6,
                 patch_artist=True, showmeans=True)
for patch, color in zip(bp['boxes'], ['lightcoral', 'lightgreen']):
    patch.set_facecolor(color)

ax.set_xticks([1, 2])
ax.set_xticklabels(['Œª=0.5\n(Current)', 'Œª=0.0\n(No Distill)'], fontweight='bold')
ax.set_ylabel('NN2 Retention (%)', fontsize=12, fontweight='bold')
ax.set_title(f'CIFAR-10: Œª=0.0 vs Œª=0.5\n(p={cifar_stats["p_value"]:.4f})', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

# Add significance annotation
if cifar_stats['p_value'] < 0.05:
    y_max = max(max(cifar_lambda0), max(cifar_lambda5))
    ax.plot([1, 2], [y_max + 0.5, y_max + 0.5], 'k-', linewidth=1.5)
    ax.text(1.5, y_max + 0.7, '***' if cifar_stats['p_value'] < 0.001 else '**' if cifar_stats['p_value'] < 0.01 else '*',
            ha='center', fontsize=14, fontweight='bold')

plt.tight_layout()

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
plt.savefig(f'../../results/simple_mlp/figures/lambda_zero_investigation_{timestamp}.png',
            dpi=300, bbox_inches='tight')
print(f"üíæ Saved: results/simple_mlp/figures/lambda_zero_investigation_{timestamp}.png")
plt.show()

## 7. Save Results

In [None]:
# Combine results
df_all = pd.concat([df_mnist, df_cifar], ignore_index=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
csv_path = f'../../results/simple_mlp/csv/lambda_zero_investigation_{timestamp}.csv'
df_all.to_csv(csv_path, index=False)
print(f"üíæ Saved: {csv_path}")

print("\nüìä Combined Results:")
print(df_all.groupby(['dataset', 'consolidation_lambda'])['nn2_final'].agg(['mean', 'std', 'count']))

## 8. Analysis: Why Does Œª=0.0 Work Better?

In [None]:
print("="*70)
print("ANALYSIS: WHY DOES Œª=0.0 WORK BETTER DURING CONSOLIDATION?")
print("="*70)

print("\nüî¨ Possible Explanations:\n")

print("1. üìö REPLAY IS SUFFICIENT:")
print("   - Replay buffer already contains diverse examples from all tasks")
print("   - NN2 learns directly from ground truth labels (via CE loss)")
print("   - Distillation from NN1 may introduce noise/bias from fast learning")

print("\n2. ‚ö° NN1 IS TOO TASK-SPECIFIC:")
print("   - NN1 adapts quickly to current task (high learning rate)")
print("   - NN1's predictions may be overfit to recent tasks")
print("   - Distilling from NN1 transfers this recency bias to NN2")

print("\n3. üéØ CONSOLIDATION PHASE IS DIFFERENT:")
print("   - During task training: Distillation helps (Œª=0.3 is beneficial)")
print("   - During consolidation: Only replay data, no new task interference")
print("   - Direct learning from labels may be cleaner than distillation")

print("\n4. üîÑ DISTILLATION TEMPERATURE MISMATCH:")
print("   - Temperature T=2.0 may not be optimal for consolidation phase")
print("   - Higher temperature softens targets, may lose important distinctions")

print("\n5. üìä OPTIMIZATION DYNAMICS:")
print("   - Pure CE loss: Clear gradient signal from labels")
print("   - Mixed CE+KL loss: Gradient conflict between objectives")
print("   - NN2 already benefits from NN1's summary embedding (input fusion)")

print("\n" + "="*70)
print("RECOMMENDATION")
print("="*70)

if mnist_stats['difference'] > 0 and mnist_stats['p_value'] < 0.05:
    print("\n‚úÖ UPDATE PAPER: Use Œª=0.0 for consolidation phase")
    print("\n   Modified training protocol:")
    print("   ‚Ä¢ Task training: Œª=0.3 (keep distillation)")
    print("   ‚Ä¢ Consolidation: Œª=0.0 (pure replay, no distillation)")
    print(f"\n   Expected improvement:")
    print(f"   ‚Ä¢ MNIST: +{mnist_stats['difference']:.2f}%")
    if cifar_stats['difference'] > 0:
        print(f"   ‚Ä¢ CIFAR-10: +{cifar_stats['difference']:.2f}%")
else:
    print("\n‚ö†Ô∏è KEEP CURRENT: Œª=0.5 performs similarly or better")
    print("   No significant benefit from switching to Œª=0.0")

print("\nüéâ Investigation Complete!")

## 9. Recorded Output (Colab Run, Nov 10, 2025)
```
============================================================
MNIST RESULTS
============================================================

üìä Œª=0.0 (No distillation during consolidation):
   Mean: 91.46%
   Std:  0.84%
   Values: ['90.39', '90.45', '91.00', '92.48', '91.10', '92.28', '92.13', '90.55', '92.12', '92.07']

üìä Œª=0.5 (Current approach):
   Mean: 90.20%
   Std:  1.67%
   Values: ['90.30', '88.74', '86.06', '91.23', '90.34', '91.12', '91.04', '90.58', '91.90', '90.70']

üìà Comparison:
   Difference: +1.26%
   t-statistic: 2.796
   p-value: 0.0208
   Cohen's d: 0.884
   Significant (p<0.05): YES ‚úÖ

üéØ Verdict: Œª=0.0 is SIGNIFICANTLY BETTER by 1.26%
============================================================


============================================================
CIFAR-10 RESULTS
============================================================

üìä Œª=0.0 (No distillation during consolidation):
   Mean: 34.38%
   Std:  0.67%
   Values: ['34.22', '33.64', '35.48', '34.34', '34.22']

üìä Œª=0.5 (Current approach):
   Mean: 32.62%
   Std:  1.70%
   Values: ['33.34', '32.77', '34.07', '33.22', '29.70']

üìà Comparison:
   Difference: +1.76%
   t-statistic: 2.525
   p-value: 0.0650
   Cohen's d: 1.129
   Significant (p<0.05): NO ‚ùå

üéØ Verdict: Œª=0.0 is BETTER by 1.76%
============================================================

üíæ Saved: /content/lambda_zero_investigation_20251110_194805.csv

üìä Combined Results:
                                mean       std  count
dataset consolidation_lambda                         
cifar10 0.0                   34.380  0.672756      5
        0.5                   32.620  1.697778      5
mnist   0.0                   91.457  0.836900     10
        0.5                   90.201  1.673257     10

======================================================================
ANALYSIS: WHY DOES Œª=0.0 WORK BETTER DURING CONSOLIDATION?
======================================================================

üî¨ Possible Explanations:

1. üìö REPLAY IS SUFFICIENT:
   - Replay buffer already contains diverse examples from all tasks
   - NN2 learns directly from ground truth labels (via CE loss)
   - Distillation from NN1 may introduce noise/bias from fast learning

2. ‚ö° NN1 IS TOO TASK-SPECIFIC:
   - NN1 adapts quickly to current task (high learning rate)
   - NN1's predictions may be overfit to recent tasks
   - Distilling from NN1 transfers this recency bias to NN2

3. üéØ CONSOLIDATION PHASE IS DIFFERENT:
   - During task training: Distillation helps (Œª=0.3 is beneficial)
   - During consolidation: Only replay data, no new task interference
   - Direct learning from labels may be cleaner than distillation

4. üîÑ DISTILLATION TEMPERATURE MISMATCH:
   - Temperature T=2.0 may not be optimal for consolidation phase
   - Higher temperature softens targets, may lose important distinctions

5. üìä OPTIMIZATION DYNAMICS:
   - Pure CE loss: Clear gradient signal from labels
   - Mixed CE+KL loss: Gradient conflict between objectives
   - NN2 already benefits from NN1's summary embedding (input fusion)

======================================================================
RECOMMENDATION
======================================================================

‚úÖ UPDATE PAPER: Use Œª=0.0 for consolidation phase

   Modified training protocol:
   ‚Ä¢ Task training: Œª=0.3 (keep distillation)
   ‚Ä¢ Consolidation: Œª=0.0 (pure replay, no distillation)

   Expected improvement:
   ‚Ä¢ MNIST: +1.26%
   ‚Ä¢ CIFAR-10: +1.76%

üéâ Investigation Complete!
```