In [6]:
# %load ../firstcell.py
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [1]:
# Static quantization of a model consists of the following steps:

#     Fuse modules
#     Insert Quant/DeQuant Stubs
#     Prepare the fused module (insert observers before and after layers)
#     Calibrate the prepared module (pass it representative data)
#     Convert the calibrated module (replace with quantized version)

import torch
from torch import nn
import copy

In [2]:
backend = "qnnpack"  # running on a x86 CPU. Use "qnnpack" if running on ARM.

In [3]:
model = nn.Sequential(nn.Conv2d(2, 64, 3), nn.ReLU(), nn.Conv2d(64, 128, 3), nn.ReLU())

## EAGER MODE
m = copy.deepcopy(model)
m.eval()

Sequential(
  (0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
)

In [4]:
"""Fuse
- Inplace fusion replaces the first module in the sequence with the fused module, and the rest with identity modules
"""
torch.quantization.fuse_modules(
    m, ["0", "1"], inplace=True
)  # fuse first Conv-ReLU pair
torch.quantization.fuse_modules(
    m, ["2", "3"], inplace=True
)  # fuse second Conv-ReLU pair

"""Insert stubs"""
m = nn.Sequential(torch.quantization.QuantStub(), *m, torch.quantization.DeQuantStub())

"""Prepare"""
m.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare(m, inplace=True)

Sequential(
  (0): QuantStub(
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (1): ConvReLU2d(
    (0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (2): Identity()
  (3): ConvReLU2d(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (4): Identity()
  (5): DeQuantStub()
)

In [5]:
"""Calibrate
- This example uses random data for convenience. Use representative (validation) data instead.
"""
with torch.inference_mode():
    for _ in range(10):
        x = torch.rand(1, 2, 28, 28)
        m(x)

"""Convert"""
torch.quantization.convert(m, inplace=True)

"""Check"""
print(m[[1]].weight().element_size())  # 1 byte instead of 4 bytes for FP32

RuntimeError: Didn't find engine for operation quantized::conv2d_prepack NoQEngine

In [None]:
## FX GRAPH
from torch.quantization import quantize_fx

m = copy.deepcopy(model)
m.eval()
qconfig_dict = {"": torch.quantization.get_default_qconfig(backend)}
# Prepare
model_prepared = quantize_fx.prepare_fx(m, qconfig_dict)
# Calibrate - Use representative (validation) data.
with torch.inference_mode():
    for _ in range(10):
        x = torch.rand(1, 2, 28, 28)
        model_prepared(x)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)