# Notes of Exporting Models to ONNX format

## 1. Before Quantization

### 1.1 ONNX-Compatible Modification Before Exporting

Model architecture should be modified to be ONNX-compatible. Related code:

```python
x = F.avg_pool1d(x, x.shape[-1]) 
```

->

```python
x = F.adaptive_avg_pool1d(x, 1)
```
- `F.avg_pool1d(x, x.shape[-1])` uses a dynamic kernel size and PyTorch evaluates it at runtime. ONNX export uses static tracing (`torch.onnx.export()`), and it cannot trace dynamic kernel sizes derived from input shapes.
- `adaptive_avg_pool1d` is natively supported in ONNX and symbolically defines the output to always have a fixed length (here, 1).

```python 
x = x.permute(0, 2, 1)
```

->
```python
x = x.transpose(1, 2)
```
- The permute() operation is not supported in the quantized version of PyTorch. [Related Issue Here.](https://github.com/pytorch/pytorch/issues/109425)
- This is because of the symbolic function, which describe how to map a PyTorch op to an ONNX op, of `permute` assumes float tensors only. As shown in the [official symbolic_opset](https://github.com/pytorch/pytorch/blob/ffaed8c569406839335bf46dafc4c3e8871e4b8a/torch/onnx/symbolic_opset9.py#L989), `@symbolic_helper.quantized_args(True)` didn't decorate the symbolic function `def permute`.
- Since the `aten.permute` has been implemented [here](https://docs.pytorch.org/docs/main/torch.compiler_ir.html), the straightforward way I've tried is to overwrite and registry the symbolic function of `permute` by `torch.onnx.symbolic_registry.register_op`, but it turned to be very complicated. 
- Then I found the symbolic function of `transpose` is supported for quantized version of the base operator, so here's the tricky way.



### 1.2 Preprocess by `onnxruntime.quantization.preprocess` \(Quantization in ort)

``` bash
python -m onnxruntime.quantization.preprocess \
    --input models/cnn_fp32.onnx \
    --output models/cnn_fp32_infer.onnx
```
Pre-processing is to transform a float32 model to prepare it for quantization and improve quantization quality. It consists of the following three optional steps:

- Symbolic shape inference. This is best suited for transformer models.
- ONNX shape inference.
- Model optimization: This step uses ONNX Runtime native library to rewrite the computation graph, including merging computation nodes, eliminating redundancies to improve runtime efficiency.

In our case, according to the computational graph, the preprocessing helps to:
- Figures out the shape in each step in the graph (The shape is noted next to each arrow after preprocessing);
- Fuse `Matmul` and `Add` operators into [`Gemm`](https://onnx.ai/onnx/operators/onnx__Gemm.html) operator for matrix multiplication.




## 2. Precision Alignment before and after exporting

[Reference](https://web.mit.edu/10.001/Web/Tips/Converge.htm)

I use `numpy.testing.assert_allclose` for precision alignment before and after exporting to ONNX. 

- For FP models, I set tolerances as `rtol=1e-3, atol=1e-05`and it passed the test. 

- For Quantized model I set `rtol=0.1, atol=0.1`and it failed, and only ~95%(QAT) and ~97%(PTQ) top-1 predictions of pytorch model and ONNX model matched. However the final accuracy keep almost the same.







# 3. Evaluation after Quantization
Run the script `./onnx_static_quantize.sh` of overall workflow of static quantization in ONNX Runtime and evaluation, given by the well-trained fp32 model checkpoint. Run `./onnx_qat_quantize.sh` to export QAT model to ONNX as well. The output will be close to this:

|Metrics|FP32 ONNX|Int8 ONNX|QAT Pytorch|QAT ONNX|PTQ Pytorch|PTQ ONNX|
| ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| Model Size | 106.51 KB | 44.09 KB | 40.62 KB | 46.18 KB | 40.62 KB | 46.18 KB |
| Accuracy | 83.14% | 78.41% | 79.63% | 79.73% | 75.85% | 75.89% |
| Average Inference Time | 9.62 ms | 2.85 ms | \- |2.81 ms | \- | 2.89 ms|

# 4. Check the computational graph of quantized model

## 1\) Node category
| Aspect        | Initializer                           | Constant node                          |
| ------------- | ------------------------------------- | -------------------------------------- |
| **Purpose**   | Model parameters (weights, biases)    | Inline constants inside ops            |
| **Stored in** | `graph.initializer`                   | `graph.node` with `op_type="Constant"` |
| **Changes?**  | Can be updated (e.g., for finetuning) | Always fixed in graph logic            |
| **Usage in our graph**     | generated in ort quantization                  | if it's directly exported from a quantized pytorch model then it's constant |


In [32]:
import onnx


onnx_static = onnx.load("../models/cnn_int8.onnx")
o_graph = onnx_static.graph
o_node = o_graph.node

torch_ptq = onnx.load("../models/cnn_ptq.onnx")
t_graph = torch_ptq.graph
t_node = t_graph.node

def print_tensor_info(graph, tensor_name):
    found = False

    # Check in initializers
    for init in graph.initializer:
        if init.name == tensor_name:
            val_array = onnx.numpy_helper.to_array(init)
            print(f"Found in initializer: `{tensor_name}`.")
            print(f"  Shape: {val_array.shape}")
            print(f"  Dtype: {val_array.dtype}")
            # print(f"  Values: {val_array}")
            found = True
            return val_array

    # Check in Constant nodes
    for node in graph.node:
        if node.op_type == "Constant":
            for attr in node.attribute:
                if attr.name == "value":
                    val_array = onnx.numpy_helper.to_array(attr.t)
                    if node.output[0] == tensor_name:
                        print(f"Found {tensor_name} in Constant node output `{node.name}`. ")
                        print(f"  Shape: {val_array.shape}")
                        print(f"  Dtype: {val_array.dtype}")
                        # print(f"  Values: {val_array}")
                        found = True
                        return val_array

    if not found:
        print(f"Tensor '{tensor_name}' not found in initializers or Constant nodes.")

o_L1_bias_q = print_tensor_info(o_graph, "onnx::Conv_55_quantized")
o_L1_bias_scale = print_tensor_info(o_graph, "onnx::Conv_55_quantized_scale")
# o_L1_bias_zero_point is all-zero since it's symmetric quantization
t_L1_bias_q = print_tensor_info(t_graph, "/block1/block/0/Constant_6_output_0")
t_L1_bias_scale = print_tensor_info(t_graph, "/block1/block/0/Constant_7_output_0")

Found in initializer: `onnx::Conv_55_quantized`.
  Shape: (32,)
  Dtype: int32
Found in initializer: `onnx::Conv_55_quantized_scale`.
  Shape: (32,)
  Dtype: float32
Found /block1/block/0/Constant_6_output_0 in Constant node output `/block1/block/0/Constant_6`. 
  Shape: (32,)
  Dtype: int32
Found /block1/block/0/Constant_7_output_0 in Constant node output `/block1/block/0/Constant_7`. 
  Shape: (32,)
  Dtype: float32


In [38]:
o_L1_bias_fp = o_L1_bias_q * o_L1_bias_scale
t_L1_bias_fp = t_L1_bias_q * t_L1_bias_scale

import numpy as np
try:
    np.testing.assert_allclose(o_L1_bias_fp, t_L1_bias_fp, rtol=1e-3, atol=1e-5)
except AssertionError as e:
    print(f"AssertionError: {e}")

AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 28 / 32 (87.5%)
Max absolute difference among violations: 0.0007678
Max relative difference among violations: 0.10642903
 ACTUAL: array([-0.00612 ,  0.028869,  0.016809, -0.003452,  0.004411, -0.001444,
       -0.02042 ,  0.092428,  0.022525,  0.006155, -0.019724, -0.005085,
       -0.00049 , -0.005574, -0.006422,  0.060337, -0.330705, -0.006446,...
 DESIRED: array([-0.006085,  0.028701,  0.016548, -0.00312 ,  0.004134, -0.001305,
       -0.020278,  0.092147,  0.022479,  0.005934, -0.019425, -0.004903,
        0.      , -0.005568, -0.006366,  0.06011 , -0.33066 , -0.006397,...


In [41]:
o_L1_weight_q = print_tensor_info(o_graph, "onnx::Conv_54_quantized")
o_L1_weight_scale = print_tensor_info(o_graph, "onnx::Conv_54_scale")

t_L1_weight_q = print_tensor_info(t_graph, "/block1/block/0/Constant_2_output_0")
t_L1_weight_scale = print_tensor_info(t_graph, "/block1/block/0/Constant_3_output_0")

o_L1_weight_fp = o_L1_weight_scale[:, np.newaxis, np.newaxis] * (o_L1_weight_q.astype(np.float32))
t_L1_weight_fp = t_L1_weight_scale[:, np.newaxis, np.newaxis] * (t_L1_weight_q.astype(np.float32))
try:
    np.testing.assert_allclose(o_L1_weight_fp, t_L1_weight_fp, rtol=1e-3, atol=1e-5)
except AssertionError as e:
    print(f"AssertionError: {e}")


Found in initializer: `onnx::Conv_54_quantized`.
  Shape: (32, 1, 80)
  Dtype: int8
Found in initializer: `onnx::Conv_54_scale`.
  Shape: (32,)
  Dtype: float32
Found /block1/block/0/Constant_2_output_0 in Constant node output `/block1/block/0/Constant_2`. 
  Shape: (32, 1, 80)
  Dtype: int8
Found /block1/block/0/Constant_3_output_0 in Constant node output `/block1/block/0/Constant_3`. 
  Shape: (32,)
  Dtype: float32
AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 2547 / 2560 (99.5%)
Max absolute difference among violations: 0.11737084
Max relative difference among violations: 1.
 ACTUAL: array([[[-0.966245,  0.191236,  0.462993, ..., -0.644163, -0.191236,
         -0.744814]],
...
 DESIRED: array([[[-0.962456,  0.200512,  0.461177, ..., -0.641637, -0.190486,
         -0.741893]],
...


In [42]:
o_L2_bias_q = print_tensor_info(o_graph, "onnx::Conv_58_quantized")
o_L2_bias_scale = print_tensor_info(o_graph, "onnx::Conv_58_quantized_scale")
# o_L2_bias_zero_point is all-zero since it's symmetric quantization
t_L2_bias_q = print_tensor_info(t_graph, "/block2/block/0/Constant_6_output_0")
t_L2_bias_scale = print_tensor_info(t_graph, "/block2/block/0/Constant_7_output_0")

o_L2_bias_fp = o_L2_bias_q * o_L2_bias_scale
t_L2_bias_fp = t_L2_bias_q * t_L2_bias_scale

try:
    np.testing.assert_allclose(o_L2_bias_fp, t_L2_bias_fp, rtol=1e-3, atol=1e-5)
except AssertionError as e:
    print(f"AssertionError: {e}")

Found in initializer: `onnx::Conv_58_quantized`.
  Shape: (32,)
  Dtype: int32
Found in initializer: `onnx::Conv_58_quantized_scale`.
  Shape: (32,)
  Dtype: float32
Found /block2/block/0/Constant_6_output_0 in Constant node output `/block2/block/0/Constant_6`. 
  Shape: (32,)
  Dtype: int32
Found /block2/block/0/Constant_7_output_0 in Constant node output `/block2/block/0/Constant_7`. 
  Shape: (32,)
  Dtype: float32
AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 9 / 32 (28.1%)
Max absolute difference among violations: 0.00045552
Max relative difference among violations: 0.0534982
 ACTUAL: array([-0.025805,  0.322926, -0.060563,  0.073844,  0.516396,  0.20104 ,
        0.500617,  0.469001,  0.395666,  0.270268,  0.501922,  0.298508,
        0.240467,  0.251023,  0.522687,  0.287385,  0.318088,  0.00897 ,...
 DESIRED: array([-0.025711,  0.323145, -0.060501,  0.07353 ,  0.516634,  0.200726,
        0.500608,  0.468857,  0.395571,  0.270032,  0.50180

In [43]:
o_L2_weight_q = print_tensor_info(o_graph, "onnx::Conv_57_quantized")
o_L2_weight_scale = print_tensor_info(o_graph, "onnx::Conv_57_scale")

t_L2_weight_q = print_tensor_info(t_graph, "/block2/block/0/Constant_2_output_0")
t_L2_weight_scale = print_tensor_info(t_graph, "/block2/block/0/Constant_3_output_0")

o_L2_weight_fp = o_L2_weight_scale[:, np.newaxis, np.newaxis] * (o_L2_weight_q.astype(np.float32))
t_L2_weight_fp = t_L2_weight_scale[:, np.newaxis, np.newaxis] * (t_L2_weight_q.astype(np.float32))
try:
    np.testing.assert_allclose(o_L2_weight_fp, t_L2_weight_fp, rtol=1e-3, atol=1e-5)
except AssertionError as e:
    print(f"AssertionError: {e}")

Found in initializer: `onnx::Conv_57_quantized`.
  Shape: (32, 32, 3)
  Dtype: int8
Found in initializer: `onnx::Conv_57_scale`.
  Shape: (32,)
  Dtype: float32
Found /block2/block/0/Constant_2_output_0 in Constant node output `/block2/block/0/Constant_2`. 
  Shape: (32, 32, 3)
  Dtype: int8
Found /block2/block/0/Constant_3_output_0 in Constant node output `/block2/block/0/Constant_3`. 
  Shape: (32,)
  Dtype: float32
AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 2816 / 3072 (91.7%)
Max absolute difference among violations: 0.00203356
Max relative difference among violations: 0.16338585
 ACTUAL: array([[[ 0.032159,  0.041882,  0.041134],
        [ 0.08152 ,  0.044873,  0.051604],
        [-0.011966,  0.020193,  0.055344],...
 DESIRED: array([[[ 0.032778,  0.041717,  0.040972],
        [ 0.0812  ,  0.045442,  0.052147],
        [-0.011919,  0.020114,  0.055126],...


In [44]:
o_L3_bias_q = print_tensor_info(o_graph, "onnx::Conv_61_quantized")
o_L3_bias_scale = print_tensor_info(o_graph, "onnx::Conv_61_quantized_scale")
t_L3_bias_q = print_tensor_info(t_graph, "/block3/block/0/Constant_6_output_0")
t_L3_bias_scale = print_tensor_info(t_graph, "/block3/block/0/Constant_7_output_0")

o_L3_bias_fp = o_L3_bias_q * o_L3_bias_scale
t_L3_bias_fp = t_L3_bias_q * t_L3_bias_scale

try:
    np.testing.assert_allclose(o_L3_bias_fp, t_L3_bias_fp, rtol=1e-3, atol=1e-5)
except AssertionError as e:
    print(f"AssertionError: {e}")

Found in initializer: `onnx::Conv_61_quantized`.
  Shape: (64,)
  Dtype: int32
Found in initializer: `onnx::Conv_61_quantized_scale`.
  Shape: (64,)
  Dtype: float32
Found /block3/block/0/Constant_6_output_0 in Constant node output `/block3/block/0/Constant_6`. 
  Shape: (64,)
  Dtype: int32
Found /block3/block/0/Constant_7_output_0 in Constant node output `/block3/block/0/Constant_7`. 
  Shape: (64,)
  Dtype: float32
AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 14 / 64 (21.9%)
Max absolute difference among violations: 0.0010885
Max relative difference among violations: 0.04639256
 ACTUAL: array([ 1.016544,  0.211402,  1.294162,  1.235873,  0.696388,  0.216166,
        0.591279,  0.390996,  0.13884 , -0.891073,  1.167723, -1.398085,
        0.161216,  1.517577,  0.353284, -0.489279,  0.181339, -0.009228,...
 DESIRED: array([ 1.016257,  0.211391,  1.294079,  1.235603,  0.695967,  0.215784,
        0.590768,  0.390695,  0.138811, -0.891067,  1.1674

In [45]:
o_L3_weight_q = print_tensor_info(o_graph, "onnx::Conv_60_quantized")
o_L3_weight_scale = print_tensor_info(o_graph, "onnx::Conv_60_scale")

t_L3_weight_q = print_tensor_info(t_graph, "/block3/block/0/Constant_2_output_0")
t_L3_weight_scale = print_tensor_info(t_graph, "/block3/block/0/Constant_3_output_0")

o_L3_weight_fp = o_L3_weight_scale[:, np.newaxis, np.newaxis] * (o_L3_weight_q.astype(np.float32))
t_L3_weight_fp = t_L3_weight_scale[:, np.newaxis, np.newaxis] * (t_L3_weight_q.astype(np.float32))
try:
    np.testing.assert_allclose(o_L3_weight_fp, t_L3_weight_fp, rtol=1e-3, atol=1e-5)
except AssertionError as e:
    print(f"AssertionError: {e}")

Found in initializer: `onnx::Conv_60_quantized`.
  Shape: (64, 32, 3)
  Dtype: int8
Found in initializer: `onnx::Conv_60_scale`.
  Shape: (64,)
  Dtype: float32
Found /block3/block/0/Constant_2_output_0 in Constant node output `/block3/block/0/Constant_2`. 
  Shape: (64, 32, 3)
  Dtype: int8
Found /block3/block/0/Constant_3_output_0 in Constant node output `/block3/block/0/Constant_3`. 
  Shape: (64,)
  Dtype: float32
AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 6053 / 6144 (98.5%)
Max absolute difference among violations: 0.01690417
Max relative difference among violations: 1.
 ACTUAL: array([[[-0.188505,  0.081516,  0.005095],
        [-0.331158,  0.091705,  0.163032],
        [-0.045853, -0.035663, -0.081516],...
 DESIRED: array([[[-0.192841,  0.081196,  0.005075],
        [-0.334934,  0.091346,  0.162392],
        [-0.045673, -0.035523, -0.081196],...


In [46]:
o_L4_bias_q = print_tensor_info(o_graph, "onnx::Conv_64_quantized")
o_L4_bias_scale = print_tensor_info(o_graph, "onnx::Conv_64_quantized_scale")
t_L4_bias_q = print_tensor_info(t_graph, "/block4/block/0/Constant_6_output_0")
t_L4_bias_scale = print_tensor_info(t_graph, "/block4/block/0/Constant_7_output_0")

o_L4_bias_fp = o_L4_bias_q * o_L4_bias_scale
t_L4_bias_fp = t_L4_bias_q * t_L4_bias_scale

try:
    np.testing.assert_allclose(o_L4_bias_fp, t_L4_bias_fp, rtol=1e-3, atol=1e-5)
except AssertionError as e:
    print(f"AssertionError: {e}")

Found in initializer: `onnx::Conv_64_quantized`.
  Shape: (64,)
  Dtype: int32
Found in initializer: `onnx::Conv_64_quantized_scale`.
  Shape: (64,)
  Dtype: float32
Found /block4/block/0/Constant_6_output_0 in Constant node output `/block4/block/0/Constant_6`. 
  Shape: (64,)
  Dtype: int32
Found /block4/block/0/Constant_7_output_0 in Constant node output `/block4/block/0/Constant_7`. 
  Shape: (64,)
  Dtype: float32
AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 7 / 64 (10.9%)
Max absolute difference among violations: 0.00041547
Max relative difference among violations: 0.00333149
 ACTUAL: array([-0.317625,  1.4661  , -0.215446, -0.61548 , -1.684858, -1.574953,
        1.452836, -0.67934 ,  0.495698, -0.343092, -2.538879, -2.524491,
       -0.080531, -2.704656, -1.390812,  0.305723,  0.029642, -0.738222,...
 DESIRED: array([-0.317305,  1.466154, -0.21533 , -0.615283, -1.684831, -1.574783,
        1.452739, -0.679174,  0.495546, -0.342677, -2.5387

In [47]:
o_L4_weight_q = print_tensor_info(o_graph, "onnx::Conv_63_quantized")
o_L4_weight_scale = print_tensor_info(o_graph, "onnx::Conv_63_scale")

t_L4_weight_q = print_tensor_info(t_graph, "/block4/block/0/Constant_2_output_0")
t_L4_weight_scale = print_tensor_info(t_graph, "/block4/block/0/Constant_3_output_0")

o_L4_weight_fp = o_L4_weight_scale[:, np.newaxis, np.newaxis] * (o_L4_weight_q.astype(np.float32))
t_L4_weight_fp = t_L4_weight_scale[:, np.newaxis, np.newaxis] * (t_L4_weight_q.astype(np.float32))
try:
    np.testing.assert_allclose(o_L4_weight_fp, t_L4_weight_fp, rtol=1e-3, atol=1e-5)
except AssertionError as e:
    print(f"AssertionError: {e}")

Found in initializer: `onnx::Conv_63_quantized`.
  Shape: (64, 64, 3)
  Dtype: int8
Found in initializer: `onnx::Conv_63_scale`.
  Shape: (64,)
  Dtype: float32
Found /block4/block/0/Constant_2_output_0 in Constant node output `/block4/block/0/Constant_2`. 
  Shape: (64, 64, 3)
  Dtype: int8
Found /block4/block/0/Constant_3_output_0 in Constant node output `/block4/block/0/Constant_3`. 
  Shape: (64,)
  Dtype: float32
AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 12118 / 12288 (98.6%)
Max absolute difference among violations: 0.01933029
Max relative difference among violations: 1.
 ACTUAL: array([[[ 0.264763, -0.63249 ,  0.500108],
        [ 0.102963,  0.411854, -1.073762],
        [-0.294181,  0.      ,  0.102963],...
 DESIRED: array([[[ 0.263725, -0.644661,  0.498147],
        [ 0.10256 ,  0.410239, -1.069551],
        [-0.293028,  0.      ,  0.10256 ],...
