In [1]:
skip_long_operations = True

# Neural Networks Quantization - Quantization Aware Training (QAT) using PyTorch 2 Export Quantization and X86 Backend

**Warning:** The following quantization code is not for the faint of heart. Since this is PyTorch’s third iteration of a quantization API, the only thing we can say with confidence is that it will change again. That said, it’s been tested and confirmed to work on the following PyTorch versions:
- 2.8.0.dev20250319+cu128
- 2.6.0+cu124


## Basic Setup

In [2]:
import os
import time
import warnings
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.quantization import quantize_dynamic
from torch.ao.quantization import get_default_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.utils.data import DataLoader, Subset

# ignores irrelevant warning, see: https://github.com/pytorch/pytorch/issues/149829
warnings.filterwarnings("ignore", message=".*TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support.*")

# ignores irrelevant warning, see: https://github.com/tensorflow/tensorflow/issues/77293
warnings.filterwarnings("ignore", message=".*erase_node(.*) on an already erased node.*")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

device=device(type='cuda')


## Torch Version

In [3]:
import torch
print(torch.__version__)

2.6.0+cu124


## Get CIFAR-10 train and test sets

In [4]:
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(
    datasets.CIFAR10(root="./data", train=True, download=True, transform=transform),
    batch_size=128, shuffle=True
)

test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(
    datasets.CIFAR10(root="./data", train=False, download=True, transform=transform),
    batch_size=128,
    shuffle=False,
    num_workers=2,
    drop_last=True,
)

calibration_dataset = Subset(train_dataset, range(256))
calibration_loader = DataLoader(calibration_dataset, batch_size=128, shuffle=False)

## Adjust ResNet18 network for CIFAR-10 dataset

In [5]:
def get_resnet18_for_cifar10():
    model = models.resnet18(weights=None, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model.to(device)

model_to_quantize = get_resnet18_for_cifar10()

## Define Train and Evaluate functions

In [6]:
def train(model, loader, epochs, lr=0.01, save_path="model.pth", silent=False):
    if os.path.exists(save_path):
        if not silent:
            print(f"Model already trained. Loading from {save_path}")
        model.load_state_dict(torch.load(save_path))
        return

    # no saved model found. training from given model state

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    for epoch in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()
        if not silent:
            print(f"Epoch {epoch+1}: loss={loss.item():.4f}")

    if save_path:
        torch.save(model.state_dict(), save_path)
        if not silent:
            print(f"Training complete. Model saved to {save_path}")

def evaluate(model, device_str):
    if device_str:
        device = torch.device(device_str)
        model.to(device)
    correct = total = 0
    with torch.no_grad():
        for x, y in test_loader:
            if device_str:
                x, y = x.to(device), y.to(device)
            preds = model(x).argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

## Define helper functions to measure latency

In [8]:
class Timer:
    """
    A simple timer utility for measuring elapsed time in milliseconds.

    Supports both GPU and CPU timing:
    - If CUDA is available, uses torch.cuda.Event for accurate GPU timing.
    - Otherwise, falls back to wall-clock CPU timing via time.time().

    Methods:
        start(): Start the timer.
        stop(): Stop the timer and return the elapsed time in milliseconds.
    """
    
    def __init__(self):
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.starter = torch.cuda.Event(enable_timing=True)
            self.ender = torch.cuda.Event(enable_timing=True)

    def start(self):
        if self.use_cuda:
            self.starter.record()
        else:
            self.start_time = time.time()

    def stop(self):
        if self.use_cuda:
            self.ender.record()
            torch.cuda.synchronize()
            return self.starter.elapsed_time(self.ender)  # ms
        else:
            return (time.time() - self.start_time) * 1000  # ms

def estimate_latency(model, example_inputs, repetitions=50):
    timer = Timer()
    timings = np.zeros((repetitions, 1))

    # warm-up
    for _ in range(5):
        _ = model(example_inputs)

    with torch.no_grad():
        for rep in range(repetitions):
            timer.start()
            _ = model(example_inputs)
            elapsed = timer.stop()
            timings[rep] = elapsed

    return np.mean(timings), np.std(timings)

In [19]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print("Size (MB):", os.path.getsize("temp.p")/1e6)
    os.remove("temp.p")

## Train full model

In [None]:
model_to_quantize.train()
train(model_to_quantize, train_loader, epochs=10, save_path="full_model.pth")

Model already trained. Loading from full_model.pth


## Evaluate full model

In [11]:
# evaluate accuracy
model_to_quantize.eval()
accuracy_full = evaluate(model_to_quantize, 'cuda')
print(f"Accuracy (full): {accuracy_full*100:.2f}%")

# get model size
size_mb_full = os.path.getsize("full_model.pth") / 1e6
print(f"Size (full): {size_mb_full:.2f} MB")

# estimate latency on GPU
example_input = torch.rand(128, 3, 32, 32).cuda()
model_to_quantize.cuda()
latency_mu_full_gpu, latency_std_full_gpu = estimate_latency(model_to_quantize, example_input)
print(f"Latency (full, on gpu): {latency_mu_full_gpu:.2f} ± {latency_std_full_gpu:.2f} ms")

# estimate latency on CPU
if not skip_long_operations:
    example_input = torch.rand(128, 3, 32, 32).cpu()
    model_to_quantize.cpu()
    latency_mu_full_cpu, latency_std_full_cpu = estimate_latency(model_to_quantize, example_input)
    print(f"Latency (full, on cpu): {latency_mu_full_cpu:.2f} ± {latency_std_full_cpu:.2f} ms")
else:
    print("Skipped CPU evaluation")

Accuracy (full): 79.31%
Size (full): 44.78 MB
Latency (full, on gpu): 16.82 ± 0.03 ms
Skipped CPU evaluation


## Quantization Aware Training (QAT)

In [None]:
model_to_quantize = model_to_quantize.eval()
model_to_quantize = model_to_quantize.to(device)

# program capture
example_inputs = (torch.rand(128, 3, 32, 32).to(device),)
exported_model  = torch.export.export_for_training(model_to_quantize, example_inputs).module()

# quantization
from torch.ao.quantization.quantize_pt2e import (
  prepare_qat_pt2e,
  convert_pt2e,
)



In [None]:
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
  XNNPACKQuantizer,
  get_symmetric_quantization_config,
)

quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))

prepared_model = prepare_qat_pt2e(exported_model, quantizer)

In [None]:
torch.ao.quantization.move_exported_model_to_train(prepared_model)
train(prepared_model, train_loader, epochs=3, save_path="qat_model.pth")

Model already trained. Loading from qat_model.pth


In [18]:

quantized_model = convert_pt2e(prepared_model)

# we have a model with aten ops doing integer computations when possible

In [20]:
# Baseline model size and accuracy
print("Size of baseline model")
print_size_of_model(model_to_quantize)

model_to_quantize.eval()
accuracy_baseline = evaluate(model_to_quantize, device_str = 'cuda')
print("Baseline Float Model Evaluation accuracy: %2.2f"%(100*accuracy_baseline))

# Quantized model size and accuracy
print("Size of model after quantization")
# export again to remove unused weights
quantized_model = torch.export.export_for_training(quantized_model, example_inputs).module()
print_size_of_model(quantized_model)

torch.ao.quantization.move_exported_model_to_eval(quantized_model)
accuracy_quantized = evaluate(quantized_model, device_str = 'cuda')
print("[before serilaization] Evaluation accuracy on test dataset: %2.2f"%(100*accuracy_quantized))

Size of baseline model
Size (MB): 44.767812
Baseline Float Model Evaluation accuracy: 80.84
Size of model after quantization
Size (MB): 11.195534
[before serilaization] Evaluation accuracy on test dataset: 80.54


## References
1. PyTorch Documentation: [Quantization](https://pytorch.org/docs/stable/quantization.html)
2. PyTorch Documentation: [PyTorch 2 Export Post Training Quantization](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html)
3. PyTorch Documentation: [PyTorch 2 Export Quantization-Aware Training (QAT)](https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html)
4. PyTorch Documentation: [PyTorch 2 Export Quantization with X86 Backend through Inductor](https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html)
