# Lesson 4: Model Pruning

**Module 4: Model Development & Optimization**  
**Estimated Time**: 1 hour  
**Difficulty**: Advanced

---

## ðŸŽ¯ Learning Objectives

By the end of this lesson, you will:

âœ… Understand the difference between Structured and Unstructured Pruning  
âœ… Learn why sparse matrices don't always speed up inference  
âœ… Implement Pruning using `torch.nn.utils.prune`  
âœ… Answer interview questions on model compression  

---

## ðŸ“š Table of Contents

1. [Why Prune? (The Lottery Ticket Hypothesis)](#1-why)
2. [Unstructured vs Structured Pruning](#2-types)
3. [Hands-On: Pruning a PyTorch Linear Layer](#3-hands-on)
4. [Interview Preparation](#4-interview-questions)

---

## 1. Why Prune?

**Insight**: Neural Networks are massively over-parameterized. 90% of the weights in a ResNet-50 might not be contributing much to the prediction.

**Pruning**: Setting "unimportant" weights to Zero.

**Goal**: Reduce model size (Memory) and inference time (Latency).

## 2. Unstructured vs Structured Pruning

### Unstructured Pruning
- Set individual weights to 0 (e.g., "Remove any weight < 0.01").
- **Result**: A random sparse matrix.
- **Problem**: GPUs/CPUs rely on dense matrix multiplication. Random zeros don't help speed unless you have specialized hardware.

### Structured Pruning
- Remove entire **Rows, Columns, or Channels** (e.g., "Remove kernel 5 in Conv layer 1").
- **Result**: A smaller dense matrix.
- **Benefit**: Real speedup on standard hardware.
- **Risk**: Higher impact on accuracy (removing a whole channel is aggressive).

## 3. Hands-On: Pruning a PyTorch Linear Layer

Using PyTorch's built-in APIs.

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

# 1. Create a model
module = nn.Linear(5, 5)
print("Original Weights:")
print(module.weight)

# 2. Unstructured Pruning (L1 Unstructured)
# Prune 30% of connections with lowest magnitude
prune.l1_unstructured(module, name="weight", amount=0.3)

print("\nPruned Weights (Note the Mask):")
print(module.weight)

# 3. Check Sparsity
zeros = torch.sum(module.weight == 0)
total = module.weight.nelement()
print(f"\nSparsity: {zeros}/{total} ({zeros/total:.0%})")

# 4. Making it Permanent
# Pruning in PyTorch applies a mask. To save space, we must apply it permanently.
prune.remove(module, 'weight')
print("\nPruning made permanent (Mask removed, weights are actually zeroed)")

## 4. Interview Preparation

### Common Questions

#### Q1: "If I prune 50% of weights, is my model 2x faster?"
**Answer**: "Usually NOT, if using unstructured pruning. Standard matrix libraries (BLAS/CuBLAS) still multiply the zeros. You only get speedups with **Structured Pruning** (changing physical shape) or specialized sparse inference engines (e.g., Neural Magic)."

#### Q2: "What is Iterative Pruning?"
**Answer**: "Instead of pruning 50% at once (which kills accuracy), prune 10%, retrain (fine-tune) to recover accuracy, prune another 10%, retrain, and repeat. This allows the network to adapt to the loss of capacity."