# Elastic Training PoC

ERNIE 5.0's elastic training idea at small scale:  
Extract different-sized sub-models from a **single training run** and compare against classical methods (pruning, distillation).

In [None]:
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time, json, os

from models import ElasticMoEModel, BaselineCNN
from training.elastic_trainer import ElasticTrainer, StandardTrainer, get_warmup_cosine_scheduler
from training.pruning import StructuredPruner
from training.distillation import DistillationTrainer
from evaluation.extract_submodel import SubModelExtractor
from evaluation.benchmark import Benchmarker
from visualization.plots import Visualizer

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

## 1. Data

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)

print(f"Train: {len(train_set)}, Test: {len(test_set)}")

## 2. Elastic MoE Training

3 elasticity axes trained simultaneously:
- **Depth**: 3-6 layers active per step
- **Width**: 4-16 experts available per step  
- **Sparsity**: top-1 to top-3 routing per step

New features:
- **Progressive elastic**: gradually expand config space (large -> medium -> all)
- **Warmup + cosine scheduler**: per-step LR with 5% warmup
- **Loss weighting**: larger sub-models get higher weight [0.5, 0.2, 0.3]
- **ColorJitter** data augmentation

In [None]:
EPOCHS = 120
LR = 0.001
NUM_EXPERTS = 16  # 8 or 16

In [None]:
# Build model
elastic_model = ElasticMoEModel(
    num_classes=10, num_blocks=6, embed_dim=128,
    moe_hidden_dim=256, num_experts=NUM_EXPERTS, default_top_k=2,
)
total_params = sum(p.numel() for p in elastic_model.parameters())
print(f"Total parameters: {total_params:,}")
print(f"Num experts: {NUM_EXPERTS}")
print(f"Width choices: {elastic_model.width_choices}")
print(f"Depth choices: {elastic_model.depth_choices}")

# OR load existing checkpoint:
# elastic_model.load_state_dict(torch.load('checkpoints/elastic_moe.pt', weights_only=True))
# print("Loaded from checkpoint")

In [None]:
optimizer = optim.AdamW(elastic_model.parameters(), lr=LR, weight_decay=1e-4)
num_training_steps = EPOCHS * len(train_loader)
num_warmup_steps = int(0.05 * num_training_steps)  # 5% warmup
scheduler = get_warmup_cosine_scheduler(optimizer, num_warmup_steps, num_training_steps)

trainer = ElasticTrainer(
    model=elastic_model, train_loader=train_loader, val_loader=test_loader,
    optimizer=optimizer, scheduler=scheduler, device=device,
    use_sandwich_rule=True, use_progressive=True, loss_weights=[0.5, 0.2, 0.3],
)

elastic_history = trainer.train(num_epochs=EPOCHS)

# Save checkpoint
os.makedirs('checkpoints', exist_ok=True)
torch.save(elastic_model.state_dict(), 'checkpoints/elastic_moe.pt')
print("Checkpoint saved.")

## 3. Sub-Model Extraction

One checkpoint -> three different models:

In [None]:
extractor = SubModelExtractor(elastic_model)
sub_models = extractor.extract_all_presets()

# Evaluate each
print("\nSub-model accuracies:")
for size in ['large', 'medium', 'small']:
    config = elastic_model.get_submodel_config(size)
    acc = trainer.evaluate(elastic_config=config)
    params = elastic_model.count_active_params(config)
    print(f"  {size:8s} | Acc: {acc:.4f} | Params: {params:,} | Config: {config}")

## 4. Baselines

Train comparison models to prove elastic training works:

In [None]:
# Baseline 1: Large CNN
print("=" * 50)
print("Baseline 1: Large CNN")
print("=" * 50)
large_cnn = BaselineCNN(num_classes=10, size='large')
opt = optim.AdamW(large_cnn.parameters(), lr=LR, weight_decay=1e-4)
sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
large_trainer = StandardTrainer(large_cnn, train_loader, test_loader, opt, sch, device)
large_history = large_trainer.train(EPOCHS)

In [None]:
# Baseline 2: Pruning
print("=" * 50)
print("Baseline 2: Pruning (50%)")
print("=" * 50)
pruner = StructuredPruner(large_cnn, train_loader, test_loader, device)
pruned_model, prune_history = pruner.prune(prune_ratio=0.5, finetune_epochs=max(5, EPOCHS // 4), lr=LR * 0.1)

In [None]:
# Baseline 3: Knowledge Distillation
print("=" * 50)
print("Baseline 3: Knowledge Distillation")
print("=" * 50)
student = BaselineCNN(num_classes=10, size='small')
opt = optim.AdamW(student.parameters(), lr=LR, weight_decay=1e-4)
distiller = DistillationTrainer(
    teacher_model=large_cnn, student_model=student,
    train_loader=train_loader, val_loader=test_loader,
    optimizer=opt, device=device, temperature=4.0, alpha=0.7,
)
distill_history = distiller.train(EPOCHS)

In [None]:
# Baseline 4: Small CNN from scratch
print("=" * 50)
print("Baseline 4: Small CNN (from scratch)")
print("=" * 50)
small_cnn = BaselineCNN(num_classes=10, size='small')
opt = optim.AdamW(small_cnn.parameters(), lr=LR, weight_decay=1e-4)
sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
small_trainer = StandardTrainer(small_cnn, train_loader, test_loader, opt, sch, device)
small_history = small_trainer.train(EPOCHS)

## 5. Benchmark

In [None]:
benchmarker = Benchmarker(test_loader, device=device, num_warmup=5, num_runs=50)

baseline_models = {
    'large_cnn': large_cnn,
    'pruned': pruned_model,
    'distilled': student,
    'small_scratch': small_cnn,
}

benchmark_results = benchmarker.run_full_comparison(elastic_model, baseline_models)

## 6. Visualizations

In [None]:
# Training curves
import matplotlib
matplotlib.use('module://matplotlib_inline.backend_inline')  # inline for notebook

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

epochs = range(1, len(elastic_history['train_loss']) + 1)
ax1.plot(epochs, elastic_history['train_loss'], linewidth=2)
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.set_title('Elastic Training Loss'); ax1.grid(True, alpha=0.3)

ax2.plot(epochs, [a*100 for a in elastic_history['val_acc_large']], label='Large (6L/8E/top-2)', linewidth=2)
ax2.plot(epochs, [a*100 for a in elastic_history['val_acc_medium']], label='Medium (4L/6E/top-2)', linewidth=2)
ax2.plot(epochs, [a*100 for a in elastic_history['val_acc_small']], label='Small (3L/4E/top-1)', linewidth=2)
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Accuracy (%)'); ax2.set_title('Sub-Model Accuracy'); ax2.legend(); ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Method comparison bar chart
names = [r['name'] for r in benchmark_results]
accs = [r['accuracy'] * 100 for r in benchmark_results]
colors = ['#58a6ff' if r.get('method') == 'elastic' else '#f78166' for r in benchmark_results]

fig, ax = plt.subplots(figsize=(12, 5))
bars = ax.bar(range(len(names)), accs, color=colors, edgecolor='white', linewidth=0.5)
for bar, acc in zip(bars, accs):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3, f'{acc:.1f}%', ha='center', fontsize=10)
ax.set_xticks(range(len(names)))
ax.set_xticklabels(names, rotation=30, ha='right')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Elastic Training vs Baselines')
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()

In [None]:
# Accuracy vs Latency trade-off
fig, ax = plt.subplots(figsize=(10, 6))
for r in benchmark_results:
    color = '#58a6ff' if r.get('method') == 'elastic' else '#f78166'
    ax.scatter(r['latency_ms'], r['accuracy']*100, color=color, s=150, edgecolors='white', zorder=5)
    ax.annotate(r['name'], (r['latency_ms'], r['accuracy']*100), textcoords='offset points', xytext=(10,5), fontsize=9)
ax.set_xlabel('Latency (ms)'); ax.set_ylabel('Accuracy (%)')
ax.set_title('Accuracy vs Inference Speed')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Expert routing heatmap
import seaborn as sns

elastic_model.eval()
elastic_model.to(device)
with torch.no_grad():
    images, _ = next(iter(test_loader))
    config = elastic_model.get_submodel_config('large')
    _, _, router_info = elastic_model(images.to(device), elastic_config=config)

fig, axes = plt.subplots(1, len(router_info), figsize=(3.5 * len(router_info), 4))
if len(router_info) == 1:
    axes = [axes]
for idx, info in enumerate(router_info):
    probs = info['probs'].cpu().numpy().mean(axis=0).reshape(1, -1)
    sns.heatmap(probs, ax=axes[idx], cmap='YlOrRd', annot=True, fmt='.3f',
                xticklabels=[f'E{i}' for i in range(probs.shape[1])], yticklabels=['Prob'])
    axes[idx].set_title(f'Block {idx}')
plt.suptitle('Expert Routing Probabilities', y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# Save publication-quality plots
viz = Visualizer(save_dir='plots')
viz.generate_all_plots(elastic_history, benchmark_results, router_info)
print("High-res plots saved to plots/ folder.")