# Module 07: Dynamic Sparse Training

Dynamic sparse training evolves network connectivity during training, not just at the end.

## Learning Objectives
- Understand SET (Sparse Evolutionary Training)
- Learn DEEP R (Deep Rewiring) algorithm
- Implement dynamic rewiring in PyTorch
- Visualize topology evolution during training

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

torch.manual_seed(42)
np.random.seed(42)

print("[OK] Libraries loaded")

## 1. Static vs Dynamic Sparsity

Static sparsity fixes the mask; dynamic sparsity evolves it during training.

In [None]:
# Visualize the difference
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Static sparse
static_mask = (torch.rand(20, 20) > 0.7).float()
axes[0].imshow(static_mask, cmap='Blues')
axes[0].set_title('Static Sparsity\n(fixed mask)')

# Dynamic sparse - evolves
dynamic_mask = static_mask.clone()
# Simulate some rewiring
for _ in range(5):
    # Remove some connections
    active = torch.where(dynamic_mask == 1)
    n_remove = len(active[0]) // 10
    indices_to_remove = torch.randperm(len(active[0]))[:n_remove]
    for idx in indices_to_remove:
        dynamic_mask[active[0][idx], active[1][idx]] = 0
    
    # Add new connections
    inactive = torch.where(dynamic_mask == 0)
    indices_to_add = torch.randperm(len(inactive[0]))[:n_remove]
    for idx in indices_to_add:
        dynamic_mask[inactive[0][idx], inactive[1][idx]] = 1

axes[1].imshow(dynamic_mask, cmap='Blues')
axes[1].set_title('Dynamic Sparsity\n(evolved mask)')

# Show difference
diff = dynamic_mask - static_mask
axes[2].imshow(diff, cmap='RdBu', vmin=-1, vmax=1)
axes[2].set_title('Difference\n(red=removed, blue=added)')

for ax in axes:
    ax.axis('off')

plt.tight_layout()
plt.show()

print(f"Connections changed: {(diff != 0).sum().item()} / {diff.numel()}")

## 2. SET Algorithm

Sparse Evolutionary Training (SET) prunes weak connections and regrows random new ones.

In [None]:
def set_rewire(weights, mask, prune_rate=0.3, regrow_rate=0.3):
    """
    SET (Sparse Evolutionary Training) rewiring.
    
    1. Prune: Remove connections with smallest magnitude
    2. Regrow: Add new random connections
    """
    # Only consider active connections
    active_weights = weights * mask
    active_indices = torch.where(mask == 1)
    
    if len(active_indices[0]) == 0:
        return mask
    
    # Get magnitudes of active connections
    magnitudes = active_weights[active_indices].abs()
    
    # Prune smallest magnitude connections
    n_prune = int(len(magnitudes) * prune_rate)
    if n_prune > 0:
        threshold = torch.kthvalue(magnitudes, n_prune).values
        prune_mask = magnitudes <= threshold
        
        # Create new mask
        new_mask = mask.clone()
        for i, idx in enumerate(zip(active_indices[0], active_indices[1])):
            if prune_mask[i]:
                new_mask[idx[0], idx[1]] = 0
    else:
        new_mask = mask.clone()
    
    # Regrow: add random new connections
    inactive = torch.where(new_mask == 0)
    n_regrow = int(len(inactive[0]) * regrow_rate * (n_prune / max(1, len(inactive[0]))))
    n_regrow = min(n_regrow, n_prune, len(inactive[0]))
    
    if n_regrow > 0:
        regrow_indices = torch.randperm(len(inactive[0]))[:n_regrow]
        for idx in regrow_indices:
            new_mask[inactive[0][idx], inactive[1][idx]] = 1
    
    return new_mask

# Demo
weights = torch.randn(10, 10)
mask = (torch.rand(10, 10) > 0.7).float()

print(f"Before: {mask.sum().item():.0f} active connections")
new_mask = set_rewire(weights, mask, prune_rate=0.3, regrow_rate=0.3)
print(f"After: {new_mask.sum().item():.0f} active connections")
print(f"[OK] SET maintains similar sparsity while evolving topology")

## 3. DEEP R Algorithm

DEEP R uses gradient information to decide which connections to regrow.

In [None]:
def deep_r_rewire(weights, mask, gradients, temperature=1.0, prune_rate=0.3):
    """
    DEEP R (Deep Rewiring) algorithm.
    
    1. Prune: Remove connections with smallest magnitude
    2. Regrow: Add connections where gradients are large (gradient-guided)
    """
    # Prune weak connections (same as SET)
    active_weights = weights * mask
    active_indices = torch.where(mask == 1)
    
    if len(active_indices[0]) == 0:
        return mask
    
    magnitudes = active_weights[active_indices].abs()
    n_prune = int(len(magnitudes) * prune_rate)
    
    if n_prune > 0:
        threshold = torch.kthvalue(magnitudes, n_prune).values
        prune_mask = magnitudes <= threshold
        
        new_mask = mask.clone()
        for i, idx in enumerate(zip(active_indices[0], active_indices[1])):
            if prune_mask[i]:
                new_mask[idx[0], idx[1]] = 0
    else:
        new_mask = mask.clone()
    
    # DEEP R: Gradient-guided regrowth
    inactive = torch.where(new_mask == 0)
    if len(inactive[0]) > 0 and n_prune > 0:
        # Use gradient magnitude for inactive connections
        inactive_grads = gradients[inactive].abs()
        
        # Softmax with temperature for probabilistic selection
        probs = F.softmax(inactive_grads / temperature, dim=0)
        
        # Sample connections to regrow based on gradient magnitude
        n_regrow = min(n_prune, len(inactive[0]))
        regrow_indices = torch.multinomial(probs, n_regrow, replacement=False)
        
        for idx in regrow_indices:
            new_mask[inactive[0][idx], inactive[1][idx]] = 1
    
    return new_mask

# Demo with synthetic gradients
weights = torch.randn(10, 10)
mask = (torch.rand(10, 10) > 0.7).float()
gradients = torch.randn(10, 10)  # Simulated gradients

print(f"Before: {mask.sum().item():.0f} active connections")
new_mask = deep_r_rewire(weights, mask, gradients, temperature=1.0, prune_rate=0.3)
print(f"After: {new_mask.sum().item():.0f} active connections")
print(f"[OK] DEEP R uses gradients to guide where to add connections")

## 4. Dynamic Sparse Layer

In [None]:
class DynamicSparseLinear(nn.Module):
    """Linear layer with dynamic sparse training support."""
    
    def __init__(self, in_features, out_features, density=0.3):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
        self.bias = nn.Parameter(torch.zeros(out_features))
        
        mask = (torch.rand(out_features, in_features) < density).float()
        self.register_buffer('mask', mask)
        self.register_buffer('_stored_grad', torch.zeros_like(self.weight))
    
    def forward(self, x):
        return F.linear(x, self.weight * self.mask, self.bias)
    
    def store_gradients(self):
        """Store gradients for DEEP R."""
        if self.weight.grad is not None:
            self._stored_grad = self.weight.grad.clone()
    
    def rewire(self, method='set', prune_rate=0.3, temperature=1.0):
        """Rewire the layer connectivity."""
        with torch.no_grad():
            if method == 'set':
                self.mask.data = set_rewire(
                    self.weight.data, self.mask, prune_rate=prune_rate
                )
            elif method == 'deep_r':
                self.mask.data = deep_r_rewire(
                    self.weight.data, self.mask, self._stored_grad,
                    temperature=temperature, prune_rate=prune_rate
                )
    
    def get_sparsity(self):
        return 1 - (self.mask.sum() / self.mask.numel()).item()

# Test
layer = DynamicSparseLinear(100, 50, density=0.2)
x = torch.randn(8, 100)
y = layer(x)
loss = y.sum()
loss.backward()

print(f"Initial sparsity: {layer.get_sparsity():.1%}")
layer.store_gradients()
layer.rewire(method='set', prune_rate=0.2)
print(f"After SET rewire: {layer.get_sparsity():.1%}")

## 5. Training with Dynamic Sparsity

In [None]:
class DynamicSparseMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, density=0.3):
        super().__init__()
        self.fc1 = DynamicSparseLinear(input_dim, hidden_dim, density)
        self.fc2 = DynamicSparseLinear(hidden_dim, output_dim, density)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)
    
    def rewire_all(self, method='set', prune_rate=0.3, temperature=1.0):
        for layer in [self.fc1, self.fc2]:
            layer.store_gradients()
            layer.rewire(method, prune_rate, temperature)

# Create data
X = torch.randn(500, 20)
y = (X[:, 0] + X[:, 1] > 0).long()
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Train with dynamic sparsity
model = DynamicSparseMLP(20, 64, 2, density=0.3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

losses = []
sparsities = []
rewire_frequency = 50  # Rewire every 50 steps
step = 0

for epoch in range(50):
    epoch_loss = 0
    for batch_x, batch_y in loader:
        optimizer.zero_grad()
        out = model(batch_x)
        loss = criterion(out, batch_y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
        step += 1
        if step % rewire_frequency == 0:
            model.rewire_all(method='set', prune_rate=0.2)
    
    losses.append(epoch_loss / len(loader))
    avg_sparsity = (model.fc1.get_sparsity() + model.fc2.get_sparsity()) / 2
    sparsities.append(avg_sparsity)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: Loss={losses[-1]:.4f}, Sparsity={avg_sparsity:.1%}")

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(losses)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')

ax2.plot([s * 100 for s in sparsities])
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Sparsity (%)')
ax2.set_title('Network Sparsity During Training')

plt.tight_layout()
plt.show()

## 6. Visualizing Topology Evolution

In [None]:
# Track mask changes over time
torch.manual_seed(42)
layer = DynamicSparseLinear(16, 16, density=0.3)
masks_history = [layer.mask.clone()]

# Simulate training with rewiring
for i in range(10):
    # Fake forward/backward
    x = torch.randn(4, 16)
    y = layer(x).sum()
    y.backward()
    
    layer.store_gradients()
    layer.rewire(method='set', prune_rate=0.3)
    masks_history.append(layer.mask.clone())
    layer.weight.grad = None

# Visualize evolution
fig, axes = plt.subplots(2, 6, figsize=(15, 5))
axes = axes.flatten()

for i, (ax, mask) in enumerate(zip(axes, masks_history[::1])):
    ax.imshow(mask, cmap='Blues')
    ax.set_title(f'Step {i}')
    ax.axis('off')

plt.suptitle('Topology Evolution with SET Rewiring', fontsize=14)
plt.tight_layout()
plt.show()

# Measure how much changed
total_changed = 0
for i in range(1, len(masks_history)):
    changed = (masks_history[i] != masks_history[i-1]).sum().item()
    total_changed += changed

print(f"Total connections changed: {total_changed}")
print(f"Average per step: {total_changed / (len(masks_history)-1):.1f}")

## Summary

Key concepts covered:

1. **Static vs Dynamic**: Dynamic sparsity evolves connectivity during training
2. **SET Algorithm**: Prune weak connections, regrow random new ones
3. **DEEP R Algorithm**: Gradient-guided regrowth for smarter topology evolution
4. **Implementation**: Store gradients, apply rewiring periodically
5. **Visualization**: Watch topology adapt as training progresses

## Benefits of Dynamic Sparse Training

- Explores more topologies than static sparsity
- Can escape local minima by changing connectivity
- Gradient-guided methods (DEEP R) find useful connections faster

## Next Steps

- [->] Module 08: Mixture of Experts
- [->] Module 09: Multi-Modal Architectures