# Lesson 09: Quantization for Inference

Quantization is a way to store and run neural networks with fewer bits. The goal is to **trade a small amount of accuracy for much better speed and memory usage**.

**Common formats:**
- **FP32**: 32-bit floating point (baseline, largest memory).
- **FP16 / BF16**: 16-bit floating point (half memory, often faster on GPUs).
- **INT8**: 8-bit integers (much smaller, often faster on CPUs).
- **INT4**: 4-bit integers (very small, needs specialized kernels, often used on GPUs).

In practice, quantization can reduce:
- **Model size** (parameters * bytes)
- **Memory bandwidth** (less data moved)
- **Latency** (especially on CPU)

But it can also cause **accuracy drops** or weird output, especially for smaller models or more aggressive quantization.

In this notebook, we will use **PyTorch dynamic quantization** for a reliable CPU demo. Dynamic quantization is simple: it converts `nn.Linear` layers to INT8 and does the quantization at runtime (activations are quantized on the fly).

We will also briefly mention GPU 4-bit quantization (bitsandbytes), but keep it optional.


In [None]:
# Setup + imports
import time
import os
from typing import Tuple

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Helpers for timing and size estimates
def estimate_model_size_bytes(model: torch.nn.Module, dtype_bytes: int) -> int:
    # Rough estimate: number of parameters * bytes per parameter
    num_params = sum(p.numel() for p in model.parameters())
    return num_params * dtype_bytes

def format_size(num_bytes: int) -> str:
    # Human-readable size
    for unit in ["B", "KB", "MB", "GB"]:
        if num_bytes < 1024:
            return f"{num_bytes:.2f} {unit}"
        num_bytes /= 1024
    return f"{num_bytes:.2f} TB"

def timed_generate(model, tokenizer, prompt: str, max_new_tokens: int, device: str) -> Tuple[str, float]:
    # Time text generation. This measures end-to-end latency for a single prompt.
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    start = time.time()
    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
    end = time.time()
    text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return text, end - start

# Use CPU by default, but allow CUDA if available for fp16/fp32 comparisons
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Load tokenizer + model
# We use a small pretrained model for a quick demo.
model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Move to GPU if available for baseline fp16/fp32.
model = model.to(device)
model.eval()

print("Model loaded")

In [None]:
# Baseline inference test (FP32 or FP16)
prompt = "Explain quantization in simple terms."
max_new_tokens = 40

# If on CUDA, we can optionally use fp16 to show speed/memory tradeoff.
if device == "cuda":
    model_fp16 = model.half()
    text_fp16, t_fp16 = timed_generate(model_fp16, tokenizer, prompt, max_new_tokens, device)
    size_fp16 = estimate_model_size_bytes(model_fp16, dtype_bytes=2)
    print("--- FP16 (GPU) ---")
    print(text_fp16)
    print(f"Latency: {t_fp16:.3f}s")
    print(f"Size (est): {format_size(size_fp16)}")

# Always run fp32 baseline (CPU or GPU)
model_fp32 = model.float()
text_fp32, t_fp32 = timed_generate(model_fp32, tokenizer, prompt, max_new_tokens, device)
size_fp32 = estimate_model_size_bytes(model_fp32, dtype_bytes=4)
print("--- FP32 ---")
print(text_fp32)
print(f"Latency: {t_fp32:.3f}s")
print(f"Size (est): {format_size(size_fp32)}")

In [None]:
# Dynamic quantization on CPU
# This is the most reliable built-in quantization demo in PyTorch.
# It replaces nn.Linear with an int8-quantized version.

# Ensure we are on CPU for quantized inference
cpu_device = "cpu"
model_cpu = AutoModelForCausalLM.from_pretrained(model_name)
model_cpu.eval()

# Apply dynamic quantization to Linear layers
quantized_model = torch.quantization.quantize_dynamic(
    model_cpu,
    {torch.nn.Linear},
    dtype=torch.qint8,
)

# Run inference and time it
text_int8, t_int8 = timed_generate(quantized_model, tokenizer, prompt, max_new_tokens, cpu_device)
size_int8 = estimate_model_size_bytes(quantized_model, dtype_bytes=1)

print("--- INT8 (dynamic, CPU) ---")
print(text_int8)
print(f"Latency: {t_int8:.3f}s")
print(f"Size (est): {format_size(size_int8)}")

## Summary: What changed?

Dynamic quantization only changes **Linear layers** and only on **CPU**. This is why it is so reliable for a teaching demo: it uses mature CPU kernels in PyTorch.

Typical results (your numbers will vary):
- **Model size** decreases roughly 4x (FP32 -> INT8).
- **Latency** can improve on CPU due to cheaper math and less memory traffic.
- **Output quality** usually stays close to FP32 for small prompts, but can degrade on longer generations.

### Why CPU quantization is easier in vanilla PyTorch
- PyTorch ships optimized CPU int8 kernels (like FBGEMM).
- GPU quantization needs specialized libraries and kernels.
- GPU 4-bit often relies on external packages (e.g., bitsandbytes).


## Optional: GPU 4-bit notes (bitsandbytes)

If you want 4-bit quantization on GPU, you typically use:
- **bitsandbytes** with Hugging Face Transformers
- `load_in_4bit=True` with a suitable quantization config

This is powerful but depends on CUDA versions and specific GPU support.
It is optional because it adds install complexity and can be fragile across systems.


## Scaling notes (production)

Real systems often go beyond simple post-training quantization:
- **Quantization-aware training (QAT)** trains the model to tolerate low precision.
- **Specialized kernels** (e.g., fused attention / matmul) are required for speed.
- **Serving stacks** (like TensorRT, ONNX Runtime, or custom CUDA kernels) often give the biggest gains.

The key idea: quantization alone helps, but performance depends heavily on the runtime.


## Exercises

1) Try a different prompt and compare outputs between FP32 and INT8.
2) Increase `max_new_tokens` and see if accuracy differences grow.
3) Measure average latency over 5 runs (and compute the mean).
4) Try a smaller model (e.g., `distilgpt2`) and compare speed/quality.
5) Research one production quantization method and summarize it.
