# AAAI 2026 Tutorial: A Decade of Sparse Training

## Dynamic Sparse Training (DST) Tutorial

This notebook demonstrates Dynamic Sparse Training concepts and implementations.

**Learning Objectives:**
1. Understand the difference between simulated and truly sparse implementations
2. Implement DST algorithms from scratch
3. Experience the performance-efficiency trade-offs
4. Recognize system-level barriers to sparse training adoption

**Dataset:** MNIST (handwritten digits)
- 60,000 training samples, 10,000 test samples
- 28x28 grayscale images, 10 classes


In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import copy

# Import utilities from utils.py
from utils import (
    SimpleNet, MaskedNet,
    train_epoch, evaluate,
    apply_dst_step,
    benchmark_inference,
    plot_mask_evolution_per_epoch,
    plot_mask_2d_evolution,
    plot_dense_vs_masked_comparison,
    CUPY_GPU_AVAILABLE
)

print("✓ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
print(f"CuPy GPU available: {CUPY_GPU_AVAILABLE}")


## Configuration

Set hyperparameters and global settings here.


In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Inform about CuPy GPU availability
if CUPY_GPU_AVAILABLE:
    print(f"✓ CuPy GPU acceleration available - will use GPU CSR matrices")
else:
    print(f"⚠ CuPy GPU not available - will use CPU CSR matrices")

# Hyperparameters
GLOBAL_SPARSITY = 0.95  # 95% sparsity (adjust to experiment)
batch_size = 64
num_epochs = 10
learning_rate = 0.001
dst_frequency = 2  # Apply DST every N epochs

print(f"\nHyperparameters:")
print(f"  Sparsity: {GLOBAL_SPARSITY*100}%")
print(f"  Batch size: {batch_size}")
print(f"  Epochs: {num_epochs}")
print(f"  Learning rate: {learning_rate}")
print(f"  DST frequency: every {dst_frequency} epochs")


## Step 1: Load MNIST Dataset


In [None]:
# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
])

# Download and load training data
train_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=transform
)
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    num_workers=2 if torch.cuda.is_available() else 0
)

# Download and load test data
test_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=False, 
    download=True, 
    transform=transform
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    shuffle=False,
    num_workers=2 if torch.cuda.is_available() else 0
)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Image size: {train_dataset[0][0].shape}")
print(f"Number of classes: {len(train_dataset.classes)}")


## Step 2: Create Networks


In [None]:
# Create base network
base_net = SimpleNet(num_classes=10)
print(f"Base network created with {sum(p.numel() for p in base_net.parameters())} parameters")

# Create dense network (copy of base)
dense_net = copy.deepcopy(base_net).to(device)

# Create masked network
print(f"\nCreating masked network with {GLOBAL_SPARSITY*100}% sparsity")
masked_net = MaskedNet(base_net, sparsity=GLOBAL_SPARSITY).to(device)

# Print sparsity statistics
stats = masked_net.get_sparsity_stats()
print("\nSparsity Statistics:")
for name, stat in stats.items():
    print(f"  {name}: {stat['sparsity']*100:.2f}% sparse ({stat['zeros']}/{stat['total']} zeros)")

# Loss and optimizers
criterion = nn.CrossEntropyLoss()
dense_optimizer = optim.Adam(dense_net.parameters(), lr=learning_rate)
masked_optimizer = optim.Adam(masked_net.parameters(), lr=learning_rate)


## Step 3: Training Loop

Train both dense and masked networks, applying Dynamic Sparse Training periodically.


In [None]:
# Track metrics for both networks
dense_train_losses = []
dense_train_accs = []
dense_test_losses = []
dense_test_accs = []

masked_train_losses = []
masked_train_accs = []
masked_test_losses = []
masked_test_accs = []
sparsity_per_epoch = []  # Track sparsity per layer per epoch
masks_per_epoch = []  # Track actual mask values per epoch for visualization

print("="*60)
print("Training Dense and Dense+Mask Networks")
print("="*60)

for epoch in range(num_epochs):
    # Train dense network
    dense_train_loss, dense_train_acc = train_epoch(dense_net, train_loader, criterion, dense_optimizer, device)
    dense_train_losses.append(dense_train_loss)
    dense_train_accs.append(dense_train_acc)
    dense_test_acc, dense_test_loss = evaluate(dense_net, test_loader, device, criterion)
    dense_test_accs.append(dense_test_acc)
    dense_test_losses.append(dense_test_loss)
    
    # Train masked network
    masked_train_loss, masked_train_acc = train_epoch(masked_net, train_loader, criterion, masked_optimizer, device)
    masked_train_losses.append(masked_train_loss)
    masked_train_accs.append(masked_train_acc)
    masked_test_acc, masked_test_loss = evaluate(masked_net, test_loader, device, criterion)
    masked_test_accs.append(masked_test_acc)
    masked_test_losses.append(masked_test_loss)
    
    # Track sparsity per epoch (per layer)
    stats = masked_net.get_sparsity_stats()
    sparsity_dict = {name: stat['sparsity'] for name, stat in stats.items()}
    sparsity_per_epoch.append(sparsity_dict)
    
    # Track mask values per epoch for visualization
    masks_2d = masked_net.get_masks_as_2d()
    masks_per_epoch.append(masks_2d)
    
    # Apply DST periodically
    if (epoch + 1) % dst_frequency == 0:
        print(f"\nEpoch {epoch+1}: Applying Dynamic Sparse Training...")
        apply_dst_step(masked_net, prune_ratio=0.1)
        
        # Update sparsity statistics
        stats = masked_net.get_sparsity_stats()
        avg_sparsity = np.mean([s['sparsity'] for s in stats.values()])
        
        print(f"  Average sparsity after DST: {avg_sparsity*100:.2f}%")
    
    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Dense:      Train Loss: {dense_train_loss:.4f}, Train Acc: {dense_train_acc:.2f}%, Test Loss: {dense_test_loss:.4f}, Test Acc: {dense_test_acc:.2f}%")
    print(f"  Dense+Mask: Train Loss: {masked_train_loss:.4f}, Train Acc: {masked_train_acc:.2f}%, Test Loss: {masked_test_loss:.4f}, Test Acc: {masked_test_acc:.2f}%")
    print(f"  Difference: Loss Δ: {masked_test_loss - dense_test_loss:+.4f}, Acc Δ: {masked_test_acc - dense_test_acc:+.2f}%")

print("\n✓ Training complete!")


## Step 4: Final Evaluation


In [None]:
final_test_acc_dense, final_test_loss_dense = evaluate(dense_net, test_loader, device, criterion)
final_test_acc_masked, final_test_loss_masked = evaluate(masked_net, test_loader, device, criterion)

print("="*60)
print("Final Test Results")
print("="*60)
print(f"  Dense:      Loss: {final_test_loss_dense:.4f}, Acc: {final_test_acc_dense:.2f}%")
print(f"  Dense+Mask: Loss: {final_test_loss_masked:.4f}, Acc: {final_test_acc_masked:.2f}%")
print(f"  Difference: Loss Δ: {final_test_loss_masked - final_test_loss_dense:+.4f}, Acc Δ: {final_test_acc_masked - final_test_acc_dense:+.2f}%")

final_stats = masked_net.get_sparsity_stats()
print("\nFinal Sparsity Statistics:")
for name, stat in final_stats.items():
    print(f"  {name}: {stat['sparsity']*100:.2f}% sparse ({stat['zeros']}/{stat['total']} zeros)")


## Step 5: Generate Visualizations


In [None]:
print("="*60)
print("Generating plots...")
print("="*60)

# Plot mask evolution per epoch (sparsity line plot)
plot_mask_evolution_per_epoch(sparsity_per_epoch, masks_per_epoch, save_path='plots/mask_evolution.png')

# Plot 2D mask evolution as images
plot_mask_2d_evolution(masks_per_epoch, save_path='plots/mask_2d_evolution.png')

# Plot dense vs masked comparison
plot_dense_vs_masked_comparison(
    dense_train_losses, dense_train_accs, dense_test_losses, dense_test_accs,
    masked_train_losses, masked_train_accs, masked_test_losses, masked_test_accs,
    save_path='plots/dense_vs_masked.png'
)

print("\n✓ All plots saved to 'plots/' directory")


## Step 6: Benchmarking - Simulated vs Truly Sparse

Compare inference time for different implementations to demonstrate the difference between simulated and truly sparse approaches.


In [None]:
print("="*60)
print("Benchmarking - Simulated vs Truly Sparse")
print("="*60)
print("\nComparing inference time for different implementations...")
print("(This demonstrates the difference between simulated and truly sparse)")

benchmark_results = benchmark_inference(masked_net, test_loader, device, num_samples=100)

print("\nInference Time Comparison (100 samples, 10 runs):")
print("-" * 60)
for method, timing in benchmark_results.items():
    print(f"{method:30s}: {timing['mean']:6.2f} ± {timing['std']:5.2f} ms")


## Step 7: Speedup Analysis


In [None]:
print("="*60)
print("Speedup Analysis")
print("="*60)

# GPU comparison: GPU Dense+Mask vs GPU Truly Sparse (if available)
if 'Dense+Mask (GPU)' in benchmark_results and 'Truly Sparse CSR (GPU)' in benchmark_results:
    gpu_dense = benchmark_results['Dense+Mask (GPU)']['mean']
    gpu_sparse = benchmark_results['Truly Sparse CSR (GPU)']['mean']
    speedup = gpu_dense / gpu_sparse if gpu_sparse > 0 else 0
    print(f"\nMain Comparison: GPU Dense+Mask vs GPU Truly Sparse")
    print(f"  Dense+Mask (GPU): {gpu_dense:.2f} ms")
    print(f"  Truly Sparse CSR (GPU): {gpu_sparse:.2f} ms")
    if speedup > 1:
        print(f"  → GPU Truly Sparse is {speedup:.2f}x faster than GPU Dense+Mask!")
    elif speedup < 1:
        print(f"  → GPU Dense+Mask is {1/speedup:.2f}x faster")
    else:
        print(f"  → Similar performance")
    print(f"\n  Note: This demonstrates the efficiency gains of truly sparse")
    print(f"        implementations on GPU using cuSPARSE acceleration.")

# Cross-platform comparison: GPU Dense+Mask vs CPU Truly Sparse
if 'Dense+Mask (GPU)' in benchmark_results and 'Truly Sparse CSR (CPU)' in benchmark_results:
    gpu_dense = benchmark_results['Dense+Mask (GPU)']['mean']
    cpu_sparse = benchmark_results['Truly Sparse CSR (CPU)']['mean']
    speedup = gpu_dense / cpu_sparse if cpu_sparse > 0 else 0
    print(f"\nCross-Platform Comparison: GPU Dense+Mask vs CPU Truly Sparse")
    print(f"  Dense+Mask (GPU): {gpu_dense:.2f} ms")
    print(f"  Truly Sparse CSR (CPU): {cpu_sparse:.2f} ms")
    if speedup > 1:
        print(f"  → CPU Truly Sparse is {speedup:.2f}x faster than GPU Dense+Mask!")
    elif speedup < 1:
        print(f"  → GPU Dense+Mask is {1/speedup:.2f}x faster")
    else:
        print(f"  → Similar performance")
    print(f"\n  Note: This shows that truly sparse implementations on CPU")
    print(f"        can be competitive with dense implementations on GPU")
    print(f"        for high sparsity levels.")

# CPU comparison (for reference)
if 'Dense+Mask (CPU)' in benchmark_results and 'Truly Sparse CSR (CPU)' in benchmark_results:
    cpu_dense = benchmark_results['Dense+Mask (CPU)']['mean']
    cpu_sparse = benchmark_results['Truly Sparse CSR (CPU)']['mean']
    speedup = cpu_dense / cpu_sparse if cpu_sparse > 0 else 0
    print(f"\nCPU Comparison (for reference):")
    print(f"  Dense+Mask (CPU): {cpu_dense:.2f} ms")
    print(f"  Truly Sparse CSR (CPU): {cpu_sparse:.2f} ms")
    print(f"  Speedup: {speedup:.2f}x")


## Summary

### Key Concepts Demonstrated:
1. ✓ Simulated sparsity using binary masks (Dense+Mask)
2. ✓ Masked Conv2d and Linear layers
3. ✓ Dynamic Sparse Training (pruning + regrowing)
4. ✓ Gradient masking to prevent updates to pruned weights
5. ✓ Truly sparse implementation using CSR matrices
6. ✓ Benchmarking comparison: Dense+Mask vs Truly Sparse

### Understanding the Limitations

**THE PARADOX: Why Do We Still Stick to Dense Training?**

Despite DST being algorithmically mature and often outperforming dense training, most implementations remain dense. Here's why:

1. **SIMULATED SPARSITY (What We Just Implemented):**
   - Uses binary masks over dense weights
   - Convenient: works with existing frameworks
   - Problem: Still stores ALL weights in memory
   - Problem: Still computes with dense kernels
   - Result: NO memory savings, minimal speedup

2. **TRULY SPARSE IMPLEMENTATIONS:**
   - Uses sparse matrix formats (COO, CSR, CSC)
   - Stores only non-zero weights
   - Requires custom kernels for efficiency
   - Challenge: Hardware optimized for dense operations
   - Challenge: Requires significant engineering effort

3. **SYSTEM-LEVEL BARRIERS:**
   - Hardware: GPUs optimized for dense matrix multiplication
   - Software: Deep learning frameworks favor dense operations
   - Memory: Sparse formats have overhead for small sparsity
   - Engineering: Truly sparse implementations are complex

### Next Steps

- Understand sparse matrix formats and operations (COO, CSR, CSC)
- Consider hardware-aware sparsity patterns
- Join the sparse training community for collaboration

For more information, see the AAAI 2026 Tutorial:
**"A Decade of Sparse Training: Why Do We Still Stick to Dense Training?"**
