# Neural Networks Quantization - Post Training Quantization (PTQ) using PyTorch 2 Export Quantization and XNNPACK Quantizer

## Basic Setup

In [1]:
import os
import time
import warnings
from packaging import version
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.quantization import quantize_dynamic
from torch.ao.quantization import get_default_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.utils.data import DataLoader, Subset

# ignores irrelevant warning, see: https://github.com/pytorch/pytorch/issues/149829
warnings.filterwarnings("ignore", message=".*TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support.*")

# ignores irrelevant warning, see: https://github.com/tensorflow/tensorflow/issues/77293
warnings.filterwarnings("ignore", message=".*erase_node(.*) on an already erased node.*")

print(f"PyTorch Version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device used: {device.type}")

skip_cpu = False # change to True to skip the slow checks on CPU
print(f"Should skip CPU evaluations: {skip_cpu}")

PyTorch Version: 2.8.0.dev20250319+cu128
Device used: cuda
Should skip CPU evaluations: False


## Get CIFAR-10 train and test sets

In [2]:
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(
    datasets.CIFAR10(root="./data", train=True, download=True, transform=transform),
    batch_size=128, shuffle=True
)

test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(
    datasets.CIFAR10(root="./data", train=False, download=True, transform=transform),
    batch_size=128,
    shuffle=False,
    num_workers=2,
    drop_last=True,
)

calibration_dataset = Subset(train_dataset, range(256))
calibration_loader = DataLoader(calibration_dataset, batch_size=128, shuffle=False)

## Adjust ResNet18 network for CIFAR-10 dataset

In [3]:
def get_resnet18_for_cifar10():
    """
    Returns a ResNet-18 model adjusted for CIFAR-10:
    - 3x3 conv with stride 1
    - No max pooling
    - 10 output classes
    """
    model = models.resnet18(weights=None, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model.to(device)

model_to_quantize = get_resnet18_for_cifar10()

## Define Train and Evaluate functions

In [4]:
def train(model, loader, epochs, lr=0.01, save_path="model.pth", silent=False):
    """
    Trains a model with SGD and cross-entropy loss.
    Loads from save_path if it exists.
    """
    
    try:
        model.train()
    except NotImplementedError:
        torch.ao.quantization.move_exported_model_to_train(model)
    
    if os.path.exists(save_path):
        if not silent:
            print(f"Model already trained. Loading from {save_path}")
        model.load_state_dict(torch.load(save_path))
        return

    # no saved model found. training from given model state

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    for epoch in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()
        if not silent:
            print(f"Epoch {epoch+1}: loss={loss.item():.4f}")
            evaluate(model, f"Epoch {epoch+1}")
            try:
                model.train()
            except NotImplementedError:
                torch.ao.quantization.move_exported_model_to_train(model)

    if save_path:
        torch.save(model.state_dict(), save_path)
        if not silent:
            print(f"Training complete. Model saved to {save_path}")

def evaluate(model, tag):
    """
    Evaluates the model on test_loader and prints accuracy.
    """
    
    try:
        model.eval()
    except NotImplementedError:
        model = torch.ao.quantization.move_exported_model_to_eval(model)

    model.to(device)
    correct = total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    accuracy = correct / total
    print(f"Accuracy ({tag}): {accuracy*100:.2f}%")

## Define helper functions to measure latency and model size

In [5]:
class Timer:
    """
    A simple timer utility for measuring elapsed time in milliseconds.

    Supports both GPU and CPU timing:
    - If CUDA is available, uses torch.cuda.Event for accurate GPU timing.
    - Otherwise, falls back to wall-clock CPU timing via time.time().

    Methods:
        start(): Start the timer.
        stop(): Stop the timer and return the elapsed time in milliseconds.
    """
    
    def __init__(self):
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.starter = torch.cuda.Event(enable_timing=True)
            self.ender = torch.cuda.Event(enable_timing=True)

    def start(self):
        if self.use_cuda:
            self.starter.record()
        else:
            self.start_time = time.time()

    def stop(self):
        if self.use_cuda:
            self.ender.record()
            torch.cuda.synchronize()
            return self.starter.elapsed_time(self.ender)  # ms
        else:
            return (time.time() - self.start_time) * 1000  # ms

def estimate_latency(model, example_inputs, repetitions=50):
    """
    Returns avg and std inference latency (ms) over given runs.
    """
    
    timer = Timer()
    timings = np.zeros((repetitions, 1))

    # warm-up
    for _ in range(5):
        _ = model(example_inputs)

    with torch.no_grad():
        for rep in range(repetitions):
            timer.start()
            _ = model(example_inputs)
            elapsed = timer.stop()
            timings[rep] = elapsed

    return np.mean(timings), np.std(timings)

def estimate_latency_full(model, tag, skip_cpu):
    """
    Prints model latency on GPU and (optionally) CPU.
    """

    # estimate latency on CPU
    if not skip_cpu:
        example_input = torch.rand(128, 3, 32, 32).cpu()
        model.cpu()
        latency_mu, latency_std = estimate_latency(model, example_input)
        print(f"Latency ({tag}, on CPU): {latency_mu:.2f} ± {latency_std:.2f} ms")

    # estimate latency on GPU
    example_input = torch.rand(128, 3, 32, 32).cuda()
    model.cuda()
    latency_mu, latency_std = estimate_latency(model, example_input)
    print(f"Latency ({tag}, on GPU): {latency_mu:.2f} ± {latency_std:.2f} ms")

def print_size_of_model(model, tag=""):
    """
    Prints model size (MB).
    """
    
    torch.save(model.state_dict(), "temp.p")
    size_mb_full = os.path.getsize("temp.p") / 1e6
    print(f"Size ({tag}): {size_mb_full:.2f} MB")
    os.remove("temp.p")

## Train full model

In [6]:
train(model_to_quantize, train_loader, epochs=15, save_path="full_model.pth")

Epoch 1: loss=0.8905
Accuracy (Epoch 1): 58.78%
Epoch 2: loss=0.7371
Accuracy (Epoch 2): 67.51%
Epoch 3: loss=0.7273
Accuracy (Epoch 3): 74.37%
Epoch 4: loss=0.5133
Accuracy (Epoch 4): 74.88%
Epoch 5: loss=0.2877
Accuracy (Epoch 5): 74.88%
Epoch 6: loss=0.3319
Accuracy (Epoch 6): 67.36%
Epoch 7: loss=0.2166
Accuracy (Epoch 7): 73.47%
Epoch 8: loss=0.1160
Accuracy (Epoch 8): 76.27%
Epoch 9: loss=0.1579
Accuracy (Epoch 9): 74.73%
Epoch 10: loss=0.0222
Accuracy (Epoch 10): 77.65%
Epoch 11: loss=0.0090
Accuracy (Epoch 11): 78.60%
Epoch 12: loss=0.0024
Accuracy (Epoch 12): 79.19%
Epoch 13: loss=0.0005
Accuracy (Epoch 13): 80.18%
Epoch 14: loss=0.0002
Accuracy (Epoch 14): 80.48%
Epoch 15: loss=0.0005
Accuracy (Epoch 15): 80.53%
Training complete. Model saved to full_model.pth


## Evaluate full model

In [7]:
# get full model size
print_size_of_model(model_to_quantize, "full")

# evaluate full accuracy
accuracy_full = evaluate(model_to_quantize, 'full')

# estimate full model latency
estimate_latency_full(model_to_quantize, 'full', skip_cpu)

Size (full): 44.77 MB
Accuracy (full): 80.53%
Latency (full, on CPU): 750.06 ± 76.23 ms
Latency (full, on GPU): 16.76 ± 0.32 ms


## Post Training Quantization (PTQ) with XNNPACK Quantizer

**Warning:** The following quantization code is not for the faint of heart. Since this is PyTorch’s third iteration of a quantization API, the only thing we can say with confidence is that it will change again. That said, it’s been tested and confirmed to work on the following PyTorch versions:
- 2.8.0.dev20250319+cu128
- 2.6.0+cu124
- 2.4.1+cu124

In [8]:
from torch.ao.quantization.quantize_pt2e import (
  prepare_pt2e,
  convert_pt2e,
)

from torch.ao.quantization.quantizer.xnnpack_quantizer import (
  XNNPACKQuantizer,
  get_symmetric_quantization_config,
)

# batch of 128 images, each with 3 color channels and 32x32 resolution (CIFAR-10)
example_inputs = (torch.rand(128, 3, 32, 32).to(device),)

# export the model to a standardized format before quantization
if version.parse(torch.__version__) >= version.parse("2.5"): # for pytorch 2.5+
    exported_model  = torch.export.export_for_training(model_to_quantize, example_inputs).module()
else: # for pytorch 2.4
    from torch._export import capture_pre_autograd_graph
    exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs) 

# quantization setup for XNNPACK quantizer
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config())

# preparing for PTQ by folding batch-norm into preceding conv2d operators, and inserting observers in appropriate places
prepared_model = prepare_pt2e(exported_model, quantizer)

# run inference on calibration data to collect activation stats needed for activation quantization
def calibrate(model, data_loader):
    torch.ao.quantization.move_exported_model_to_eval(model)
    with torch.no_grad():
        for image, target in data_loader:
            model(image.to(device))
calibrate(prepared_model, calibration_loader)

# converts calibrated model to a quantized model
quantized_model = convert_pt2e(prepared_model)

# export again to remove unused weights after quantization
if version.parse(torch.__version__) >= version.parse("2.5"): # for pytorch 2.5+
    quantized_model = torch.export.export_for_training(quantized_model, example_inputs).module()
else: # for pytorch 2.4
    quantized_model = capture_pre_autograd_graph(quantized_model, example_inputs)



## Evaluate quantized model

In [9]:
# get quantized model size
print_size_of_model(quantized_model, "quantized")

# evaluate quantized accuracy
accuracy_full = evaluate(quantized_model, 'quantized')

# estimate quantized model latency
estimate_latency_full(quantized_model, 'quantized', skip_cpu)

Size (quantized): 11.20 MB
Accuracy (quantized): 80.41%
Latency (quantized, on CPU): 2310.11 ± 116.63 ms
Latency (quantized, on GPU): 43.51 ± 0.23 ms


## References
1. PyTorch Documentation: [Quantization](https://pytorch.org/docs/stable/quantization.html)
2. PyTorch Documentation: [PyTorch 2 Export Post Training Quantization](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html)
3. PyTorch Documentation: [PyTorch 2 Export Quantization-Aware Training (QAT)](https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html)
4. PyTorch Documentation: [PyTorch 2 Export Quantization with X86 Backend through Inductor](https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html)
5. PyTorch Dev Discussions: [TorchInductor Update 6: CPU backend performance update and new features in PyTorch 2.1](https://dev-discuss.pytorch.org/t/torchinductor-update-6-cpu-backend-performance-update-and-new-features-in-pytorch-2-1/1514)
