# Recursive Self-Distillation Experiments

This notebook demonstrates the self-distillation system where a neural network iteratively trains improved versions of itself.

## Key Concepts:

1. **Born-Again Networks**: Each generation has identical architecture to the previous
2. **Soft Target Distillation**: Students learn from teacher's softened probability distributions
3. **Noisy Student Training**: Added noise (dropout, augmentation) makes students more robust
4. **EMA Teachers**: Exponential moving average provides stable distillation targets

In [None]:
import sys
sys.path.insert(0, '../src')

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

from models import ResNet18, MLP, VisionTransformerSmall
from data import get_cifar10_loaders, get_mnist_loaders
from distillation import distillation_loss, TemperatureScheduler
from generations import GenerationManager, CheckpointManager

%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

## 1. Understanding Distillation Loss

Let's visualize how temperature affects the soft targets.

In [None]:
# Create sample logits
logits = torch.tensor([[2.0, 1.0, 0.5, 0.1, 0.05, 0.01, 0.001, 0.0001, 0.0, 0.0]])

temperatures = [1.0, 2.0, 4.0, 10.0, 20.0]
fig, axes = plt.subplots(1, len(temperatures), figsize=(15, 3))

for i, T in enumerate(temperatures):
    probs = torch.softmax(logits / T, dim=-1).numpy()[0]
    axes[i].bar(range(10), probs)
    axes[i].set_title(f'T = {T}')
    axes[i].set_xlabel('Class')
    axes[i].set_ylabel('Probability')
    axes[i].set_ylim(0, 1)

plt.suptitle('Effect of Temperature on Soft Targets')
plt.tight_layout()
plt.show()

print("Higher temperature → softer distribution → more 'dark knowledge'")

## 2. Temperature Scheduling

Different scheduling strategies for temperature across generations.

In [None]:
schedules = ['constant', 'linear', 'cosine', 'step']
generations = 10

fig, ax = plt.subplots(figsize=(10, 5))

for schedule in schedules:
    scheduler = TemperatureScheduler(
        initial_temp=20.0,
        final_temp=1.0,
        schedule=schedule,
        total_steps=generations,
    )
    temps = [scheduler.get_temperature(g) for g in range(generations + 1)]
    ax.plot(range(generations + 1), temps, marker='o', label=schedule)

ax.set_xlabel('Generation')
ax.set_ylabel('Temperature')
ax.set_title('Temperature Scheduling Strategies')
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()

## 3. Quick Test: MNIST with MLP

Train a simple 2-generation self-distillation on MNIST.

In [None]:
# Configuration
config = {
    'model': {'name': 'mlp', 'num_classes': 10},
    'data': {'name': 'mnist', 'batch_size': 128, 'num_workers': 2, 'augmentation': False},
    'training': {
        'epochs': 5,  # Quick training
        'learning_rate': 0.01,
        'momentum': 0.9,
        'weight_decay': 1e-4,
        'scheduler': 'cosine',
    },
    'distillation': {
        'temperature': 4.0,
        'alpha': 0.7,
        'noisy_student': False,
        'dropout_rate': 0.0,
    },
    'generations': {
        'max_generations': 2,
        'plateau_threshold': 0.001,
        'plateau_patience': 2,
    },
    'seed': 42,
}

# Data loaders
train_loader, val_loader = get_mnist_loaders(
    batch_size=config['data']['batch_size'],
    num_workers=config['data']['num_workers'],
)

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")

In [None]:
# Model factory
def create_model():
    return MLP(input_dim=784, num_classes=10, hidden_dims=[256, 128])

# Test model
test_model = create_model()
print(f"Model parameters: {test_model.count_parameters():,}")
del test_model

In [None]:
# Run self-distillation evolution
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

manager = GenerationManager(
    model_factory=create_model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    checkpoint_dir='./checkpoints_mnist',
    device=device,
)

summary = manager.run_evolution(
    num_generations=config['generations']['max_generations'],
    early_stop=False,
)

In [None]:
# Visualize results
generations = list(range(summary['total_generations']))
accuracies = [summary['all_metrics'][g]['best_val_acc'] for g in generations]

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(generations, accuracies, marker='o', linewidth=2, markersize=10)
ax.set_xlabel('Generation')
ax.set_ylabel('Validation Accuracy (%)')
ax.set_title('Self-Distillation on MNIST')
ax.grid(True, alpha=0.3)

# Annotate improvements
for i in range(1, len(accuracies)):
    improvement = accuracies[i] - accuracies[i-1]
    ax.annotate(
        f'+{improvement:.2f}%',
        xy=(i, accuracies[i]),
        xytext=(10, 10),
        textcoords='offset points',
        fontsize=9,
        color='green' if improvement > 0 else 'red'
    )

plt.show()

## 4. CIFAR-10 with ResNet-18

More realistic experiment with a deeper CNN.

In [None]:
# CIFAR-10 configuration
cifar_config = {
    'model': {'name': 'resnet18', 'num_classes': 10},
    'data': {'name': 'cifar10', 'batch_size': 128, 'num_workers': 4, 'augmentation': True},
    'training': {
        'epochs': 50,  # More epochs for better results
        'learning_rate': 0.1,
        'momentum': 0.9,
        'weight_decay': 5e-4,
        'scheduler': 'cosine',
    },
    'distillation': {
        'temperature': 4.0,
        'alpha': 0.7,
        'noisy_student': True,
        'dropout_rate': 0.1,
    },
    'generations': {
        'max_generations': 3,
        'plateau_threshold': 0.001,
        'plateau_patience': 2,
    },
    'seed': 42,
}

# Note: This will take longer to run. Uncomment if you want to run it.
# train_loader, val_loader = get_cifar10_loaders(
#     batch_size=cifar_config['data']['batch_size'],
#     num_workers=cifar_config['data']['num_workers'],
#     augmentation=cifar_config['data']['augmentation'],
# )

# def create_resnet():
#     return ResNet18(num_classes=10, dropout_rate=0.1)

# manager = GenerationManager(
#     model_factory=create_resnet,
#     train_loader=train_loader,
#     val_loader=val_loader,
#     config=cifar_config,
#     checkpoint_dir='./checkpoints_cifar10',
#     device=device,
# )

# summary_cifar = manager.run_evolution(num_generations=3)

## 5. Analyzing Checkpoints

Load and compare different generations.

In [None]:
# Load checkpoint manager
ckpt_manager = CheckpointManager('./checkpoints_mnist')
lineage = ckpt_manager.get_lineage()

print("Training Lineage:")
print(f"Created: {lineage['created_at']}")
print(f"Generations: {len(lineage['generations'])}")
print()

for gen in lineage['generations']:
    print(f"Generation {gen['generation']}:")
    print(f"  Parent: {gen['parent']}")
    print(f"  Metrics: {gen['metrics']}")
    print(f"  Timestamp: {gen['timestamp']}")

In [None]:
# Compare predictions between generations
def compare_predictions(gen1, gen2, data_loader, num_batches=5):
    """Compare prediction agreement between two generations."""
    model1 = create_model()
    model2 = create_model()
    
    ckpt_manager.load_generation(gen1, model1)
    ckpt_manager.load_generation(gen2, model2)
    
    model1.eval()
    model2.eval()
    model1.to(device)
    model2.to(device)
    
    agreement_count = 0
    total_count = 0
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(data_loader):
            if i >= num_batches:
                break
                
            inputs = inputs.to(device)
            
            out1 = model1(inputs)
            out2 = model2(inputs)
            
            pred1 = out1.argmax(dim=1)
            pred2 = out2.argmax(dim=1)
            
            agreement_count += (pred1 == pred2).sum().item()
            total_count += len(labels)
    
    return agreement_count / total_count * 100

# Compare Gen 0 vs Gen 1
if len(lineage['generations']) >= 2:
    agreement = compare_predictions(0, 1, val_loader)
    print(f"Prediction agreement between Gen 0 and Gen 1: {agreement:.2f}%")

## 6. Confidence and Calibration Analysis

Analyze how confidence changes across generations.

In [None]:
def analyze_confidence(generation, data_loader):
    """Analyze model confidence."""
    model = create_model()
    ckpt_manager.load_generation(generation, model)
    model.eval()
    model.to(device)
    
    confidences = []
    correct_confidences = []
    incorrect_confidences = []
    
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            max_probs, predictions = probs.max(dim=1)
            
            confidences.extend(max_probs.cpu().numpy())
            
            correct_mask = predictions == labels
            correct_confidences.extend(max_probs[correct_mask].cpu().numpy())
            incorrect_confidences.extend(max_probs[~correct_mask].cpu().numpy())
    
    return {
        'all': np.array(confidences),
        'correct': np.array(correct_confidences),
        'incorrect': np.array(incorrect_confidences),
    }

# Analyze confidence for each generation
num_gens = len(lineage['generations'])
fig, axes = plt.subplots(1, num_gens, figsize=(5*num_gens, 4))

if num_gens == 1:
    axes = [axes]

for gen in range(num_gens):
    conf = analyze_confidence(gen, val_loader)
    
    axes[gen].hist(conf['correct'], bins=30, alpha=0.7, label='Correct', color='green')
    axes[gen].hist(conf['incorrect'], bins=30, alpha=0.7, label='Incorrect', color='red')
    axes[gen].set_xlabel('Confidence')
    axes[gen].set_ylabel('Count')
    axes[gen].set_title(f'Generation {gen}')
    axes[gen].legend()
    
    print(f"Gen {gen}: Mean confidence = {conf['all'].mean():.3f}")

plt.suptitle('Confidence Distribution by Generation')
plt.tight_layout()
plt.show()

## 7. Feature Visualization

Visualize what the model learns at different layers.

In [None]:
# Get a batch of data
inputs, labels = next(iter(val_loader))
inputs = inputs[:16].to(device)  # Just 16 samples

# Load best generation
best_gen = summary['best_generation']
model = create_model()
ckpt_manager.load_generation(best_gen, model)
model.eval()
model.to(device)

# Extract features
with torch.no_grad():
    features = model.get_features(inputs)

print(f"Feature shape: {features.shape}")

# Visualize features as heatmap
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
axes = axes.flatten()

for i in range(16):
    img = inputs[i].cpu().squeeze()
    if img.dim() == 1:
        # Reshape for MNIST
        img = img.view(28, 28)
    axes[i].imshow(img, cmap='gray')
    axes[i].axis('off')
    axes[i].set_title(f'Label: {labels[i].item()}')

plt.suptitle('Sample Images')
plt.tight_layout()
plt.show()

## 8. Training from Command Line

For full training, use the command line scripts:

```bash
# Train with default config
cd ..
python scripts/train.py

# Evaluate all generations
python scripts/evaluate.py --checkpoint-dir ./checkpoints

# Custom config
python scripts/train.py \n    model=resnet18 \n    data=cifar10 \n    training.epochs=100 \n    generations.max_generations=5
```

## 9. Summary

Key observations from self-distillation:

1. **Each generation typically improves**: Student learns from teacher's soft targets
2. **Dark knowledge transfer**: Teacher's probability distributions encode class similarities
3. **Diminishing returns**: Improvements get smaller with more generations
4. **Noisy student**: Adding noise (dropout, augmentation) helps generalization

### Next Steps:

- Try different architectures (Vision Transformer)
- Experiment with temperature scheduling
- Use EMA teacher for more stable training
- Try feature-based distillation
- Scale to larger datasets (ImageNet)