Skip to content

Explicit quantization in PyTorch before ONNX leads to slower TRT engine than ONNX PTQ #207

@liukang1811

Description

@liukang1811

Describe the bug

Steps/Code to reproduce bug

Expected behavior

I’m experiencing unexpected behavior when performing PTQ (post-training quantization) using NVIDIA ModelOpt versus ONNX quantization. Here’s a comparison of two workflows:

  1. Workflow 1 (trt1):

    • Perform PTQ on the PyTorch model (INT8) using NVIDIA ModelOpt.
    • Export the quantized PyTorch model to ONNX.
    • Convert the ONNX model to a TensorRT engine.
    • Observation: The TensorRT engine (trt1) contains many QuantizeLinear operators, and its inference time is 7.4 ms.
  2. Workflow 2 (trt2):

    • Export the float PyTorch model directly to ONNX.
    • Apply PTQ (INT8) on the ONNX model.
    • Convert the quantized ONNX model to a TensorRT engine.
    • Observation: The TensorRT engine (trt2) has no QuantizeLinear operators, and its inference time is 5.9 ms.

So trt2 is about 20% faster than trt1. I suspect the presence of explicit quantization nodes in trt1 might be causing the slowdown. Is this expected behavior? Does doing PTQ in the ONNX space generate a more optimized engine compared to PTQ in the PyTorch space before ONNX export?

I checked the extra operators in trt1 compared to trt2, and they look like this:
"/image_encoder/image_encoder/_blocks_conv.21/_expand_conv/_expand_conv.0/input_quantizer/QuantizeLinear", "timeMs" : 3.38419, "averageMs" : 0.0129663, "medianMs" : 0.012448, "percentage" : 0.137381 }

System information

  • Container used (if applicable): ?
  • OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): ?
  • CPU architecture (x86_64, aarch64): x86_64
  • GPU name (e.g. H100, A100, L40S): A800
  • GPU memory size: 80g
  • Number of GPUs: 1
  • Library versions (if applicable):
    • Python: 3.10
    • ModelOpt version or commit hash: 0.23
    • CUDA: 11.8
    • PyTorch: 2.1
    • Transformers: ?
    • TensorRT-LLM: ?
    • ONNXRuntime: 1.20
    • TensorRT: 10.7
  • Any other details that may help: ?
    from tqdm import tqdm
    import modelopt.torch.quantization as mtq

def forward_loop(model):
...
model_outputs_iter = model(tensor_iamge)
quant_model = mtq.quantize(torch_model, mtq.INT8_DEFAULT_CFG, forward_loop)
with torch.no_grad():
torch.onnx.export(model, input_data, "./net_quant.onnx",
input_names= input_names, output_names =output, opset_version=17, verbose=False,)
model_onnx = onnx.load("./image_encoder_muti_input_ori_quant.onnx")
from modelopt.onnx.quantization.qdq_utils import qdq_to_dq
model_simp, check = simplify(model_onnx)
assert check, "Simplified ONNX model could not be validated"
model_dq = qdq_to_dq(model_simp)
onnx.save(model_dq, "./image_encoder_muti_input_simplified_quant.onnx")

Click to expand: Python script to automatically collect system information
</details>

Metadata

Metadata

Assignees

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions