# Module 05: Sparse Neural Networks

Sparse networks use fewer connections than dense networks while maintaining performance.

## Learning Objectives
- Understand sparsity and its benefits
- Learn about pruning techniques
- Implement sparse layers with binary masks
- Explore the Lottery Ticket Hypothesis

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
np.random.seed(42)

print("[OK] Libraries loaded")

## 1. Why Sparsity?

Dense networks have many parameters that may not all be necessary.

In [None]:
# Compare parameter counts
input_dim = 784  # e.g., MNIST
hidden_dim = 512
output_dim = 10

dense_params = input_dim * hidden_dim + hidden_dim * output_dim
print(f"Dense network parameters: {dense_params:,}")

# With 90% sparsity
sparsity = 0.9
sparse_params = int(dense_params * (1 - sparsity))
print(f"90% sparse network parameters: {sparse_params:,}")
print(f"Compression ratio: {dense_params / sparse_params:.1f}x")

## 2. Types of Sparsity

Sparsity can be unstructured (individual weights) or structured (entire neurons/channels).

In [None]:
# Visualize different sparsity patterns
def create_sparse_mask(shape, sparsity, pattern='random'):
    """Create different sparse masks."""
    if pattern == 'random':
        mask = (torch.rand(shape) > sparsity).float()
    elif pattern == 'row':  # Structured: remove entire rows
        mask = torch.ones(shape)
        rows_to_remove = int(shape[0] * sparsity)
        mask[:rows_to_remove, :] = 0
    elif pattern == 'block':  # Block sparsity
        block_size = 4
        mask = torch.ones(shape)
        for i in range(0, shape[0], block_size):
            for j in range(0, shape[1], block_size):
                if torch.rand(1) < sparsity:
                    mask[i:i+block_size, j:j+block_size] = 0
    return mask

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
patterns = ['random', 'row', 'block']
titles = ['Unstructured (Random)', 'Structured (Row)', 'Block Sparsity']

for ax, pattern, title in zip(axes, patterns, titles):
    mask = create_sparse_mask((32, 32), 0.7, pattern)
    ax.imshow(mask, cmap='Blues')
    ax.set_title(f"{title}\n({(mask==0).sum().item()}/{mask.numel()} zeros)")
    ax.axis('off')

plt.tight_layout()
plt.show()

## 3. Pruning Methods

Pruning removes connections from a trained network.

In [None]:
def magnitude_prune(weights, sparsity):
    """Prune weights with smallest magnitude."""
    threshold = torch.quantile(weights.abs().flatten(), sparsity)
    mask = (weights.abs() > threshold).float()
    return mask

def random_prune(weights, sparsity):
    """Random pruning (baseline)."""
    mask = (torch.rand_like(weights) > sparsity).float()
    return mask

# Compare pruning methods
weights = torch.randn(100, 100)

mag_mask = magnitude_prune(weights, 0.8)
rand_mask = random_prune(weights, 0.8)

# Check which preserves larger weights
mag_preserved = (weights.abs() * mag_mask).sum()
rand_preserved = (weights.abs() * rand_mask).sum()

print(f"Magnitude pruning preserves: {mag_preserved:.2f} total magnitude")
print(f"Random pruning preserves: {rand_preserved:.2f} total magnitude")
print(f"[OK] Magnitude pruning keeps {mag_preserved/rand_preserved:.2f}x more important weights")

## 4. Implementing Sparse Linear Layer

In [None]:
class SparseLinear(nn.Module):
    """Linear layer with sparse connectivity."""
    
    def __init__(self, in_features, out_features, density=0.3, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Weights
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)
        
        # Binary mask (not a parameter - not trained)
        mask = (torch.rand(out_features, in_features) < density).float()
        self.register_buffer('mask', mask)
    
    def forward(self, x):
        # Apply mask to weights
        masked_weight = self.weight * self.mask
        return F.linear(x, masked_weight, self.bias)
    
    def get_sparsity(self):
        """Return current sparsity level."""
        total = self.mask.numel()
        zeros = (self.mask == 0).sum().item()
        return zeros / total

# Test
layer = SparseLinear(100, 50, density=0.2)
x = torch.randn(8, 100)
y = layer(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Sparsity: {layer.get_sparsity():.1%}")
print(f"Active connections: {(layer.mask == 1).sum().item()} / {layer.mask.numel()}")

## 5. Training Sparse Networks

In [None]:
class SparseMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, density=0.3):
        super().__init__()
        self.fc1 = SparseLinear(input_dim, hidden_dim, density=density)
        self.fc2 = SparseLinear(hidden_dim, output_dim, density=density)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# Create synthetic data
X = torch.randn(500, 20)
y = (X[:, 0] + X[:, 1] > 0).long()  # Simple classification

# Train sparse vs dense
def train_model(model, X, y, epochs=100):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    losses = []
    
    for _ in range(epochs):
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    
    # Accuracy
    with torch.no_grad():
        pred = model(X).argmax(1)
        acc = (pred == y).float().mean().item()
    
    return losses, acc

# Compare different sparsity levels
densities = [1.0, 0.5, 0.3, 0.1]
results = {}

for d in densities:
    torch.manual_seed(42)
    if d == 1.0:
        model = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 2))
    else:
        model = SparseMLP(20, 64, 2, density=d)
    
    losses, acc = train_model(model, X, y)
    results[d] = (losses, acc)
    print(f"Density {d:.0%}: Final accuracy = {acc:.1%}")

# Plot
plt.figure(figsize=(10, 5))
for d, (losses, acc) in results.items():
    label = f"{'Dense' if d == 1.0 else f'{(1-d):.0%} Sparse'} (acc={acc:.1%})"
    plt.plot(losses, label=label)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Sparse vs Dense Training')
plt.legend()
plt.show()

## 6. The Lottery Ticket Hypothesis

Dense networks contain sparse subnetworks ("winning tickets") that can train to the same accuracy.

In [None]:
# Simplified Lottery Ticket experiment
def lottery_ticket_experiment(X, y, target_sparsity=0.9, iterations=3):
    """Iterative magnitude pruning to find winning tickets."""
    torch.manual_seed(42)
    
    # Initial weights (the "lottery ticket")
    init_weights = {
        'fc1.weight': torch.randn(64, 20) * 0.1,
        'fc1.bias': torch.zeros(64),
        'fc2.weight': torch.randn(2, 64) * 0.1,
        'fc2.bias': torch.zeros(2),
    }
    
    # Create masks (start fully connected)
    masks = {
        'fc1.weight': torch.ones(64, 20),
        'fc2.weight': torch.ones(2, 64),
    }
    
    prune_rate = 1 - (1 - target_sparsity) ** (1 / iterations)
    
    results = []
    
    for i in range(iterations + 1):
        # Create model with current mask and initial weights
        model = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )
        
        # Load initial weights
        model[0].weight.data = init_weights['fc1.weight'].clone() * masks['fc1.weight']
        model[0].bias.data = init_weights['fc1.bias'].clone()
        model[2].weight.data = init_weights['fc2.weight'].clone() * masks['fc2.weight']
        model[2].bias.data = init_weights['fc2.bias'].clone()
        
        # Train
        _, acc = train_model(model, X, y, epochs=100)
        
        current_sparsity = 1 - (masks['fc1.weight'].sum() + masks['fc2.weight'].sum()) / \
                          (masks['fc1.weight'].numel() + masks['fc2.weight'].numel())
        
        results.append((current_sparsity.item(), acc))
        print(f"Iteration {i}: Sparsity={current_sparsity:.1%}, Accuracy={acc:.1%}")
        
        # Prune for next iteration
        if i < iterations:
            for name in ['fc1.weight', 'fc2.weight']:
                layer_idx = 0 if 'fc1' in name else 2
                weights = model[layer_idx].weight.data.abs()
                threshold = torch.quantile(weights[masks[name] == 1], prune_rate)
                masks[name] = ((weights > threshold) & (masks[name] == 1)).float()
    
    return results

print("Lottery Ticket Hypothesis Demonstration:")
print("-" * 50)
results = lottery_ticket_experiment(X, y, target_sparsity=0.8, iterations=3)

print("\n[OK] Finding: Sparse subnetworks can match dense performance!")

## Summary

Key concepts covered:

1. **Sparsity Benefits**: Fewer parameters, less compute, potential regularization
2. **Types**: Unstructured (random), structured (rows/channels), block
3. **Pruning**: Magnitude-based pruning keeps important weights
4. **Sparse Layers**: Binary masks multiply weights to enforce sparsity
5. **Lottery Ticket**: Dense nets contain trainable sparse subnetworks

## Next Steps

- [->] Module 06: Unsupervised Learning
- [->] Module 07: Dynamic Sparse Training (SET, DEEP R)