# Quantization and Pruning

### Objective
This notebook explores two of the most widely used **model compression techniques** : **Quantization** and **Pruning** : to make deep learning models faster and smaller without significantly compromising accuracy.

You’ll learn:
- What quantization and pruning are
- How to apply them in PyTorch
- How they impact model performance and size
- How to fine-tune pruned and quantized models

## ⚙️ Setup

Let's import PyTorch and create a simple neural network for demonstration purposes.

In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Simple model
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

model = SimpleNN()
print(model)

## 🧮 1. Quantization

**Quantization** reduces the numerical precision of model parameters and activations (e.g., from 32-bit floating-point to 8-bit integer), which:
- Lowers memory usage
- Speeds up inference
- Reduces energy consumption

PyTorch supports several types:
- **Dynamic Quantization** (post-training)
- **Static Quantization** (calibrated on data)
- **Quantization-Aware Training (QAT)**

### 🔹 Dynamic Quantization
Dynamic quantization quantizes weights and dynamically quantizes activations at runtime — it's fast and works well for linear and recurrent layers.

In [ ]:
import torch.quantization as quant

model_fp32 = model.eval()

# Apply dynamic quantization
model_int8 = quant.quantize_dynamic(
    model_fp32, {nn.Linear}, dtype=torch.qint8
)

print("✅ Dynamic quantization done!")
print(model_int8)

### 🔹 Checking Size Reduction
Let's compare model sizes before and after quantization.

In [ ]:
import io, sys

def get_size_of_model(model):
    buffer = io.BytesIO()
    torch.save(model.state_dict(), buffer)
    return buffer.getbuffer().nbytes / 1e6  # in MB

size_fp32 = get_size_of_model(model_fp32)
size_int8 = get_size_of_model(model_int8)

print(f"FP32 model size: {size_fp32:.2f} MB")
print(f"INT8 model size: {size_int8:.2f} MB")

💡 **Observation:** The INT8 model is often 3–4× smaller than the FP32 model, with minimal accuracy drop.

### 🔹 Static Quantization (with calibration)
Static quantization requires a small dataset to calibrate the model’s activations.

In [ ]:
model_to_quantize = SimpleNN().eval()

model_prepared = torch.quantization.prepare(model_to_quantize)

# Fake calibration with random data
calib_data = torch.randn(100, 784)
model_prepared(calib_data)

model_int8_static = torch.quantization.convert(model_prepared)
print("✅ Static quantization complete!")

## ✂️ 2. Pruning

**Pruning** removes less important connections (weights or neurons) from a neural network, making it sparse.

Types of pruning:
- **Unstructured Pruning:** Individual weights are zeroed out.
- **Structured Pruning:** Entire filters or neurons are removed (for hardware acceleration).

### 🔹 Unstructured Pruning Example

In [ ]:
import torch.nn.utils.prune as prune

# Apply pruning to 30% of weights in fc1
prune.l1_unstructured(model.fc1, name='weight', amount=0.3)

print("Sparsity in fc1 weight tensor:",
      100.0 * float(torch.sum(model.fc1.weight == 0)) / model.fc1.weight.nelement(), "%")

### 🔹 Removing the Pruning Reparametrization
After pruning, you can make pruning permanent by removing the masks.

In [ ]:
prune.remove(model.fc1, 'weight')
print("✅ Pruning mask removed, weights permanently pruned!")

### 🔹 Structured Pruning Example
We can prune entire channels or filters (useful for CNNs).

In [ ]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.fc = nn.Linear(32 * 26 * 26, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = x.view(x.size(0), -1)
        return self.fc(x)

cnn = ConvNet()
prune.ln_structured(cnn.conv1, name='weight', amount=0.3, n=2, dim=0)
print("Structured pruning applied on conv1 filters!")

### 🔹 Fine-Tuning the Pruned Model
After pruning, fine-tuning (retraining) helps recover accuracy.

In [ ]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Example fine-tuning step
for _ in range(3):
    data = torch.randn(16, 784)
    target = torch.randint(0, 10, (16,))
    output = model(data)
    loss = F.cross_entropy(output, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("✅ Fine-tuning complete!")

## 📊 3. Comparing Compression Techniques

| Technique | What It Does | Compression | Accuracy Drop | Use Case |
|------------|---------------|--------------|----------------|-----------|
| Quantization | Lower precision of weights | 4x smaller | Low | Deployment on CPUs/Edge |
| Unstructured Pruning | Zero out small weights | 30-90% smaller | Moderate | Research or custom hardware |
| Structured Pruning | Remove entire filters | High | Low–Moderate | CNN deployment |
| Combined | Quantization + Pruning | Very High | Low–Moderate | Efficient inference |

## ✅ Summary

In this notebook, you learned:
- What quantization and pruning are
- How to apply dynamic & static quantization in PyTorch
- How to prune both individual weights and channels
- How to fine-tune a pruned model

Together, these techniques form the foundation of **energy-efficient AI deployment**.

🚀 **Next:** Try combining both quantization and pruning for your trained model and evaluate performance gains.