
# TinyML Seminar — Quantization Hands‑On (PyTorch)

**Facilitator:** Obed Mogaka  

**Department Seminar — Mixed Audience (Students & Faculty)**  

**Focus:** Quantization (PTQ, QAT, Custom Quantizer)  

**Runtime:** ~3 hours (or two 90‑minute sessions)

> This notebook is intentionally scaffolded with **explanations and TODOs**. 
> You will fill in the code live during the session. Keep it lightweight and interactive.



## Agenda & Learning Outcomes

**Modules**
1. Visualization — FP32 vs INT* discretization + error metrics  
2. DNN Primer — CIFAR‑10, small CNN, MobileNetV2 overview; histograms, size, latency  
3. PTQ — PyTorch built‑ins on small CNN & MobileNetV2; compare metrics  
4. QAT — Built‑in QAT on small CNN (and/or MobileNetV2)  
5. User‑Defined Quantizer — Manual linear quantization + STE for QAT (advanced)

**You will be able to:**
- Explain quantization intuitively and mathematically.
- Run PTQ and QAT in PyTorch and interpret the trade‑offs.
- Implement and experiment with a custom quantizer.



## 0. Setup & Environment

> Keep installs minimal for CPU‑only environments. Pre‑download CIFAR‑10 to avoid bandwidth surprises.


In [None]:

# TODO: (optional) installs if running in a fresh environment
# %pip install torch torchvision matplotlib numpy --quiet

# TODO: standard imports (keep here for participants to run once)
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.optim as optim
# from torch.utils.data import DataLoader
# import torchvision
# from torchvision import datasets, transforms, models
# import numpy as np
# import time
# import os
# import math
# import matplotlib.pyplot as plt

# SEED = 42
# torch.manual_seed(SEED)
# np.random.seed(SEED)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False



---
# Module 1 — Quantization Visualization

**Goal:** Build intuition with numbers **before** models.  
**Concepts:** Discretization, step size (Δ), clipping, rounding, error metrics (MSE, max error).

### What to do
1. Generate continuous numeric data (e.g., sine wave).
2. Quantize to 16‑, 8‑, 4‑, 2‑bit with **uniform** quantization.
3. Plot original vs quantized; show error curve; histogram of errors.


In [None]:

# TODO: Generate a smooth signal and visualize quantization
# x = np.linspace(-1, 1, 1000)
# y = np.sin(2 * np.pi * x)  # or try gaussian noise

# def uniform_quantize(arr, bits):
#     # Symmetric uniform quantization (conceptual)
#     # TODO: compute min/max, step size Δ, quantize-dequantize, return y_hat and Δ
#     pass

# for bits in [16, 8, 4, 2]:
#     # TODO: y_hat, delta = uniform_quantize(y, bits)
#     # TODO: compute MSE, max error, plot overlays
#     # TODO: plot histogram of (y - y_hat)
#     pass



---
# Module 2 — DNN Primer (CIFAR‑10 + Small CNN + MobileNetV2)

**Goal:** Introduce a fast, interpretable DNN to be quantized later.  
**Dataset:** CIFAR‑10 (32×32 RGB).  
**Models:** 
- Small CNN (custom) — train a few epochs for ~60‑70% acc (CPU‑friendly).
- Pretrained MobileNetV2 — use for PTQ/QAT comparison and realism.

### What to do
1. Load CIFAR‑10 with basic transforms.
2. Define a **small CNN** (`Conv → ReLU → Pool → Conv → ReLU → FC`) < 1M params.
3. Train briefly; evaluate test accuracy and latency (CPU).
4. Inspect and **visualize weight histograms** (conv layers & FC).
5. Print **model size** (saved .pt) and basic metrics.
6. Load **MobileNetV2** from `torchvision.models` and run a quick eval on CIFAR‑10 (or subset). Print its size.


In [None]:

# TODO: CIFAR-10 dataset & dataloaders
# transform_train = transforms.Compose([
#     transforms.ToTensor(),
# ])
# transform_test = transforms.Compose([
#     transforms.ToTensor(),
# ])
# train_ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
# test_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
# train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)
# test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2)

# classes = train_ds.classes
# classes


In [None]:

# TODO: Define a small CNN model (keep it <1M params; simple and didactic)
# class SmallCNN(nn.Module):
#     def __init__(self):
#         super().__init__()
#         # TODO: define conv layers, pooling, and a small classifier head
#     def forward(self, x):
#         # TODO: forward pass
#         pass

# model = SmallCNN()
# model


In [None]:

# TODO: Train the small CNN briefly (1-3 epochs) and report accuracy
# device = torch.device('cpu')
# model.to(device)

# optimizer = optim.Adam(model.parameters(), lr=1e-3)
# criterion = nn.CrossEntropyLoss()

# def train_one_epoch(model, loader):
#     # TODO: simple training loop
#     pass

# def evaluate(model, loader):
#     # TODO: compute accuracy; also measure avg forward latency for a few batches
#     pass

# for epoch in range(1):  # increase to 2-3 if time allows
#     # train_one_epoch(model, train_loader)
#     # acc, lat_ms = evaluate(model, test_loader)
#     # print(f"Epoch {epoch}: acc={acc:.2f}% latency={lat_ms:.2f} ms")
#     pass

# # TODO: Save FP32 model and report file size
# # torch.save(model.state_dict(), "smallcnn_fp32.pt")
# # print("SmallCNN FP32 size (MB):", os.path.getsize("smallcnn_fp32.pt")/1e6)


In [None]:

# TODO: Visualize weight histograms (conv layers & fc)
# with torch.no_grad():
#     # for each named parameter, if 'weight' in name: collect and plot histogram
#     pass


In [None]:

# TODO: Load pretrained MobileNetV2 and adapt to CIFAR-10 (optional: replace classifier head)
# mobilenet = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
# # Optional: replace final classifier for CIFAR-10; or evaluate feature extractor qualitatively
# # TODO: quick eval / dummy pass / size reporting
# # torch.save(mobilenet.state_dict(), "mobilenetv2_fp32.pt")
# # print("MobileNetV2 FP32 size (MB):", os.path.getsize("mobilenetv2_fp32.pt")/1e6)



---
# Module 3 — Post‑Training Quantization (PTQ)

**Goal:** Use PyTorch built‑ins to quantize FP32 models and compare **accuracy**, **latency**, and **model size**.

### What to do
1. PTQ on the **Small CNN** (static or dynamic quantization depending on layer support).
2. PTQ on **MobileNetV2** (dynamic quantization typically).
3. Build a comparison table:
   - FP32 vs INT8: accuracy (%), latency (ms), size (MB).

> Tip: Dynamic quantization targets Linear/LSTM layers; static quant requires calibration on representative data.


In [None]:

# TODO: PTQ for Small CNN (static or dynamic)
# from torch.ao.quantization import quantize_dynamic, get_default_qconfig, prepare, convert

# # Option A: Dynamic quantization (works well for Linear layers)
# # smallcnn_int8 = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
# # torch.save(smallcnn_int8.state_dict(), "smallcnn_int8_ptq.pt")

# # Option B: Static quantization (requires qconfig, fuse modules, prepare, calibration, convert)
# # model.qconfig = get_default_qconfig('fbgemm')
# # fused = torch.ao.quantization.fuse_modules(...)  # TODO if applicable
# # prepared = prepare(model, inplace=False)
# # # Calibration: run a few batches from train_loader through 'prepared' in eval mode
# # # convert
# # quantized_smallcnn = convert(prepared)

# # TODO: Evaluate metrics (accuracy, latency, size)
# # print("SmallCNN INT8 PTQ size (MB):", os.path.getsize("smallcnn_int8_ptq.pt")/1e6)


In [None]:

# TODO: PTQ for MobileNetV2 (usually dynamic quantization on Linear layers)
# # mobilenet_int8 = quantize_dynamic(mobilenet, {nn.Linear}, dtype=torch.qint8)
# # torch.save(mobilenet_int8.state_dict(), "mobilenetv2_int8_ptq.pt")
# # print("MobileNetV2 INT8 PTQ size (MB):", os.path.getsize("mobilenetv2_int8_ptq.pt")/1e6)

# # TODO: quick accuracy and latency check (on CIFAR-10 subset)


In [None]:

# TODO: Summarize PTQ results in a small table/dict for display
# results_ptq = {
#     'Model': ['SmallCNN FP32', 'SmallCNN INT8 PTQ', 'MobileNetV2 FP32', 'MobileNetV2 INT8 PTQ'],
#     'Accuracy_%': [None, None, None, None],
#     'Latency_ms': [None, None, None, None],
#     'Size_MB': [None, None, None, None],
# }
# results_ptq



---
# Module 4 — Quantization‑Aware Training (QAT)

**Goal:** Show that training with quantization awareness recovers accuracy relative to PTQ.

### What to do
1. Configure QAT on **Small CNN** using PyTorch QAT APIs.
2. Fine‑tune for 1–2 epochs with fake quantization enabled.
3. Convert to a quantized model and compare **accuracy**, **latency**, and **size** with FP32 and PTQ.

> Tip: Keep QAT minimal for the session; even a small amount of fine‑tuning is illustrative.


In [None]:

# TODO: QAT flow (PyTorch built-ins)
# from torch.ao.quantization import get_default_qat_qconfig, prepare_qat

# # model_fp32_for_qat = SmallCNN()
# # model_fp32_for_qat.load_state_dict(torch.load("smallcnn_fp32.pt"))
# # model_fp32_for_qat.train()
# # model_fp32_for_qat.qconfig = get_default_qat_qconfig('fbgemm')

# # prepared_qat = prepare_qat(model_fp32_for_qat, inplace=False)
# # # Train for a small number of epochs with standard optimizer/loss
# # # After fine-tune: convert to quantized
# # quantized_qat_model = convert(prepared_qat.eval(), inplace=False)

# # TODO: Evaluate metrics and record in a results table


In [None]:

# TODO: Summarize QAT results
# results_qat = {
#     'Model': ['SmallCNN FP32', 'SmallCNN INT8 PTQ', 'SmallCNN INT8 QAT'],
#     'Accuracy_%': [None, None, None],
#     'Latency_ms': [None, None, None],
#     'Size_MB': [None, None, None],
# }
# results_qat



---
# Module 5 — User‑Defined Quantizer (Advanced)

**Goal:** Implement linear quantization by hand and make it QAT‑aware with **Straight‑Through Estimator (STE)**.

### Concepts to Emphasize
- Linear quantization: \( \hat{x} = \mathrm{round}(x / \Delta) \cdot \Delta \)
- Step size: \( \Delta = \frac{\beta - \alpha}{2^b - 1} \) given clipping range \([\alpha, \beta]\)
- **Symmetric vs Asymmetric**, **per‑tensor vs per‑channel**
- STE passes gradients as identity during backprop through the discretization

### What to do
1. Write a pure‑PyTorch function to quantize **weights** given bit‑width **b**, range, and rounding mode.
2. Wrap it in an `nn.Module` that applies quantization in `forward()` for **QAT** (with STE).
3. Insert into Small CNN and fine‑tune briefly; compare results across bit‑widths.
4. Visualize quantized vs original weight histograms and report accuracy/latency/size.

> Keep the first pass **per‑tensor symmetric** for clarity, then optionally extend to **per‑channel**.


In [None]:

# TODO: Manual weight quantizer (linear, symmetric)
# def linear_quantize_tensor(x, bits=8, clip_ratio=0.999):
#     # TODO: choose alpha/beta from observed range or percentile; compute Δ; quantize/dequantize
#     pass


In [None]:

# TODO: STE-enabled module for QAT
# class STEQuantizer(nn.Module):
#     def __init__(self, bits=8):
#         super().__init__()
#         self.bits = bits
#     def forward(self, x):
#         # TODO: quantize in forward; identity gradient in backward (via .detach() trick or custom autograd.Function)
#         pass


In [None]:

# TODO: Integrate STEQuantizer into SmallCNN (e.g., wrapping weights or as fake-quant nodes)
# class SmallCNN_QAT(nn.Module):
#     def __init__(self, bits=8):
#         super().__init__()
#         # TODO: define layers; include STEQuantizer modules where appropriate
#     def forward(self, x):
#         # TODO: apply quantization to weights/activations (start with weights)
#         pass


In [None]:

# TODO: Experiment: train/evaluate with bits in [8, 4, 2]; log accuracy and errors
# for b in [8, 4, 2]:
#     # model_q = SmallCNN_QAT(bits=b)
#     # TODO: brief fine-tuning + evaluation
#     pass



---
## Wrap‑Up & Discussion

- **PTQ vs QAT:** When is each appropriate? What accuracy trade‑offs did you observe?
- **Custom quantization:** How close did manual QAT get to built‑in QAT?
- **Bit‑width sensitivity:** Which layers were most sensitive? (Conv vs FC)
- **Deployment:** How would these INT8 models map to MCUs/NPUs/FPGAs?

> Next steps (beyond this notebook): add **pruning** and **mixed‑precision** experiments.



---
*Prepared on:* 2025-10-15 07:54 UTC  
*Notes:* Replace TODO blocks with live coding during the seminar. Keep iterations short and show results early.
