diff --git a/modelopt/onnx/export/__init__.py b/modelopt/onnx/export/__init__.py new file mode 100644 index 000000000..39d05aff4 --- /dev/null +++ b/modelopt/onnx/export/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ONNX export utilities.""" diff --git a/modelopt/onnx/export/quant_exporter.py b/modelopt/onnx/export/quant_exporter.py new file mode 100644 index 000000000..81ccecf9b --- /dev/null +++ b/modelopt/onnx/export/quant_exporter.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ONNX quantizer exporters.""" + +from abc import ABC, abstractmethod + +import onnx +from onnx import numpy_helper + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.graph_utils import get_tensor_producer_nodes +from modelopt.onnx.quantization.qdq_utils import cast_initializer_to_dtype +from modelopt.onnx.quantization.quant_utils import pack_weights_to_int4 + + +class ONNXQuantExporter(ABC): + """Base class for ONNX quantizer exporters.""" + + @staticmethod + @abstractmethod + def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Computes the scales for the weights in the ONNX model.""" + + @staticmethod + @abstractmethod + def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Compresses the weights in the ONNX model.""" + + @staticmethod + @abstractmethod + def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Post-processes the ONNX model.""" + + +# TODO: Implement the MXFP8QuantExporter +class MXFP8QuantExporter(ONNXQuantExporter): + """Exporter for MXFP8 quantization.""" + + @staticmethod + def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Computes the scales for the weights in the ONNX model for MXFP8 quantization.""" + + @staticmethod + def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Compresses the weights in the ONNX model for MXFP8 quantization.""" + + @staticmethod + def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Post-processes the ONNX model for MXFP8 quantization.""" + + +# TODO: Implement the FP8QuantExporter +class FP8QuantExporter(ONNXQuantExporter): + """Exporter for FP8 quantization.""" + + @staticmethod + def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Computes the scales for the weights in the ONNX model for FP8 quantization.""" + + @staticmethod + def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Compresses the weights in the ONNX model for FP8 quantization.""" + + @staticmethod + def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Post-processes the ONNX model for FP8 quantization.""" + + +# TODO: Implement the INT8QuantExporter +class INT8QuantExporter(ONNXQuantExporter): + """Exporter for INT8 quantization.""" + + @staticmethod + def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Computes the scales for the weights in the ONNX model for INT8 quantization.""" + + @staticmethod + def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Compresses the weights in the ONNX model for INT8 quantization.""" + + @staticmethod + def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Post-processes the ONNX model for INT8 quantization.""" + + +class INT4QuantExporter(ONNXQuantExporter): + """Exporter for INT4 quantization.""" + + @staticmethod + def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Computes the scales for the weights in the ONNX model for INT4 quantization.""" + graph = onnx_model.graph + initializer_map = {initializer.name: initializer for initializer in graph.initializer} + value_info_map = {value_info.name: value_info for value_info in graph.value_info} + weight_dq_nodes = [node for node in graph.node if node.op_type == "DequantizeLinear"] + tensor_producer_map = get_tensor_producer_nodes(graph) + + nodes_to_remove = [] + for node in weight_dq_nodes: + weight_name = node.input[0] + scale_name = node.input[1] + logger.debug(f"Processing INT4 conversion for weight {weight_name}") + weight = numpy_helper.to_array(initializer_map[weight_name]) + if scale_name in initializer_map: + scale = numpy_helper.to_array(initializer_map[scale_name]) + else: + scale_constant_node = tensor_producer_map[scale_name] + for attr in scale_constant_node.attribute: + if attr.name == "value": + tensor = attr.t + scale = numpy_helper.to_array(tensor) + + weight = weight / scale + block_size = weight.shape[-1] + + ## Convert DequantizeLinear -> Reshape -> Transpose -> MatMul/Gemm to DequantizeLinear -> Matmul/Gemm + dq_child_nodes = [n for n in graph.node if node.output[0] in n.input] + reshape_node = dq_child_nodes[0] + nodes_to_remove.append(reshape_node.name) + assert reshape_node.op_type == "Reshape", f"Expected Reshape node for {node.name}" + reshape_node_output = reshape_node.output[0] + + # Remove constant node from reshape node + shape_constant_name = next(input for input in reshape_node.input if "Constant" in input) + nodes_to_remove.append(tensor_producer_map[shape_constant_name].name) + + # Get the shape of the output of the reshape node + reshape_output_value_info = value_info_map.get(reshape_node_output) + if reshape_output_value_info is not None: + weight_shape = [ + dim.dim_value for dim in reshape_output_value_info.type.tensor_type.shape.dim + ] + else: + raise ValueError(f"Unable to determine shape of weight tensor {weight_name}") + + # Reshape weights and scales + weight = weight.reshape(weight_shape) + assert weight_shape[-1] % block_size == 0, ( + f"Block size {block_size} is not divisible by {weight_shape[-1]}" + ) + scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size] + scale = scale.reshape(scale_shape) + reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input] + assert len(reshape_child_nodes) == 1, f"Expected exactly one child node for {node.name}" + + # Check if there's an optional Cast node between Reshape and Transpose/MatMul/Gemm + next_node = reshape_child_nodes[0] + if next_node.op_type == "Cast": + # Remove unnecessary Cast node + cast_node = next_node + nodes_to_remove.append(cast_node.name) + cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input] + next_node = cast_child_nodes[0] + + # Transpose weights and scales if present + if next_node.op_type == "Transpose": + transpose_node = next_node + nodes_to_remove.append(transpose_node.name) + assert transpose_node.op_type == "Transpose", ( + f"Expected Transpose node for {node.name}" + ) + perm = None + for attr in transpose_node.attribute: + if attr.name == "perm": + perm = [x for x in attr.ints] # noqa: C416 + assert perm is not None, f"Permutation not found for {node.name}" + weight = weight.transpose(perm) + scale = scale.transpose(perm) + transpose_child_nodes = [ + n for n in graph.node if transpose_node.output[0] in n.input + ] + # transpose_node.input = [] + assert len(transpose_child_nodes) == 1, ( + f"Expected exactly one matmul node for {node.name}" + ) + matmul_node = transpose_child_nodes[0] + else: + matmul_node = next_node + assert matmul_node.op_type in ["MatMul", "Gemm"], ( + f"Expected MatMul or Gemm node for {node.name}" + ) + matmul_node.input[1] = node.output[0] + + if scale_name not in initializer_map: + # Remove scale producer if it's a Constant node + scale_name = node.input[1] + scale_producer = tensor_producer_map[scale_name] + if scale_producer.op_type == "Constant": + graph.node.remove(scale_producer) + + # Create a new scale tensor + scale_name = scale_name.replace("Constant_output_0", "scale") + scale_tensor = onnx.numpy_helper.from_array(scale, scale_name) + graph.initializer.append(scale_tensor) + node.input[1] = scale_name + else: + scale_tensor = onnx.numpy_helper.from_array(scale, scale_name) + initializer_map[scale_name].CopyFrom(scale_tensor) + + weight = numpy_helper.from_array(weight, weight_name) + initializer_map[weight_name].CopyFrom(weight) + logger.debug(f"Computed scales for weight {weight_name} for INT4 quantization") + + # Remove transpose and reshape nodes + new_nodes = [node for node in graph.node if node.name not in nodes_to_remove] + del graph.node[:] + graph.node.extend(new_nodes) + + return onnx_model + + @staticmethod + def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Compresses the weights in the ONNX model for INT4 quantization.""" + graph = onnx_model.graph + initializer_map = {initializer.name: initializer for initializer in graph.initializer} + weight_dq_nodes = [node for node in graph.node if node.op_type == "DequantizeLinear"] + + for node in weight_dq_nodes: + weight_name = node.input[0] + weight = numpy_helper.to_array(initializer_map[weight_name]) + weight_shape = weight.shape + weights_int4_np = pack_weights_to_int4(weight) + weights_int4_onnx = onnx.numpy_helper.from_array(weights_int4_np, weight_name) + weights_int4_onnx.data_type = onnx.TensorProto.INT4 + weights_int4_onnx.dims[0] = weight_shape[0] + initializer_map[weight_name].CopyFrom(weights_int4_onnx) + logger.debug(f"Converted {weight_name} to INT4 precision") + + return onnx_model + + @staticmethod + def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Post-processes the ONNX model for INT4 quantization.""" + + def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool: + has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input) + return node.op_type == "Mul" and has_pqs_input + + graph = onnx_model.graph + initializer_map = {initializer.name: initializer for initializer in graph.initializer} + nodes_to_remove = [] + + def is_fp32_cast(node: onnx.NodeProto) -> bool: + return node.op_type == "Cast" and any( + attr.name == "to" and attr.i == onnx.TensorProto.FLOAT for attr in node.attribute + ) + + # Remove Cast nodes after specific operators + for node in graph.node: + if node.op_type in ["Transpose", "Reshape", "Sqrt", "Add", "Gelu"]: + child_nodes = [n for n in graph.node if node.output[0] in n.input] + if len(child_nodes) == 1 and is_fp32_cast(child_nodes[0]): + cast_node = child_nodes[0] + node.output.clear() + node.output.extend(cast_node.output) + nodes_to_remove.append(cast_node.name) + + # Remove unnecessay Cast after Pre-quant scale + for node in graph.node: + if is_pre_quant_scale_node(node): + pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input] + assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}" + cast_node = pqs_child_nodes[0] + assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" + node.output.clear() + node.output.extend(cast_node.output) + nodes_to_remove.append(cast_node.name) + + # Remove unnecessary casts + new_nodes = [node for node in graph.node if node.name not in nodes_to_remove] + del graph.node[:] + graph.node.extend(new_nodes) + + # Cast bias to float16 + for node in graph.node: + if node.op_type == "Add" and "proj/Add" in node.name: + cast_initializer_to_dtype(node, "Half", initializer_map) + + # Cast pre quant scales of o_proj and down_proj to float16 + for node in graph.node: + if node.op_type == "Mul" and ( + any( + x in node.name + for x in ("o_proj/input_quantizer/Mul", "down_proj/input_quantizer/Mul") + ) + ): + cast_initializer_to_dtype(node, "Half", initializer_map) + + return onnx_model + + +# TODO: Implement the NVFP4QuantExporter +class NVFP4QuantExporter(ONNXQuantExporter): + """Exporter for NVFP4 quantization.""" + + @staticmethod + def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Computes the scales for the weights in the ONNX model for NVFP4 quantization.""" + + @staticmethod + def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Compresses the weights in the ONNX model for NVFP4 quantization.""" + + @staticmethod + def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Post-processes the ONNX model for NVFP4 quantization.""" diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 66c613a6c..ab8c5c2c7 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -38,7 +38,6 @@ get_num_bits, get_weights_scaling_factor, get_weights_scaling_factor_2, - pack_weights_to_int4, quantize, ) from modelopt.onnx.utils import get_attribute, has_attribute @@ -1037,9 +1036,10 @@ def remove_graph_input_q(onnx_model: onnx.ModelProto) -> onnx.ModelProto: return onnx_model -def _cast_initializer_to_dtype( +def cast_initializer_to_dtype( node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto] ): + """Casts the initializer to the given dtype.""" for id, input_name in enumerate(node.input): if input_name in initializer_map: input_id = id @@ -1051,180 +1051,6 @@ def _cast_initializer_to_dtype( initializer_map[input_name].CopyFrom(input_onnx) -def quantize_weights_to_int4( - onnx_model: onnx.ModelProto, -) -> onnx.ModelProto: - """Converts ONNX model weights from higher precision to INT4 precision with graph optimization. - - This function performs a comprehensive transformation of quantized weights in an ONNX model: - 1. Identifies DequantizeLinear nodes that represent quantized weights - 2. Extracts and processes weights and their corresponding scales - 3. Simplifies the graph by removing unnecessary Reshape/Transpose operations - 4. Converts weights to INT4 precision while maintaining numerical accuracy - 5. Updates Cast operations to use float16 instead of float32 - - The transformation optimizes the typical pattern: - DequantizeLinear -> Reshape -> Transpose -> MatMul/Gemm - Into the simplified pattern: - DequantizeLinear -> MatMul/Gemm - - Args: - onnx_model (onnx.ModelProto): Input ONNX model containing quantized weights. - - Returns: - onnx.ModelProto: Weights converted to INT4 precision - """ - graph = onnx_model.graph - initializer_map = {initializer.name: initializer for initializer in graph.initializer} - value_info_map = {value_info.name: value_info for value_info in graph.value_info} - weight_dq_nodes = [node for node in graph.node if node.op_type == "DequantizeLinear"] - tensor_producer_map = get_tensor_producer_nodes(graph) - - nodes_to_remove = [] - for node in weight_dq_nodes: - weight_name = node.input[0] - scale_name = node.input[1] - logger.debug(f"Processing INT4 conversion for weight {weight_name}") - weight = numpy_helper.to_array(initializer_map[weight_name]) - if scale_name in initializer_map: - scale = numpy_helper.to_array(initializer_map[scale_name]) - else: - scale_constant_node = tensor_producer_map[scale_name] - for attr in scale_constant_node.attribute: - if attr.name == "value": - tensor = attr.t - scale = numpy_helper.to_array(tensor) - - weight = weight / scale - block_size = weight.shape[-1] - - ## Convert DequantizeLinear -> Reshape -> Transpose -> MatMul/Gemm to DequantizeLinear -> Matmul/Gemm - dq_child_nodes = [n for n in graph.node if node.output[0] in n.input] - reshape_node = dq_child_nodes[0] - nodes_to_remove.append(reshape_node.name) - assert reshape_node.op_type == "Reshape", f"Expected Reshape node for {node.name}" - reshape_node_output = reshape_node.output[0] - - # Remove constant node from reshape node - shape_constant_name = next(input for input in reshape_node.input if "Constant" in input) - nodes_to_remove.append(tensor_producer_map[shape_constant_name].name) - - # Get the shape of the output of the reshape node - reshape_output_value_info = value_info_map.get(reshape_node_output) - if reshape_output_value_info is not None: - weight_shape = [ - dim.dim_value for dim in reshape_output_value_info.type.tensor_type.shape.dim - ] - else: - raise ValueError(f"Unable to determine shape of weight tensor {weight_name}") - - # Reshape weights and scales - weight = weight.reshape(weight_shape) - assert weight_shape[-1] % block_size == 0, ( - f"Block size {block_size} is not divisible by {weight_shape[-1]}" - ) - scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size] - scale = scale.reshape(scale_shape) - reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input] - assert len(reshape_child_nodes) == 1, f"Expected exactly one child node for {node.name}" - - # Check if there's an optional Cast node between Reshape and Transpose/MatMul/Gemm - next_node = reshape_child_nodes[0] - if next_node.op_type == "Cast": - # Remove unnecessary Cast node - cast_node = next_node - nodes_to_remove.append(cast_node.name) - cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input] - next_node = cast_child_nodes[0] - - # Transpose weights and scales if present - if next_node.op_type == "Transpose": - transpose_node = next_node - nodes_to_remove.append(transpose_node.name) - assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}" - perm = None - for attr in transpose_node.attribute: - if attr.name == "perm": - perm = [x for x in attr.ints] # noqa: C416 - assert perm is not None, f"Permutation not found for {node.name}" - weight = weight.transpose(perm) - scale = scale.transpose(perm) - transpose_child_nodes = [n for n in graph.node if transpose_node.output[0] in n.input] - # transpose_node.input = [] - assert len(transpose_child_nodes) == 1, ( - f"Expected exactly one matmul node for {node.name}" - ) - matmul_node = transpose_child_nodes[0] - else: - matmul_node = next_node - assert matmul_node.op_type in ["MatMul", "Gemm"], ( - f"Expected MatMul or Gemm node for {node.name}" - ) - matmul_node.input[1] = node.output[0] - - if scale_name not in initializer_map: - # Remove scale producer if it's a Constant node - scale_name = node.input[1] - scale_producer = tensor_producer_map[scale_name] - if scale_producer.op_type == "Constant": - graph.node.remove(scale_producer) - - # Create a new scale tensor - scale_name = scale_name.replace("Constant_output_0", "scale") - scale_tensor = onnx.numpy_helper.from_array(scale, scale_name) - graph.initializer.append(scale_tensor) - node.input[1] = scale_name - else: - scale_tensor = onnx.numpy_helper.from_array(scale, scale_name) - initializer_map[scale_name].CopyFrom(scale_tensor) - - # Convert weights to INT4 precision - weight_shape = weight.shape - weights_int4_np = pack_weights_to_int4(weight) - weights_int4_onnx = onnx.numpy_helper.from_array(weights_int4_np, weight_name) - weights_int4_onnx.data_type = onnx.TensorProto.INT4 - weights_int4_onnx.dims[0] = weight_shape[0] - initializer_map[weight_name].CopyFrom(weights_int4_onnx) - logger.debug(f"Converted {weight_name} to INT4 precision") - - def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool: - has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input) - return node.op_type == "Mul" and has_pqs_input - - # Remove unnecessay Cast after Pre-quant scale - for node in graph.node: - if is_pre_quant_scale_node(node): - pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input] - assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}" - cast_node = pqs_child_nodes[0] - assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" - node.output.clear() - node.output.extend(cast_node.output) - nodes_to_remove.append(cast_node.name) - - # Remove transpose and reshape nodes - new_nodes = [node for node in graph.node if node.name not in nodes_to_remove] - del graph.node[:] - graph.node.extend(new_nodes) - - # Cast bias to float16 - for node in graph.node: - if node.op_type == "Add" and "proj/Add" in node.name: - _cast_initializer_to_dtype(node, "Half", initializer_map) - - # Cast pre quant scales of o_proj and down_proj to float16 - for node in graph.node: - if node.op_type == "Mul" and ( - any( - x in node.name - for x in ("o_proj/input_quantizer/Mul", "down_proj/input_quantizer/Mul") - ) - ): - _cast_initializer_to_dtype(node, "Half", initializer_map) - - return onnx_model - - def quantize_weights_to_mxfp8( onnx_model: onnx.ModelProto, ) -> onnx.ModelProto: diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index cfebd0dc1..1641a788f 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -32,12 +32,13 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel from modelopt.onnx.autocast.convert import convert_to_f16 -from modelopt.onnx.quantization.qdq_utils import ( - fp4qdq_to_2dq, - qdq_to_dq, - quantize_weights_to_int4, - quantize_weights_to_mxfp8, +from modelopt.onnx.export.quant_exporter import ( + INT4QuantExporter, + MXFP8QuantExporter, + NVFP4QuantExporter, + ONNXQuantExporter, ) +from modelopt.onnx.quantization.qdq_utils import fp4qdq_to_2dq, qdq_to_dq, quantize_weights_to_mxfp8 from modelopt.onnx.utils import ( get_input_names, get_input_shapes, @@ -336,6 +337,49 @@ def is_mxfp8_quantized(model: nn.Module) -> bool: return False +def quantize_weights(model: nn.Module, onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Real quantizes the weights in the onnx model. + + Applies weight quantization to an ONNX model based on the quantization scheme detected + in the PyTorch model. Supports INT4, FP4, and MXFP8 quantization formats. + + The function performs a three-stage process for each detected quantization type: + 1. Compute scales - Calculate quantization scaling factors + 2. Compress weights - Convert weights to the target quantized format + 3. Post-process - Apply any final transformations or cleanup + + Args: + model (nn.Module): The original PyTorch model used to detect quantization schemes. + This model should have been quantized using modelopt's quantization APIs. + onnx_model (onnx.ModelProto): The ONNX model whose weights will be quantized. + + Returns: + onnx.ModelProto: The ONNX model with quantized weights applied. The returned model + contains compressed weight tensors in the appropriate quantization format. + + Notes: + - Multiple quantization formats can be applied sequentially if the model contains + different quantization schemes for different layers + - The function checks for INT4, FP4, and MXFP8 quantization in the PyTorch model + - Each quantization exporter modifies the ONNX graph in-place before returning + """ + + onnx_exporters: list[type[ONNXQuantExporter]] = [] + if is_int4_quantized(model): + onnx_exporters.append(INT4QuantExporter) + if is_fp4_quantized(model): + onnx_exporters.append(NVFP4QuantExporter) + if is_mxfp8_quantized(model): + onnx_exporters.append(MXFP8QuantExporter) + + for onnx_exporter in onnx_exporters: + onnx_model = onnx_exporter.compute_scales(onnx_model) + onnx_model = onnx_exporter.compress_weights(onnx_model) + onnx_model = onnx_exporter.post_process(onnx_model) + + return onnx_model + + def get_onnx_bytes_and_metadata( model: nn.Module, dummy_input: Any | tuple, @@ -482,10 +526,12 @@ def get_onnx_bytes_and_metadata( # Convert dummy TRT_FP4QDQ nodes to 2DQ format if the model is quantized in FP4 mode # Or convert weights to MXFP8 format if the model is quantized in MXFP8 mode if is_int4_quantized(model): - onnx_opt_graph = quantize_weights_to_int4(onnx_opt_graph) + onnx_opt_graph = quantize_weights(model, onnx_opt_graph) elif is_fp4_quantized(model): + # TODO: Implement the NVFP4QuantExporter onnx_opt_graph = fp4qdq_to_2dq(onnx_opt_graph) elif is_mxfp8_quantized(model): + # TODO: Implement the MXFP8QuantExporter onnx_opt_graph = quantize_weights_to_mxfp8(onnx_opt_graph) if dq_only: diff --git a/tests/unit/onnx/test_qdq_utils.py b/tests/unit/onnx/test_qdq_utils.py index e661b7c78..978d877cb 100644 --- a/tests/unit/onnx/test_qdq_utils.py +++ b/tests/unit/onnx/test_qdq_utils.py @@ -17,11 +17,11 @@ import pytest from onnx import TensorProto, helper, numpy_helper +from modelopt.onnx.export.quant_exporter import INT4QuantExporter from modelopt.onnx.quantization.qdq_utils import ( _cast_fp4, _cast_fp8, fp4qdq_to_2dq, - quantize_weights_to_int4, quantize_weights_to_mxfp8, ) @@ -337,7 +337,9 @@ def test_basic_quantization_with_reshape_transpose(self): model = create_test_model_with_int4_dq_reshape_transpose_matmul() # Run quantization - quantized_model = quantize_weights_to_int4(model) + quantized_model = INT4QuantExporter.compute_scales(model) + quantized_model = INT4QuantExporter.compress_weights(quantized_model) + quantized_model = INT4QuantExporter.post_process(quantized_model) # Verify weight is converted to INT4 weight_tensor = next( @@ -362,7 +364,9 @@ def test_quantization_with_constant_scale(self): model = create_test_model_with_int4_dq_reshape_transpose_matmul(constant_scale=True) # Run quantization - quantized_model = quantize_weights_to_int4(model) + quantized_model = INT4QuantExporter.compute_scales(model) + quantized_model = INT4QuantExporter.compress_weights(quantized_model) + quantized_model = INT4QuantExporter.post_process(quantized_model) # Verify Constant node is removed constant_nodes = [node for node in quantized_model.graph.node if node.op_type == "Constant"] @@ -385,7 +389,9 @@ def test_projection_bias_and_scale_casting(self): model = create_test_model_with_proj_nodes() # Run quantization - quantized_model = quantize_weights_to_int4(model) + quantized_model = INT4QuantExporter.compute_scales(model) + quantized_model = INT4QuantExporter.compress_weights(quantized_model) + quantized_model = INT4QuantExporter.post_process(quantized_model) # Verify bias tensor is cast to float16 bias_tensor = next(