# 🌿 Tutorial: Deep Learning for Neural Network Pruning

![Neural Network Pruning](https://upload.wikimedia.org/wikipedia/commons/thumb/c/c2/Neural_network.svg/500px-Neural_network.svg.png)

## Welcome to the Fascinating World of Neural Network Pruning! ✂️

In this comprehensive tutorial, you'll learn:
- 🌿 What is neural network pruning and why it matters
- 🧮 The mathematics behind sparsity and compression
- 🤖 How to use advanced techniques for intelligent weight selection
- 💻 Hands-on implementation with PyTorch
- 🎯 Visualization of pruning effects on model performance
- 🧪 Interactive exercises to build your skills

By the end, you'll be ready to implement sophisticated pruning algorithms that maintain model performance while dramatically reducing model size!


## 📚 Table of Contents

1. [🎓 Understanding Neural Network Pruning](#1--understanding-neural-network-pruning)
2. [🧮 Mathematical Foundation](#2--mathematical-foundation)
3. [🔧 Setting Up the Environment](#3--setting-up-the-environment)
4. [🏗️ Model Architecture](#4--model-architecture)
5. [✂️ Classical Pruning Approaches](#5--classical-pruning-approaches)
6. [🤖 Advanced Pruning Techniques](#6--advanced-pruning-techniques)
7. [🧠 Linear Approximation Strategy](#7--linear-approximation-strategy)
8. [🎯 Custom Scoring Function](#8--custom-scoring-function)
9. [💼 Complete Solution Strategy](#9--complete-solution-strategy)
10. [🎨 Visualizing Pruning Effects](#10--visualizing-pruning-effects)
11. [🎮 Interactive Exercises](#11--interactive-exercises)
12. [🚀 Advanced Techniques](#12--advanced-techniques)
13. [📖 Summary and Next Steps](#13--summary-and-next-steps)


## 1. 🎓 Understanding Neural Network Pruning

### What is Neural Network Pruning?

Imagine you have a beautiful, fully grown tree 🌳. While every branch and leaf contributes to the tree's function, you might find that you can carefully trim certain branches without significantly affecting the tree's health or appearance. **Neural network pruning** works on the same principle!

In deep learning, we often create networks that are **over-parameterized** - they have many more parameters than strictly necessary. Pruning allows us to:

- **Reduce model size** 📉 (fewer parameters to store)
- **Speed up inference** ⚡ (fewer computations)
- **Prevent overfitting** 🎯 (simpler models generalize better)
- **Improve interpretability** 🔍 (focus on important connections)

### Key Concepts:

- **Sparsity**: The fraction of parameters that are zero (higher = more pruned)
- **Structured vs Unstructured**: Removing entire neurons/channels vs individual weights
- **Magnitude-based**: Removing weights with smallest absolute values
- **Gradient-based**: Removing weights with least impact on loss

### Mathematical Definition:

Given a neural network with parameters $\theta$, pruning creates a **sparse** version $\theta_{sparse}$ where:

$$\theta_{sparse}[i] = \begin{cases} 
\theta[i] & \text{if weight is important} \\
0 & \text{if weight is pruned}
\end{cases}$$

The **sparsity** is defined as:

$$s = \frac{\text{number of zero parameters}}{\text{total number of parameters}}$$

### Why Should We Care?

🚀 **Mobile Deployment**: Smaller models fit on phones and edge devices  
💰 **Cost Reduction**: Less computation = lower cloud costs  
🌱 **Environmental Impact**: Fewer operations = less energy consumption  
⚡ **Real-time Applications**: Faster inference for time-critical tasks  
🎯 **Better Understanding**: Sparse models reveal which connections matter most


## 2. 🧮 Mathematical Foundation

Before diving into implementation, let's understand the mathematical framework behind pruning.

### The Core Trade-off

Pruning is fundamentally about balancing two competing objectives:

1. **Model Performance**: We want to maintain low prediction error
2. **Model Sparsity**: We want to zero out as many weights as possible

This creates a **multi-objective optimization problem**:

$$\min_{\theta_{sparse}} \left[ \mathcal{L}(\theta_{sparse}) + \lambda \cdot \text{Complexity}(\theta_{sparse}) \right]$$

Where:
- $\mathcal{L}$ is our loss function (e.g., MSE for regression)
- $\lambda$ controls the sparsity-performance trade-off
- $\text{Complexity}$ measures how "complex" our model is

### Our Specific Scoring Function

In our pruning challenge, we use a sophisticated scoring function that captures this trade-off:

$$\text{score}(s, \epsilon) = \begin{cases}
0 & \text{if } \epsilon > 1000 \\
(1 - \frac{\epsilon}{1000})^{1.5} \cdot s^{1.5} & \text{otherwise}
\end{cases}$$

Where:
- $s$ = sparsity (fraction of zero weights)
- $\epsilon$ = MSE on test set

### Understanding the Scoring Function

This function is designed to:
- **Reward high sparsity**: $s^{1.5}$ grows super-linearly with sparsity
- **Penalize high error**: $(1 - \frac{\epsilon}{1000})^{1.5}$ decreases as MSE increases
- **Set hard limits**: Score is 0 if MSE > 1000 (model completely broken)
- **Balance both objectives**: Neither sparsity nor accuracy alone is sufficient

### The Challenge

This is a **discrete, non-convex optimization problem**. We need to decide for each weight: keep it or zero it out. With thousands of parameters, this creates an enormous search space!

Traditional approaches:
- **Magnitude pruning**: Remove smallest weights 📏
- **Gradient-based**: Remove weights with small gradients 📈
- **Structured pruning**: Remove entire neurons/filters 🏗️

**Our Innovation**: Use intelligent approximation strategies to find near-optimal sparse representations! 🧠


## 3. 🔧 Setting Up the Environment

Let's start by importing all the necessary libraries and setting up our environment. We'll be working with PyTorch for neural networks and various other libraries for data processing and visualization.


In [None]:
# Essential imports for our neural network pruning tutorial
import copy
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from IPython.display import clear_output

# PyTorch for neural networks
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Set up device - GPU greatly speeds up neural network training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 Using device: {device}")

print("✅ Environment setup complete!")
print(f"📦 PyTorch version: {torch.__version__}")
print(f"🖥️  CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"🎮 GPU name: {torch.cuda.get_device_name(0)}")
    print(
        f"💾 GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
    )

## 4. 🏗️ Model Architecture

Let's understand the specific neural network we'll be working with. Our model is a **Multi-Layer Perceptron (MLP)** with a specific architecture that we cannot change.


In [None]:
# Define our Multi-Layer Perceptron (MLP) architecture
class MLP(nn.Module):
    """
    Multi-Layer Perceptron for regression task.

    Architecture (FIXED - cannot be changed):
    - Input layer: 128 features
    - Hidden layer: 1024 neurons with Sigmoid activation
    - Output layer: 10 targets (regression outputs)

    Total parameters: 128*1024 + 1024 + 1024*10 + 10 = 142,378 parameters
    """

    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential(
            nn.Linear(128, 1024),  # First layer: 128 -> 1024
            nn.Sigmoid(),  # Activation function
            nn.Linear(1024, 10),  # Output layer: 1024 -> 10
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.layers(x)
        return logits

    def loss(self, input, target, reduction="mean"):
        """Mean Squared Error loss for regression"""
        mse_loss = nn.MSELoss(reduction=reduction)
        return mse_loss(input, target)


# Create our model and move to device
model = MLP().to(device)
print(f"🧠 Created MLP model!")

# Count and display parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"📊 Total parameters: {total_params:,}")

# Break down parameters by layer
for name, param in model.named_parameters():
    print(f"   {name}: {param.shape} = {param.numel():,} parameters")

print(f"\\n🎯 This model performs regression: 128 inputs → 10 outputs")
print(f"📏 Model size in memory: ~{total_params * 4 / 1024:.1f} KB (float32)")

# Test the model with dummy data
with torch.no_grad():
    dummy_input = torch.randn(32, 128).to(device)  # Batch of 32 samples
    output = model(dummy_input)
    print(f"\\n✅ Model test successful!")
    print(f"   Input shape: {dummy_input.shape}")
    print(f"   Output shape: {output.shape}")

## 5. 🎯 Understanding the Scoring System

Before we start pruning, let's implement and understand the evaluation functions that will measure our success.


In [None]:
def get_sparsity(model):
    """
    Calculate the sparsity of a model (fraction of weights that are zero).

    Args:
        model: PyTorch model

    Returns:
        float: Sparsity value between 0 and 1 (1 = all weights are zero)
    """
    total_params = 0
    zero_params = 0

    for name, param in model.named_parameters():
        if "weight" in name or "bias" in name:
            total_params += param.numel()
            zero_params += torch.sum(param == 0).item()

    sparsity = zero_params / total_params
    return sparsity


def compute_error(model, data_loader):
    """
    Compute Mean Squared Error on a dataset.

    Args:
        model: PyTorch model
        data_loader: DataLoader with test data

    Returns:
        float: Average MSE across all samples
    """
    model.eval()
    total_loss = 0
    total_samples = 0

    with torch.no_grad():
        for x, y in data_loader:
            outputs = model(x)
            total_samples += x.shape[0] * y.shape[1]
            total_loss += model.loss(outputs, y, reduction="sum").item()

    return total_loss / total_samples


def score(mse_loss, sparsity, mse_weight=1.5, sparsity_weight=1.5):
    """
    Calculate the final score that balances MSE and sparsity.

    This is the exact function used for evaluation!

    Args:
        mse_loss: Mean squared error on test set
        sparsity: Fraction of zero weights (0 to 1)
        mse_weight: Exponent for MSE term (default 1.5)
        sparsity_weight: Exponent for sparsity term (default 1.5)

    Returns:
        float: Score (higher is better)
    """
    # Handle different input types
    if isinstance(mse_loss, np.ndarray):
        mse_loss = np.clip(mse_loss, 0, 1000)
    else:
        mse_loss = min(max(mse_loss, 0), 1000)

    # If MSE is too high, score is 0
    if mse_loss >= 1000:
        return 0.0

    # Calculate the balanced score
    score_value = (1 - mse_loss / 1000) ** mse_weight * sparsity**sparsity_weight
    return score_value


def points(score_value):
    """
    Convert score to final points (0 to 1.5).

    Args:
        score_value: Score from score() function

    Returns:
        float: Points between 0 and 1.5
    """

    def scale(x, lower=0.085, upper=0.95, max_points=1.5):
        scaled = min(max(x, lower), upper)
        return (scaled - lower) / (upper - lower) * max_points

    return scale(score_value)


print("🎯 Scoring functions implemented!")
print("   - get_sparsity(): Calculates fraction of zero weights")
print("   - compute_error(): Calculates MSE on dataset")
print("   - score(): Combines MSE and sparsity into final score")
print("   - points(): Converts score to assignment points")

# Let's test these functions with our model
model_sparsity = get_sparsity(model)
print(f"\\n📊 Current model statistics:")
print(
    f"   Sparsity: {model_sparsity:.3f} ({model_sparsity*100:.1f}% of weights are zero)"
)

# Create some dummy data to test error calculation


class DummyDataset(Dataset):
    def __init__(self, size=1000):
        self.X = torch.randn(size, 128)
        self.y = torch.randn(size, 10)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx].to(device), self.y[idx].to(device)


dummy_loader = DataLoader(DummyDataset(500), batch_size=32, shuffle=False)
dummy_mse = compute_error(model, dummy_loader)
dummy_score = score(dummy_mse, model_sparsity)

print(f"   MSE on dummy data: {dummy_mse:.3f}")
print(f"   Combined score: {dummy_score:.3f}")
print(f"   Assignment points: {points(dummy_score):.3f}/1.5")

## 6. ✂️ Classical Pruning Approaches

Before we dive into advanced techniques, let's understand and implement traditional pruning methods. This will help us appreciate why more sophisticated approaches are needed.

### Magnitude-Based Pruning: The Classic Approach

The most intuitive pruning method is **magnitude-based pruning**:
1. Calculate the absolute value of each weight
2. Sort weights by magnitude (smallest first)
3. Zero out the smallest weights until desired sparsity is reached
4. Keep the largest weights (they're presumably most important)

Let's implement this and see how it performs!


In [None]:
def magnitude_based_pruning(model, target_sparsity=0.9):
    """
    Prune model weights based on magnitude (smallest weights are removed).

    Args:
        model: PyTorch model to prune
        target_sparsity: Fraction of weights to zero out (between 0 and 1)

    Returns:
        model: Pruned model (modified in-place)
    """
    # Collect all weights and their absolute values
    all_weights = []
    for name, param in model.named_parameters():
        if "weight" in name or "bias" in name:
            all_weights.extend(param.data.abs().flatten().cpu().numpy())

    # Find the threshold: weights below this value will be pruned
    all_weights = np.array(all_weights)
    threshold = np.percentile(all_weights, target_sparsity * 100)

    print(f"🔪 Magnitude-based pruning:")
    print(f"   Target sparsity: {target_sparsity:.1%}")
    print(f"   Pruning threshold: {threshold:.6f}")
    print(f"   Weights below {threshold:.6f} will be set to zero")

    # Apply pruning
    pruned_count = 0
    total_count = 0

    with torch.no_grad():
        for name, param in model.named_parameters():
            if "weight" in name or "bias" in name:
                mask = param.data.abs() < threshold
                pruned_count += mask.sum().item()
                total_count += param.numel()
                param.data[mask] = 0

    actual_sparsity = pruned_count / total_count
    print(f"   Actual sparsity achieved: {actual_sparsity:.1%}")

    return model


def random_pruning(model, target_sparsity=0.9):
    """
    Randomly prune model weights (baseline comparison).

    Args:
        model: PyTorch model to prune
        target_sparsity: Fraction of weights to zero out

    Returns:
        model: Pruned model (modified in-place)
    """
    print(f"🎲 Random pruning:")
    print(f"   Target sparsity: {target_sparsity:.1%}")

    with torch.no_grad():
        for name, param in model.named_parameters():
            if "weight" in name or "bias" in name:
                # Create random mask
                mask = torch.rand_like(param.data) < target_sparsity
                param.data[mask] = 0

    actual_sparsity = get_sparsity(model)
    print(f"   Actual sparsity achieved: {actual_sparsity:.1%}")

    return model


# Let's test different pruning methods on copies of our model
print("🧪 Testing classical pruning methods...")

# First, let's create a simple trained model (we'll just use random weights for demo)


def init_weights(m):
    """Initialize model weights"""
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0.01)


# Test magnitude-based pruning
print("\\n" + "=" * 50)
model_magnitude = MLP().to(device)
model_magnitude.apply(init_weights)

# Record original performance
original_sparsity = get_sparsity(model_magnitude)
original_mse = compute_error(model_magnitude, dummy_loader)
original_score = score(original_mse, original_sparsity)

print(f"📊 Original model:")
print(f"   Sparsity: {original_sparsity:.1%}")
print(f"   MSE: {original_mse:.3f}")
print(f"   Score: {original_score:.3f}")

# Apply magnitude-based pruning
model_magnitude = magnitude_based_pruning(model_magnitude, target_sparsity=0.95)

# Record pruned performance
pruned_sparsity = get_sparsity(model_magnitude)
pruned_mse = compute_error(model_magnitude, dummy_loader)
pruned_score = score(pruned_mse, pruned_sparsity)

print(f"\\n📊 After magnitude-based pruning:")
print(
    f"   Sparsity: {pruned_sparsity:.1%} (↑{(pruned_sparsity-original_sparsity)*100:.1f}pp)"
)
print(f"   MSE: {pruned_mse:.3f} (↑{pruned_mse-original_mse:.3f})")
print(
    f"   Score: {pruned_score:.3f} ({'↑' if pruned_score > original_score else '↓'}{abs(pruned_score-original_score):.3f})"
)

# Compare with random pruning
print("\\n" + "=" * 50)
model_random = MLP().to(device)
model_random.apply(init_weights)
model_random = random_pruning(model_random, target_sparsity=0.95)

random_sparsity = get_sparsity(model_random)
random_mse = compute_error(model_random, dummy_loader)
random_score = score(random_mse, random_sparsity)

print(f"\\n📊 After random pruning:")
print(f"   Sparsity: {random_sparsity:.1%}")
print(f"   MSE: {random_mse:.3f}")
print(f"   Score: {random_score:.3f}")

print(
    f"\\n🏆 Winner: {'Magnitude-based' if pruned_score > random_score else 'Random'} pruning!"
)
print("(Though both are quite naive approaches...)")

## 7. 🧠 Linear Approximation Strategy

Now comes the exciting part! Instead of using simple heuristics like magnitude-based pruning, we'll use a sophisticated **linear approximation strategy**.

### The Big Idea 💡

What if we could find a way to make our complex 3-layer network (128 → 1024 → 10) behave like a simple linear model (128 → 10) while maintaining most of its expressive power?

The key insight is:
1. **Train a linear approximation** of the entire network (128 → 10 directly)
2. **Use the hidden layer as a "bridge"** to transfer this linear knowledge
3. **Create a sparse representation** that maintains the linear approximation's behavior

### Why This Works

For many regression tasks, the underlying function can be well-approximated by a linear model. Our hidden layer with 1024 neurons is **over-parameterized** - we don't need all that complexity!

By training a linear model first, we learn the **essential linear relationships** in the data. Then we embed this knowledge into our larger network in a sparse way.

### The Mathematical Strategy

Our approach creates a specific sparse pattern:

1. **First layer (128 → 1024)**: Identity-like mapping with small coefficients
2. **Second layer (1024 → 10)**: Scaled version of the linear model's weights

This creates a \"pathway\" through the network that implements the linear approximation while keeping most weights at zero.


In [None]:
def train_linear_approximation(data_loader, epochs=100, lr=0.01):
    """
    Train a linear model to approximate the target function.

    This linear model will serve as our \"teacher\" for the pruning strategy.

    Args:
        data_loader: DataLoader with training data
        epochs: Number of training epochs
        lr: Learning rate

    Returns:
        trained linear model
    """
    # Create simple linear model (128 -> 10, no hidden layers)
    linear_model = nn.Linear(128, 10).to(device)
    linear_model.loss = lambda input, target, reduction="mean": nn.MSELoss(
        reduction=reduction
    )(input, target)

    # Initialize weights
    torch.nn.init.xavier_normal_(linear_model.weight)
    linear_model.bias.data.fill_(0.01)

    # Train the linear model
    optimizer = optim.Adam(linear_model.parameters(), lr=lr)

    print(f"🚀 Training linear approximation...")
    print(f"   Architecture: 128 → 10 (linear)")
    print(
        f"   Parameters: {sum(p.numel() for p in linear_model.parameters())} (vs {142378} in full model)"
    )

    best_loss = float("inf")
    for epoch in range(epochs):
        linear_model.train()
        total_loss = 0
        num_batches = 0

        for inputs, targets in data_loader:
            optimizer.zero_grad()
            outputs = linear_model(inputs)
            loss = linear_model.loss(outputs, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / num_batches
        if avg_loss < best_loss:
            best_loss = avg_loss

        if (epoch + 1) % 20 == 0:
            print(f"   Epoch {epoch+1}: Loss = {avg_loss:.4f}")

    final_loss = compute_error(linear_model, data_loader)
    print(f"✅ Linear approximation trained! Final MSE: {final_loss:.4f}")

    return linear_model


def advanced_pruning_strategy(model, linear_model, eps=0.001, gamma=0.25):
    """
    Advanced pruning using linear approximation strategy.

    This implements the core idea from the solution:
    1. Zero out most weights
    2. Create a \"bridge\" through the hidden layer using diagonal connections
    3. Transfer linear model knowledge to the output layer

    Args:
        model: Original MLP model to prune
        linear_model: Trained linear approximation
        eps: Small value for diagonal \"bridge\" weights
        gamma: Scaling factor for compensation

    Returns:
        Pruned model
    """
    print(f"🔬 Applying advanced linear approximation pruning...")
    print(f"   Bridge coefficient (eps): {eps}")
    print(f"   Scaling factor (gamma): {gamma}")

    with torch.no_grad():
        # STEP 1: Zero out all weights we'll be modifying
        model.layers[0].weight.data.zero_()  # First layer weights (128 x 1024)
        model.layers[0].bias.data.zero_()  # First layer biases (1024)
        model.layers[2].weight.data.zero_()  # Output layer weights (10 x 1024)
        model.layers[2].bias.data.zero_()  # Output layer biases (10)

        # STEP 2: Create diagonal \"bridge\" in first layer
        # This allows information to flow through specific pathways
        for i in range(128):  # Only use first 128 neurons of hidden layer
            model.layers[0].weight.data[i, i] = eps

        print(f"   Created diagonal bridge: {128} connections with weight {eps}")

        # STEP 3: Transfer linear model knowledge to output layer
        # We need to scale appropriately to compensate for eps and gamma
        for j in range(10):  # For each output
            for i in range(128):  # For each input feature
                # Scale the linear model weight by the bridge scaling
                model.layers[2].weight.data[j, i] = linear_model.weight[j, i] / (
                    eps * gamma
                )

        print(f"   Transferred linear weights with scaling factor: {1/(eps*gamma):.1f}")

        # STEP 4: Set output biases with correction
        for j in range(10):
            # Copy bias from linear model with correction for systematic shift
            bias_correction = torch.sum(model.layers[2].weight.data[j]) / 2
            model.layers[2].bias.data[j] = linear_model.bias.data[j] - bias_correction

        print(f"   Set output biases with correction")

    # Calculate final sparsity
    final_sparsity = get_sparsity(model)
    print(f"✅ Advanced pruning complete!")
    print(f"   Final sparsity: {final_sparsity:.1%}")

    return model


# Let's test this approach with our dummy data
print("🧪 Testing Linear Approximation Strategy")
print("=" * 60)

# Train linear approximation
linear_model = train_linear_approximation(dummy_loader, epochs=50, lr=0.01)

# Test linear model performance
linear_mse = compute_error(linear_model, dummy_loader)
linear_score = score(linear_mse, 0.0)  # Linear model has 0% sparsity
print(f"\\n📊 Linear model performance:")
print(f"   MSE: {linear_mse:.4f}")
print(f"   Score (no sparsity): {linear_score:.4f}")

# Create and prune a new MLP model
print("\\n" + "=" * 60)
advanced_model = MLP().to(device)
advanced_model.apply(init_weights)

# Apply our advanced pruning strategy
advanced_model = advanced_pruning_strategy(advanced_model, linear_model)

# Evaluate the pruned model
advanced_mse = compute_error(advanced_model, dummy_loader)
advanced_sparsity = get_sparsity(advanced_model)
advanced_score = score(advanced_mse, advanced_sparsity)

print(f"\\n📊 Advanced pruned model performance:")
print(f"   MSE: {advanced_mse:.4f} (vs {linear_mse:.4f} linear)")
print(f"   Sparsity: {advanced_sparsity:.1%}")
print(f"   Score: {advanced_score:.4f}")
print(f"   Points: {points(advanced_score):.3f}/1.5")

# Compare with magnitude-based pruning
print(f"\\n🏆 Comparison:")
print(f"   Advanced strategy score: {advanced_score:.4f}")
print(f"   Magnitude pruning score: {pruned_score:.4f}")
print(f"   Improvement: {advanced_score - pruned_score:.4f} points!")

## 8. 🎨 Visualizing Pruning Effects

Understanding what pruning does to our model is crucial. Let's create visualizations to see the structure and sparsity patterns.


In [None]:
def visualize_weight_distribution(model, title="Model Weight Distribution"):
    """
    Visualize the distribution of weights in the model.

    Args:
        model: PyTorch model
        title: Title for the plot
    """
    # Collect all weights
    all_weights = []
    layer_weights = {}

    for name, param in model.named_parameters():
        if "weight" in name:
            weights = param.data.cpu().numpy().flatten()
            all_weights.extend(weights)
            layer_weights[name] = weights

    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(title, fontsize=16, fontweight="bold")

    # Plot 1: Overall weight distribution
    axes[0, 0].hist(all_weights, bins=50, alpha=0.7, edgecolor="black")
    axes[0, 0].set_title("All Weights Distribution")
    axes[0, 0].set_xlabel("Weight Value")
    axes[0, 0].set_ylabel("Frequency")
    axes[0, 0].axvline(0, color="red", linestyle="--", alpha=0.7, label="Zero")
    axes[0, 0].legend()

    # Plot 2: Layer-wise weight distributions
    for i, (name, weights) in enumerate(layer_weights.items()):
        if i < 3:  # Only plot first 3 layers
            axes[0, 1].hist(weights, bins=30, alpha=0.5, label=name, density=True)
    axes[0, 1].set_title("Layer-wise Weight Distributions")
    axes[0, 1].set_xlabel("Weight Value")
    axes[0, 1].set_ylabel("Density")
    axes[0, 1].legend()

    # Plot 3: Sparsity per layer
    layer_names = []
    sparsity_values = []

    for name, param in model.named_parameters():
        if "weight" in name or "bias" in name:
            total = param.numel()
            zeros = torch.sum(param == 0).item()
            sparsity = zeros / total
            layer_names.append(name.replace("layers.", ""))
            sparsity_values.append(sparsity)

    axes[1, 0].bar(range(len(layer_names)), sparsity_values, alpha=0.7)
    axes[1, 0].set_title("Sparsity per Layer")
    axes[1, 0].set_xlabel("Layer")
    axes[1, 0].set_ylabel("Sparsity")
    axes[1, 0].set_xticks(range(len(layer_names)))
    axes[1, 0].set_xticklabels(layer_names, rotation=45)

    # Plot 4: Weight matrix heatmap (first layer only)
    first_layer_weights = model.layers[0].weight.data.cpu().numpy()
    # Show only a subset for visibility
    subset = first_layer_weights[:64, :64]  # Top-left 64x64 block
    im = axes[1, 1].imshow(subset, cmap="RdBu", aspect="auto")
    axes[1, 1].set_title("First Layer Weight Matrix (64x64 subset)")
    axes[1, 1].set_xlabel("Input Features")
    axes[1, 1].set_ylabel("Hidden Neurons")
    plt.colorbar(im, ax=axes[1, 1])

    plt.tight_layout()
    plt.show()


def analyze_pruning_patterns(models_dict):
    """
    Compare pruning patterns across different methods.

    Args:
        models_dict: Dictionary of {name: model} pairs
    """
    fig, axes = plt.subplots(2, len(models_dict), figsize=(5 * len(models_dict), 8))
    if len(models_dict) == 1:
        axes = axes.reshape(-1, 1)

    for col, (name, model) in enumerate(models_dict.items()):
        # Plot 1: Weight magnitude vs position
        all_weights = []
        positions = []

        pos = 0
        for param_name, param in model.named_parameters():
            if "weight" in param_name:
                weights = param.data.cpu().numpy().flatten()
                all_weights.extend(np.abs(weights))
                positions.extend(range(pos, pos + len(weights)))
                pos += len(weights)

        axes[0, col].scatter(positions, all_weights, alpha=0.5, s=1)
        axes[0, col].set_title(f"{name}\\nWeight Magnitudes")
        axes[0, col].set_xlabel("Parameter Index")
        axes[0, col].set_ylabel("Absolute Weight Value")
        axes[0, col].set_yscale("log")

        # Plot 2: Sparsity breakdown
        layer_sparsity = []
        layer_names = []

        for param_name, param in model.named_parameters():
            if "weight" in param_name or "bias" in param_name:
                total = param.numel()
                zeros = torch.sum(param == 0).item()
                sparsity = zeros / total
                layer_sparsity.append(sparsity)
                layer_names.append(param_name.split(".")[-1])

        bars = axes[1, col].bar(range(len(layer_sparsity)), layer_sparsity, alpha=0.7)
        axes[1, col].set_title(f"{name}\\nLayer Sparsity")
        axes[1, col].set_xlabel("Layer")
        axes[1, col].set_ylabel("Sparsity")
        axes[1, col].set_xticks(range(len(layer_names)))
        axes[1, col].set_xticklabels(layer_names, rotation=45)

        # Add sparsity values on bars
        for bar, sparsity in zip(bars, layer_sparsity):
            height = bar.get_height()
            axes[1, col].text(
                bar.get_x() + bar.get_width() / 2.0,
                height + 0.01,
                f"{sparsity:.2f}",
                ha="center",
                va="bottom",
            )

    plt.tight_layout()
    plt.show()


# Let's visualize our different pruning approaches
print("🎨 Visualizing Pruning Effects")
print("=" * 50)

# Create models with different pruning strategies for comparison
models_to_compare = {}

# Original model (no pruning)
original_model = MLP().to(device)
original_model.apply(init_weights)
models_to_compare["Original"] = original_model

# Magnitude-based pruned model
magnitude_model = MLP().to(device)
magnitude_model.apply(init_weights)
magnitude_based_pruning(magnitude_model, target_sparsity=0.90)
models_to_compare["Magnitude"] = magnitude_model

# Advanced pruned model
advanced_model_viz = MLP().to(device)
advanced_model_viz.apply(init_weights)
advanced_pruning_strategy(advanced_model_viz, linear_model)
models_to_compare["Advanced"] = advanced_model_viz

# Show weight distributions for the advanced model
print("\\n📊 Weight Distribution Analysis:")
visualize_weight_distribution(advanced_model_viz, "Advanced Pruning Strategy")

# Compare all pruning patterns
print("\\n🔍 Pruning Pattern Comparison:")
analyze_pruning_patterns(models_to_compare)

# Print summary statistics
print("\\n📈 Summary Statistics:")
print("-" * 60)
for name, model in models_to_compare.items():
    sparsity = get_sparsity(model)
    mse = compute_error(model, dummy_loader)
    score_val = score(mse, sparsity)

    print(f"{name:>10}: Sparsity={sparsity:.1%}, MSE={mse:.4f}, Score={score_val:.4f}")

## 9. 🎮 Interactive Exercises

Now it's your turn to experiment and learn! Try these challenges to deepen your understanding of neural network pruning.

### 🎯 Exercise 1: Hyperparameter Tuning

The linear approximation strategy has several hyperparameters. Let's explore how they affect performance:

1. **eps** (bridge coefficient): Try values like 0.1, 0.01, 0.001, 0.0001
2. **gamma** (scaling factor): Try values like 0.1, 0.25, 0.5, 1.0
3. **Training epochs**: Try different numbers of epochs for the linear model

Use the cell below to experiment!


In [None]:
# 🧪 Experiment Playground - Try different hyperparameters!

# Define hyperparameter combinations to test
hyperparams_to_try = [
    {"eps": 0.01, "gamma": 0.1, "epochs": 30, "name": "High eps, Low gamma"},
    {"eps": 0.001, "gamma": 0.25, "epochs": 50, "name": "Medium eps, Medium gamma"},
    {"eps": 0.0001, "gamma": 0.5, "epochs": 70, "name": "Low eps, High gamma"},
    {"eps": 0.001, "gamma": 0.25, "epochs": 100, "name": "Medium eps, More training"},
]

print("🔬 Testing different hyperparameter combinations...")
print("This may take a few minutes...")

results = []

for i, params in enumerate(hyperparams_to_try):
    print(f"\\n📊 Testing combination {i+1}/4: {params['name']}")

    # Train linear model with specified epochs
    linear_model_exp = train_linear_approximation(
        dummy_loader, epochs=params["epochs"], lr=0.01
    )

    # Create and prune model with specified hyperparameters
    test_model = MLP().to(device)
    test_model.apply(init_weights)

    # Apply pruning with custom hyperparameters
    advanced_pruning_strategy(
        test_model, linear_model_exp, eps=params["eps"], gamma=params["gamma"]
    )

    # Evaluate performance
    test_mse = compute_error(test_model, dummy_loader)
    test_sparsity = get_sparsity(test_model)
    test_score = score(test_mse, test_sparsity)

    results.append(
        {
            "name": params["name"],
            "eps": params["eps"],
            "gamma": params["gamma"],
            "epochs": params["epochs"],
            "mse": test_mse,
            "sparsity": test_sparsity,
            "score": test_score,
            "points": points(test_score),
        }
    )

    print(
        f"   Results: MSE={test_mse:.4f}, Sparsity={test_sparsity:.1%}, Score={test_score:.4f}"
    )

# Visualize results
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle("🧪 Hyperparameter Experiment Results", fontsize=16, fontweight="bold")

names = [r["name"] for r in results]
scores = [r["score"] for r in results]
mses = [r["mse"] for r in results]
sparsities = [r["sparsity"] for r in results]
eps_values = [r["eps"] for r in results]

# Plot 1: Scores comparison
bars1 = axes[0, 0].bar(range(len(names)), scores, alpha=0.7, color="skyblue")
axes[0, 0].set_title("Final Scores")
axes[0, 0].set_xlabel("Configuration")
axes[0, 0].set_ylabel("Score")
axes[0, 0].set_xticks(range(len(names)))
axes[0, 0].set_xticklabels([f"Config {i+1}" for i in range(len(names))], rotation=45)

# Add score values on bars
for bar, score in zip(bars1, scores):
    height = bar.get_height()
    axes[0, 0].text(
        bar.get_x() + bar.get_width() / 2.0,
        height + 0.001,
        f"{score:.3f}",
        ha="center",
        va="bottom",
    )

# Plot 2: MSE vs Sparsity trade-off
scatter = axes[0, 1].scatter(
    sparsities, mses, c=scores, cmap="viridis", s=100, alpha=0.7
)
axes[0, 1].set_title("MSE vs Sparsity Trade-off")
axes[0, 1].set_xlabel("Sparsity")
axes[0, 1].set_ylabel("MSE")
plt.colorbar(scatter, ax=axes[0, 1], label="Score")

# Plot 3: Effect of eps parameter
axes[1, 0].scatter(eps_values, scores, s=100, alpha=0.7, color="orange")
axes[1, 0].set_title("Effect of eps Parameter")
axes[1, 0].set_xlabel("eps (bridge coefficient)")
axes[1, 0].set_ylabel("Score")
axes[1, 0].set_xscale("log")

# Plot 4: Points earned
bars2 = axes[1, 1].bar(
    range(len(names)), [points(s) for s in scores], alpha=0.7, color="lightgreen"
)
axes[1, 1].set_title("Assignment Points Earned")
axes[1, 1].set_xlabel("Configuration")
axes[1, 1].set_ylabel("Points (out of 1.5)")
axes[1, 1].set_xticks(range(len(names)))
axes[1, 1].set_xticklabels([f"Config {i+1}" for i in range(len(names))], rotation=45)
axes[1, 1].axhline(y=1.5, color="red", linestyle="--", alpha=0.7, label="Maximum")
axes[1, 1].legend()

plt.tight_layout()
plt.show()

# Print detailed results
print("\\n📈 Detailed Results:")
print("=" * 80)
for i, result in enumerate(results):
    print(f"Configuration {i+1}: {result['name']}")
    print(f"   eps={result['eps']}, gamma={result['gamma']}, epochs={result['epochs']}")
    print(f"   MSE={result['mse']:.4f}, Sparsity={result['sparsity']:.1%}")
    print(f"   Score={result['score']:.4f}, Points={result['points']:.3f}/1.5")
    print()

best_config = max(results, key=lambda x: x["score"])
print(f"🏆 Best configuration: {best_config['name']}")
print(f"   Final score: {best_config['score']:.4f}")
print(f"   Assignment points: {best_config['points']:.3f}/1.5")

## 10. 🚀 Advanced Topics and Extensions

While our linear approximation strategy is powerful, there are many other advanced pruning techniques worth exploring:

### 🎓 Knowledge Distillation
- Train a large "teacher" model, then distill knowledge into a smaller "student"
- The student learns to mimic the teacher's outputs, enabling better compression

### 🎰 Lottery Ticket Hypothesis
- Some sparse subnetworks can achieve similar performance to the full network
- Finding these "winning tickets" can lead to highly efficient models

### 🏗️ Structured Pruning
- Instead of individual weights, remove entire neurons, channels, or layers
- Provides actual speedup (not just memory savings) without specialized hardware

### 🔄 Iterative Pruning
- Gradually increase sparsity over multiple rounds
- Train → Prune → Train → Prune... achieves better results than one-shot pruning

### 🎯 Gradient-Based Pruning
- Use gradient information to determine weight importance
- Methods like SNIP, GraSP, and Synflow analyze gradients for better pruning decisions


## 11. 📖 Summary and Next Steps

Congratulations! 🎉 You've learned how to use advanced techniques for neural network pruning!

### What You've Learned:

1. **🌿 Neural Network Pruning Fundamentals**:
   - Understanding sparsity and its benefits
   - The trade-off between model size and performance
   - Different types of pruning approaches

2. **🧮 Mathematical Framework**:
   - Complex scoring functions balancing multiple objectives
   - The challenge of discrete optimization in neural networks
   - Why simple heuristics often fail

3. **🤖 Linear Approximation Strategy**:
   - Training simplified models as "teachers"
   - Creating sparse pathways through hidden layers
   - Intelligent weight transfer and scaling

4. **💻 Implementation Skills**:
   - PyTorch model manipulation and weight setting
   - Custom evaluation functions and metrics
   - Visualization and analysis of pruning effects

5. **🔬 Experimental Methodology**:
   - Hyperparameter tuning for pruning algorithms
   - Comparing different pruning strategies
   - Understanding the performance trade-offs

### For the Solution Implementation:

You now have all the knowledge to implement the complete solution! The key components are:

```python
def your_pruning_algorithm(model):
    # 1. Train a linear approximation of the target function
    linear_model = train_linear_approximation(data_loader, epochs=300)
    
    # 2. Zero out weights in layers to be modified
    model.layers[0].weight.data.zero_()  # First layer
    model.layers[2].weight.data.zero_()  # Output layer
    # (Keep biases and middle activation as-is)
    
    # 3. Create diagonal "bridge" in first layer
    eps = 0.001
    for i in range(128):
        model.layers[0].weight.data[i, i] = eps
    
    # 4. Transfer linear knowledge to output layer with scaling
    gamma = 0.25
    for j in range(10):
        for i in range(128):
            model.layers[2].weight.data[j, i] = linear_model.weight[j, i] / (eps * gamma)
    
    # 5. Set output biases with correction
    for j in range(10):
        bias_correction = torch.sum(model.layers[2].weight.data[j]) / 2
        model.layers[2].bias.data[j] = linear_model.bias.data[j] - bias_correction
    
    return model
```

### 🚀 Advanced Topics to Explore:

- **Knowledge Distillation**: Using teacher-student training paradigms
- **Lottery Ticket Hypothesis**: Finding sparse subnetworks that train well
- **Structured Pruning**: Removing entire neurons/channels for real speedup
- **Neural Architecture Search**: Automatically finding efficient architectures
- **Quantization**: Reducing precision of weights and activations

### 📚 Useful Resources:

- 📖 [The Lottery Ticket Hypothesis Paper](https://arxiv.org/abs/1803.03635)
- 🛠️ [PyTorch Pruning Tutorial](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html)
- 📑 [Magnitude-based Pruning Paper](https://arxiv.org/abs/1506.02626)
- 🎯 [SNIP: Single-shot Network Pruning](https://arxiv.org/abs/1810.02340)
- 🧠 [What's Hidden in a Randomly Weighted Neural Network?](https://arxiv.org/abs/1911.13299)

**Good luck with your implementation!** 🌟

Remember: The key insight is that many neural networks are over-parameterized, and clever approximation strategies can maintain performance while dramatically reducing model complexity. Your linear approximation approach leverages the power of well-trained simple models to guide the sparsification of complex networks!
