# PyTorch (fp32) -> ONNX (fp32) -> ONNX (QDQ/QOp)

## References
* https://pytorch.org/docs/stable/quantization.html#quantization-api-reference
* https://pytorch.org/tutorials/recipes/quantization.html
* https://tvm.apache.org/docs/how_to/compile_models/from_pytorch.html
* https://tvm.apache.org/docs/how_to/deploy_models/deploy_prequantized.html

## Setting parameters

In [1]:
# for dataset
dataset_root = "/datasets/IMAGENET"
dataset_samples = 1000
dataset_seed = 0
batch_size = 32

# for model
model_name = "resnet18"
weight_name = "ResNet18_Weights.DEFAULT"
input_name = "input"
height = 224
width = 224

# for onnx export
opset_version = 13

# for onnxruntime quantization
calib_samples = 256
quant_format = "QOperator"  #  ["Operator", "ODQ"]
activation_type = "QUInt8"  #  ["QInt8", "QUInt8"]
activation_symmetric = False  #  [True, False]
weight_type = "QInt8"  #  ["QInt8", "QUInt8"]
weight_symmetric = True  #  [True, False]
per_channel = False  #  [True, False]
calibrate_method_str = "MinMax"  #  ["MinMax", "Entropy", "Percentile"]

## 0. Configurations

### 0.1. Imports

In [2]:
import os
import pathlib

import onnx
import onnxruntime
import torch
import torchvision
from onnxruntime.quantization.registry import QLinearOpsRegistry

import onnx_util
import torch_util

  from .autonotebook import tqdm as notebook_tqdm


### 0.2. Setting device

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

device: cuda:0


### 0.3. Make artifact directory

In [4]:
artifact_dir_name = "-".join(
    [
        f"model_name={model_name}",
        f"quant_format={quant_format}",
        f"per_channel={per_channel}",
        f"activation_type={activation_type}",
        f"activation_symmetric={activation_symmetric}"
        f"weight_type={weight_type}",
        f"weight_symmetric={weight_symmetric}"
        f"calib={calibrate_method_str}",
    ]
)
artifact_dir = pathlib.Path().cwd().resolve() / "artifacts" / artifact_dir_name
os.makedirs(artifact_dir, exist_ok=True)

## 1. Preparing

### 1.1. Prepare float32 model

In [5]:
weights = torchvision.models.get_weight(weight_name)
model = getattr(torchvision.models.quantization, model_name)(weights=weights).eval()
preprocess = weights.transforms()

print(model)

QuantizableResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): QuantizableBasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (add_relu): FloatFunctional(
        (activation_post_process): Identity()
      )
    )
    (1): QuantizableBasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, e

### 1.2. Prepare dataset

In [6]:
dataset = torchvision.datasets.ImageNet(
    dataset_root,
    split="val",
    transform=preprocess,
)
dataset = torch_util.subset_dataset(
    dataset=dataset,
    num_samples=dataset_samples,
    seed=dataset_seed,
)

### 1.3. Test float32 model

In [7]:
torch_util.test_torch(model, dataset, batch_size)

[Test 1000 img]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:20<00:00,  1.56it/s, Top1=tensor(69.9219), Top5=tensor(89.7461)]


## 2. Export to ONNX

### 2.1 Export

In [None]:
x = torch.randn(1, 3, height, width, requires_grad=True)

onnx_model_file_path = str(artifact_dir / "fp32.onnx")

# Export the model
torch.onnx.export(
    model,
    x,
    onnx_model_file_path,
    export_params=True,
    opset_version=opset_version,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
)

#### 2.2 Test ONNX Model

In [None]:
session_fp32 = onnxruntime.InferenceSession(onnx_model_file_path)
onnx_util.test_onnx(session_fp32, dataset)

## 3 Static Quantization on ONNXRuntime

#### 3.1 Define DataReader for calibration

In [None]:
class ImageNetDataReader(onnxruntime.quantization.CalibrationDataReader):
    def __init__(
        self,
        dataset: torchvision.datasets.VisionDataset,
        batch_size: int = 1,
    ):
        self.dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
        self.dataiter = self.dataloader.__iter__()

    def get_next(self):
        out = next(self.dataiter, None)
        if out is None:
            return None

        image, _ = out
        return {"input": image.numpy()}

#### 3.2 Static Quantization

In [None]:
dataset_calib = torch_util.subset_dataset(dataset, calib_samples, seed=dataset_seed)
dr = ImageNetDataReader(dataset)

# Exclude gemm from quantization because it cannot be parse by TVM
op_types_to_quantize = list(QLinearOpsRegistry.keys())
op_types_to_quantize.remove("Gemm")

calibrate_method = onnx_util.calibration_method_from_str(calibrate_method_str)

quant_onnx_model_file_path = str(artifact_dir / "quant.onnx")

onnxruntime.quantization.quantize_static(
    model_input=onnx_model_file_path,
    model_output=quant_onnx_model_file_path,
    calibration_data_reader=dr,
    quant_format=onnxruntime.quantization.QuantFormat.from_string(quant_format),
    op_types_to_quantize=op_types_to_quantize,
    per_channel=per_channel,
    activation_type=onnxruntime.quantization.QuantType.from_string(activation_type),
    weight_type=onnxruntime.quantization.QuantType.from_string(weight_type),
    calibrate_method=calibrate_method,
    extra_options={
        "ActivationSymmetric": activation_symmetric,
        "WeightSymmetric": weight_symmetric,
    },
)

#### 3.3 Test Quantized ONNX Model

In [12]:
session_quant = onnxruntime.InferenceSession(quant_onnx_model_file_path)
onnx_util.test_onnx(session_quant, dataset)

[Test 1000 img]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:23<00:00, 42.45it/s, Top1=tensor(70.6000), Top5=tensor(89.8000)]
