diff --git a/examples/onnx_ptq/README.md b/examples/onnx_ptq/README.md index b1d22b896..6128ba24e 100644 --- a/examples/onnx_ptq/README.md +++ b/examples/onnx_ptq/README.md @@ -13,7 +13,7 @@ Model Optimizer enables highly performant quantization formats including NVFP4, | Pre-Requisites | Required & optional packages to use this technique | [Link](#pre-requisites) | | | Getting Started | Learn how to optimize your models using PTQ to reduce precision and improve inference efficiency | [Link](#getting-started) | [docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/_onnx_quantization.html) | | Support Matrix | View the ONNX export supported LLM models | [Link](#onnx-export-supported-llm-models) | | -| PyTorch to ONNX | Example scripts demonstrating how to quantize with PyTorch and then convert to ONNX | [Link](#torch-quantization-to-onnx-example-for-mxfp8-int4-or-nvfp4-precision) | | +| PyTorch to ONNX | Example scripts demonstrating how to quantize with PyTorch and then convert to ONNX | [Link](#torch-quantization-to-onnx-export-example) | | | Advanced Features | Examples demonstrating use advanced ONNX quantization features | [Link](#advanced-features) | | | Pre-Quantized Checkpoints | Ready to deploy Hugging Face pre-quantized checkpoints | [Link](#pre-quantized-checkpoints) | | | Resources | Extra links to relevant resources | [Link](#resources) | | @@ -80,7 +80,7 @@ python image_prep.py \ The model can be quantized as an FP8, INT8 or INT4 model using either the CLI or Python API. For FP8 and INT8 quantization, you have a choice between `max` and `entropy` calibration algorithms. For INT4 quantization, [awq_clip](https://arxiv.org/abs/2306.00978) or [rtn_dq](https://ar5iv.labs.arxiv.org/html/2301.12017) algorithms can be chosen. -> *For NVFP4 and MXFP8 ONNX, see the [PyTorch to ONNX section](#torch-quantization-to-onnx-example-for-mxfp8-int4-or-nvfp4-precision).* +> *For NVFP4 and MXFP8 ONNX, see the [PyTorch to ONNX section](#torch-quantization-to-onnx-export-example).* > *Minimum opset requirements: int8 (13+), fp8 (21+), int4 (21+). ModelOpt will automatically upgrade lower opset versions to meet these requirements.* @@ -129,9 +129,9 @@ The top5 accuracy of the model is Inference latency of the model is ms ``` -## Torch quantization to ONNX example for MXFP8, INT4 or NVFP4 precision +## Torch quantization to ONNX export example -This example demonstrates how to quantize a [timm](https://github.com/huggingface/pytorch-image-models) vision model using MXFP8, INT4 or NVFP4 precision formats, and then export it to ONNX. The script leverages the ModelOpt toolkit for both quantization and ONNX export. +This example demonstrates how to quantize a [timm](https://github.com/huggingface/pytorch-image-models) vision model for various precision formats followed by export to ONNX. The script leverages the ModelOpt toolkit for both quantization and ONNX export. > *Opset 20 is used to export the torch models to ONNX.* @@ -148,7 +148,7 @@ This example demonstrates how to quantize a [timm](https://github.com/huggingfac ```bash python torch_quant_to_onnx.py \ --timm_model_name=vit_base_patch16_224 \ - --quantize_mode= \ + --quantize_mode= \ --onnx_save_path= ``` diff --git a/examples/onnx_ptq/evaluation.py b/examples/onnx_ptq/evaluation.py index ad32323d9..8b96f3d95 100644 --- a/examples/onnx_ptq/evaluation.py +++ b/examples/onnx_ptq/evaluation.py @@ -152,8 +152,9 @@ def evaluate_accuracy( # Calculate accuracy outputs = outputs[0] if isinstance(outputs, list) else outputs.data - labels_size = labels.size(0) + outputs = outputs[:labels_size] + total += labels_size labels = labels.to(outputs.device) diff --git a/examples/onnx_ptq/torch_quant_to_onnx.py b/examples/onnx_ptq/torch_quant_to_onnx.py index a89497f67..06f1b1db8 100644 --- a/examples/onnx_ptq/torch_quant_to_onnx.py +++ b/examples/onnx_ptq/torch_quant_to_onnx.py @@ -323,7 +323,7 @@ def main(): ) print(f"Quantized Model - Top-1 Accuracy: {top1:.2f}%, Top-5 Accuracy: {top5:.2f}%") - if args.quantize_mode in ["fp8", "int8", "auto"]: + if args.quantize_mode in ["auto"]: print( f"The selected quantization mode {args.quantize_mode} is not supported for ONNX export yet." ) diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 66c613a6c..a140d7e29 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -1037,6 +1037,21 @@ def remove_graph_input_q(onnx_model: onnx.ModelProto) -> onnx.ModelProto: return onnx_model +def replace_zero_scale_with_smallest_nonzero(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Replace zero scale values with smallest nonzero fp16 value in the ONNX model.""" + graph = onnx_model.graph + fp16_smallest_nonzero = np.float16(6e-08) + scale_nodes = [node.input[1] for node in graph.node if node.op_type == "QuantizeLinear"] + for node in graph.node: + if node.op_type == "Constant" and node.output[0] in scale_nodes: + for attr in node.attribute: + if attr.name == "value": + tensor = numpy_helper.to_array(attr.t) + new_tensor = np.where(tensor == 0, fp16_smallest_nonzero, tensor) + attr.t.CopyFrom(numpy_helper.from_array(new_tensor, attr.t.name)) + return onnx_model + + def _cast_initializer_to_dtype( node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto] ): diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index cfebd0dc1..12f6893f3 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -37,6 +37,7 @@ qdq_to_dq, quantize_weights_to_int4, quantize_weights_to_mxfp8, + replace_zero_scale_with_smallest_nonzero, ) from modelopt.onnx.utils import ( get_input_names, @@ -336,6 +337,32 @@ def is_mxfp8_quantized(model: nn.Module) -> bool: return False +def is_int8_quantized(model: nn.Module) -> bool: + """Check if the model is quantized in INT8 mode.""" + for _, module in model.named_modules(): + if ( + hasattr(module, "weight_quantizer") + and hasattr(module, "input_quantizer") + and module.weight_quantizer._num_bits == 8 + and module.input_quantizer._num_bits == 8 + ): + return True + return False + + +def is_fp8_quantized(model: nn.Module) -> bool: + """Check if the model is quantized in FP8 mode.""" + for _, module in model.named_modules(): + if ( + hasattr(module, "weight_quantizer") + and hasattr(module, "input_quantizer") + and module.weight_quantizer._num_bits == (4, 3) + and module.input_quantizer._num_bits == (4, 3) + ): + return True + return False + + def get_onnx_bytes_and_metadata( model: nn.Module, dummy_input: Any | tuple, @@ -510,6 +537,9 @@ def get_onnx_bytes_and_metadata( onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False ) + # TensorRT expects all scales to be postive + onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) + # If the onnx model contains external data store the external tensors in one file and save the onnx model if has_external_data(onnx_save_path): tensor_paths = get_external_tensor_paths(onnx_path)