diff --git a/modelopt/onnx/export/nvfp4_exporter.py b/modelopt/onnx/export/nvfp4_exporter.py index a80a9845fb5..e8bdfa2db1f 100644 --- a/modelopt/onnx/export/nvfp4_exporter.py +++ b/modelopt/onnx/export/nvfp4_exporter.py @@ -39,6 +39,10 @@ def _cast_fp4(array: np.ndarray) -> np.ndarray: Note: The first dimension of the array must be divisible by 2 as two FP4 values are packed into a single byte. + + Also reused by the deprecated ``modelopt.onnx.quantization.qdq_utils.fp4qdq_to_2dq`` + compatibility shim. Do not rename or change the signature without updating that + shim (it is a load-bearing re-export for TensorRT-Edge-LLM 0.6.1). """ array_f32_t = torch.from_numpy(array) array_f32_t_shape = array_f32_t.shape @@ -76,6 +80,10 @@ def _replace_fp4qdq_with_2dq( ): """Replaces the given node in the ONNX graph with a subgraph consisting of two DequantizeLinear nodes. + Also reused by the deprecated ``modelopt.onnx.quantization.qdq_utils.fp4qdq_to_2dq`` + compatibility shim. Do not rename or change the signature without updating that + shim (it is a load-bearing re-export for TensorRT-Edge-LLM 0.6.1). + Args: graph: The ONNX graph containing the node to replace. node: The node to be replaced. diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 265bcf36b2a..28e6f8ada8b 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -15,6 +15,7 @@ """Various utils to support inserting Q/DQ nodes.""" +import warnings from collections.abc import Sequence from typing import Any @@ -31,7 +32,16 @@ get_tensor_producer_nodes, remove_redundant_cast_nodes, ) -from modelopt.onnx.quantization.quant_utils import get_num_bits +from modelopt.onnx.quantization.quant_utils import ( + compute_e8m0, + get_amax, + get_num_bits, + get_weights_scaling_factor, + get_weights_scaling_factor_2, + pack_weights_to_int4, + quantize, +) +from modelopt.onnx.utils import get_attribute, has_attribute, read_f16_tensor_as_fp32 QUANTIZE_NODE_NAME = "QuantizeLinear" DEQUANTIZE_NODE_NAME = "DequantizeLinear" @@ -1224,3 +1234,384 @@ def get_quantized_tensors(onnx_model: onnx.ModelProto) -> set[str]: logger.debug(f"Found {len(quantized_tensors)} dequantized tensors in ONNX model") return quantized_tensors + + +_LEGACY_LLM_EXPORT_DEPRECATION_MSG = ( + "{name} in modelopt.onnx.quantization.qdq_utils is deprecated and will be " + "removed in a future release. Use modelopt.onnx.export " + "(INT4QuantExporter / NVFP4QuantExporter / MXFP8QuantExporter), or migrate to " + "TensorRT-Edge-LLM (https://github.com/NVIDIA/TensorRT-Edge-LLM)." +) + + +def quantize_weights_to_int4( + onnx_model: onnx.ModelProto, +) -> onnx.ModelProto: + """Deprecated: convert ONNX model weights to INT4 with graph optimization. + + Preserved as a compatibility shim for TensorRT-Edge-LLM 0.6.1 and earlier. + New code should use :class:`modelopt.onnx.export.int4_exporter.INT4QuantExporter`. + """ + warnings.warn( + _LEGACY_LLM_EXPORT_DEPRECATION_MSG.format(name="quantize_weights_to_int4"), + DeprecationWarning, + stacklevel=2, + ) + + graph = onnx_model.graph + initializer_map = {initializer.name: initializer for initializer in graph.initializer} + value_info_map = {value_info.name: value_info for value_info in graph.value_info} + weight_dq_nodes = [node for node in graph.node if node.op_type == "DequantizeLinear"] + tensor_producer_map = get_tensor_producer_nodes(graph) + + nodes_to_remove = [] + for node in weight_dq_nodes: + weight_name = node.input[0] + scale_name = node.input[1] + logger.debug(f"Processing INT4 conversion for weight {weight_name}") + weight = numpy_helper.to_array(initializer_map[weight_name]) + if scale_name in initializer_map: + scale = numpy_helper.to_array(initializer_map[scale_name]) + else: + scale_constant_node = tensor_producer_map[scale_name] + for attr in scale_constant_node.attribute: + if attr.name == "value": + tensor = attr.t + scale = numpy_helper.to_array(tensor) + + weight = weight / scale + block_size = weight.shape[-1] + + # Convert DequantizeLinear -> Reshape -> Transpose -> MatMul/Gemm to DequantizeLinear -> Matmul/Gemm + dq_child_nodes = [n for n in graph.node if node.output[0] in n.input] + reshape_node = dq_child_nodes[0] + nodes_to_remove.append(reshape_node.name) + assert reshape_node.op_type == "Reshape", f"Expected Reshape node for {node.name}" + reshape_node_output = reshape_node.output[0] + + # Remove constant node from reshape node + shape_constant_name = next(input for input in reshape_node.input if "Constant" in input) + nodes_to_remove.append(tensor_producer_map[shape_constant_name].name) + + # Get the shape of the output of the reshape node + reshape_output_value_info = value_info_map.get(reshape_node_output) + if reshape_output_value_info is not None: + weight_shape = [ + dim.dim_value for dim in reshape_output_value_info.type.tensor_type.shape.dim + ] + else: + raise ValueError(f"Unable to determine shape of weight tensor {weight_name}") + + # Reshape weights and scales + weight = weight.reshape(weight_shape) + assert weight_shape[-1] % block_size == 0, ( + f"Block size {block_size} is not divisible by {weight_shape[-1]}" + ) + scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size] + scale = scale.reshape(scale_shape) + reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input] + assert len(reshape_child_nodes) == 1, f"Expected exactly one child node for {node.name}" + + # Check if there's an optional Cast node between Reshape and Transpose/MatMul/Gemm + next_node = reshape_child_nodes[0] + if next_node.op_type == "Cast": + # Remove unnecessary Cast node + cast_node = next_node + nodes_to_remove.append(cast_node.name) + cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input] + next_node = cast_child_nodes[0] + + # Transpose weights and scales if present + if next_node.op_type == "Transpose": + transpose_node = next_node + nodes_to_remove.append(transpose_node.name) + assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}" + perm = None + for attr in transpose_node.attribute: + if attr.name == "perm": + perm = list(attr.ints) + assert perm is not None, f"Permutation not found for {node.name}" + weight = weight.transpose(perm) + scale = scale.transpose(perm) + transpose_child_nodes = [n for n in graph.node if transpose_node.output[0] in n.input] + assert len(transpose_child_nodes) == 1, ( + f"Expected exactly one matmul node for {node.name}" + ) + matmul_node = transpose_child_nodes[0] + else: + matmul_node = next_node + assert matmul_node.op_type in ["MatMul", "Gemm"], ( + f"Expected MatMul or Gemm node for {node.name}" + ) + matmul_node.input[1] = node.output[0] + + if scale_name not in initializer_map: + # Remove scale producer if it's a Constant node + scale_name = node.input[1] + scale_producer = tensor_producer_map[scale_name] + if scale_producer.op_type == "Constant": + graph.node.remove(scale_producer) + + # Create a new scale tensor + scale_name = scale_name.replace("Constant_output_0", "scale") + scale_tensor = onnx.numpy_helper.from_array(scale, scale_name) + graph.initializer.append(scale_tensor) + node.input[1] = scale_name + else: + scale_tensor = onnx.numpy_helper.from_array(scale, scale_name) + initializer_map[scale_name].CopyFrom(scale_tensor) + + # Convert weights to INT4 precision + weight_shape = weight.shape + weights_int4_np = pack_weights_to_int4(weight) + weights_int4_onnx = onnx.numpy_helper.from_array(weights_int4_np, weight_name) + weights_int4_onnx.data_type = onnx.TensorProto.INT4 + weights_int4_onnx.dims[0] = weight_shape[0] + initializer_map[weight_name].CopyFrom(weights_int4_onnx) + logger.debug(f"Converted {weight_name} to INT4 precision") + + def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool: + has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input) + return node.op_type == "Mul" and has_pqs_input + + # Remove unnecessary Cast after Pre-quant scale + for node in graph.node: + if is_pre_quant_scale_node(node): + pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input] + assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}" + cast_node = pqs_child_nodes[0] + assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" + node.output.clear() + node.output.extend(cast_node.output) + nodes_to_remove.append(cast_node.name) + + # Remove transpose and reshape nodes + new_nodes = [node for node in graph.node if node.name not in nodes_to_remove] + del graph.node[:] + graph.node.extend(new_nodes) + + # Cast bias to float16 + for node in graph.node: + if node.op_type == "Add" and "proj/Add" in node.name: + cast_initializer_to_dtype(node, "Half", initializer_map) + + # Cast pre quant scales of o_proj and down_proj to float16 + for node in graph.node: + if node.op_type == "Mul" and ( + any( + x in node.name + for x in ("o_proj/input_quantizer/Mul", "down_proj/input_quantizer/Mul") + ) + ): + cast_initializer_to_dtype(node, "Half", initializer_map) + + return onnx_model + + +def quantize_weights_to_mxfp8( + onnx_model: onnx.ModelProto, +) -> onnx.ModelProto: + """Deprecated: convert weights to MXFP8 (FP8 with e8m0 per-block scales). + + Preserved as a compatibility shim for TensorRT-Edge-LLM 0.6.1 and earlier. + New code should use :class:`modelopt.onnx.export.mxfp8_exporter.MXFP8QuantExporter`. + """ + warnings.warn( + _LEGACY_LLM_EXPORT_DEPRECATION_MSG.format(name="quantize_weights_to_mxfp8"), + DeprecationWarning, + stacklevel=2, + ) + + logger.info("Converting weights to MXFP8 precision") + graph = onnx_model.graph + initializer_map = {initializer.name: initializer for initializer in graph.initializer} + tensor_producer_map = get_tensor_producer_nodes(graph) + e8_m0_bias = 127 + weight_dq_nodes = [ + node + for node in graph.node + if node.op_type == "TRT_MXFP8DequantizeLinear" + and any(".weight" in input for input in node.input) + ] + gelu_nodes = [node for node in graph.node if node.op_type == "Gelu"] + logger.debug(f"Found {len(weight_dq_nodes)} weight DQ nodes and {len(gelu_nodes)} GELU nodes") + + for node in weight_dq_nodes: + # Get weights and node attributes + weight_name = node.input[0] + logger.debug(f"Processing MXFP8 conversion for weight {weight_name}") + weight = numpy_helper.to_array(initializer_map[weight_name]) + if has_attribute(node, "axis"): + quant_axis = int(get_attribute(node, "axis")) + else: + quant_axis = -1 + logger.warning( + "axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1" + ) + + if has_attribute(node, "block_size"): + block_size = int(get_attribute(node, "block_size")) + else: + block_size = 32 + logger.warning( + "block_size attribute not found for MXFP8DequantizeLinear node. Setting block_size to 32" + ) + + # Compute and save scales as uint8 + amax = get_amax(weight, quant_axis, block_size) + se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size) + se8m0 = se8m0_fp32.astype(np.uint8) + + # Remove scale producer if it's a Constant node + scale_name = node.input[1] + scale_producer = tensor_producer_map[scale_name] + if scale_producer.op_type == "Constant": + graph.node.remove(scale_producer) + + # Create a new scale tensor + scale_name = scale_name.replace("Constant_output_0", "scale") + scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name) + graph.initializer.append(scale_tensor) + node.input[1] = scale_name + + # Convert weights to FP8 + # Expand block array so that it can be broadcasted with weight + se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis) + scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias) + weights_e4m3 = onnx.helper.make_tensor( + name=weight_name, + data_type=onnx_dtype_map["Float8"], + dims=[*scaled_weight.shape], + vals=_cast_fp8(scaled_weight).tobytes(), + raw=True, + ) + initializer_map[weight_name].CopyFrom(weights_e4m3) + logger.debug(f"Converted {weight_name} to MXFP8") + + # set output type of DQ to FP16 + for node in graph.node: + if node.op_type in ["TRT_MXFP8DequantizeLinear"]: + for attr in node.attribute: + if attr.name == "output_dtype": + attr.i = onnx_dtype_map["Half"] + + # Currently only tanh approximation is supported for Gelu + for node in gelu_nodes: + for attr in node.attribute: + if attr.name == "approximate": + attr.s = b"tanh" + logger.debug(f"Updated GELU node {node.name} to use tanh approximation") + + return onnx_model + + +def fp4qdq_to_2dq(onnx_model: onnx.ModelProto, verbose: bool = False) -> onnx.ModelProto: + """Deprecated: convert FP32/FP16 weights of TRT_FP4QDQ nodes to FP4 + 2 DQ subgraph. + + Preserved as a compatibility shim for TensorRT-Edge-LLM 0.6.1 and earlier. + New code should use :class:`modelopt.onnx.export.nvfp4_exporter.NVFP4QuantExporter`. + """ + warnings.warn( + _LEGACY_LLM_EXPORT_DEPRECATION_MSG.format(name="fp4qdq_to_2dq"), + DeprecationWarning, + stacklevel=2, + ) + + # Lazy import to avoid a circular import: nvfp4_exporter imports from this module. + from modelopt.onnx.export.nvfp4_exporter import _cast_fp4, _replace_fp4qdq_with_2dq + + logger.info("Converting model with FP4QDQ nodes to 2DQ only model") + graph = onnx_model.graph + initializers = graph.initializer + initializers_to_delete = [] + tensor_consumers = get_tensor_consumer_nodes(graph) + initializer_indices = { + initializer.name: idx for idx, initializer in enumerate(graph.initializer) + } + value_info_map = {vi.name: vi for vi in graph.value_info} + graph_inputs = {inp.name for inp in graph.input} + + def _cast_input_dtypes(node: onnx.NodeProto, precision_dtype: str): + # Change the input types to match weight precision (precision_dtype) + if node.op_type == "Transpose": + maybe_matmul = tensor_consumers[node.output[0]][0] + assert maybe_matmul.op_type == "MatMul" + node = maybe_matmul + + # Create Cast nodes for each input of the target node except bias + for i, input_name in enumerate(node.input[:2]): + cast_output_name = input_name + "_f16" + + cast_node = onnx.helper.make_node( + "Cast", + inputs=[input_name], + outputs=[cast_output_name], + to=onnx_dtype_map[precision_dtype], + ) + + graph.node.extend([cast_node]) + node.input[i] = cast_output_name + + def _get_precision_dtype() -> str: + precision_dtype = "Half" + for initializer in graph.initializer: + if initializer.data_type == onnx.TensorProto.BFLOAT16: + precision_dtype = "BFloat16" + break + return precision_dtype + + if verbose: + logger.info("Post-processing TRT_FP4QDQ nodes for TRT deployment") + precision_dtype = _get_precision_dtype() + logger.debug(f"Using precision dtype: {precision_dtype}") + fp4_qdq_nodes = [node for node in graph.node if node.op_type == "TRT_FP4QDQ"] + logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to convert") + + for node in fp4_qdq_nodes: + idx1 = initializer_indices.get(node.input[0], None) + assert idx1 is not None, f"Initializer for weight '{node.input[0]}' not found." + block_size_attr = next((attr for attr in node.attribute if attr.name == "block_size"), None) + assert block_size_attr is not None, f"block_size attribute not found for {node.name}" + block_size = block_size_attr.i + initializers_to_delete.append(initializers[idx1].name) + logger.debug( + f"Processing FP4QDQ node for weight {node.input[0]} with block size {block_size}" + ) + + tensor = initializers[idx1] + w32 = read_f16_tensor_as_fp32(tensor) + sw_f32_per_tensor = get_weights_scaling_factor_2(w32) + sw_f32_per_block = get_weights_scaling_factor(w32, block_size, sw_f32_per_tensor) + w_f32 = quantize(w32, block_size, sw_f32_per_block, sw_f32_per_tensor) + + # Real quantize the tensors + w_f4 = _cast_fp4(w_f32) + sw_f8_per_block = _cast_fp8(sw_f32_per_block) + + _replace_fp4qdq_with_2dq( + graph, + node, + initializer_indices, + value_info_map, + graph_inputs, + w_f4, + sw_f32_per_tensor, + sw_f8_per_block, + block_size, + ) + + # We need to change the bias etc. type + next_node = tensor_consumers[node.output[0]][0] + _cast_input_dtypes(next_node, precision_dtype) + + if verbose: + logger.debug(f"Replaced {node.name} with 2 DQ nodes") + + new_initializers = [ + init for init in graph.initializer if init.name not in initializers_to_delete + ] + graph.ClearField("initializer") + graph.initializer.extend(new_initializers) + logger.info(f"Removed {len(initializers_to_delete)} initializers") + + return onnx_model diff --git a/tests/unit/onnx/quantization/test_qdq_utils.py b/tests/unit/onnx/quantization/test_qdq_utils.py index 8af5f560dd0..0ff3686a610 100644 --- a/tests/unit/onnx/quantization/test_qdq_utils.py +++ b/tests/unit/onnx/quantization/test_qdq_utils.py @@ -1108,3 +1108,96 @@ def test_constant_node_scale_path_still_patched(self): scale_arr = numpy_helper.to_array(value_attr.t) assert not (scale_arr == 0).any() assert (scale_arr > 0).all() + + +class TestLegacyEdgeLLMShims: + """Smoke tests for the deprecated top-level shims kept for TensorRT-Edge-LLM 0.6.1. + + These are the functions edgellm 0.6.1 imports from + ``modelopt.onnx.quantization.qdq_utils`` directly (not via the staged exporters). + Tests verify each shim runs end-to-end on the same fixtures used for the staged + exporters and emits a ``DeprecationWarning``. + """ + + def test_quantize_weights_to_int4_shim(self): + import warnings + + from modelopt.onnx.quantization.qdq_utils import quantize_weights_to_int4 + + model = create_test_model_with_int4_dq_reshape_transpose_matmul() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + quantized_model = quantize_weights_to_int4(model) + + assert any( + issubclass(w.category, DeprecationWarning) + and "quantize_weights_to_int4" in str(w.message) + for w in caught + ) + + weight_tensor = next( + init for init in quantized_model.graph.initializer if init.name == "weight" + ) + assert weight_tensor.data_type == TensorProto.INT4 + + node_types = [node.op_type for node in quantized_model.graph.node] + assert "Reshape" not in node_types + assert "Transpose" not in node_types + + def test_quantize_weights_to_mxfp8_shim(self): + import warnings + + from modelopt.onnx.quantization.qdq_utils import quantize_weights_to_mxfp8 + + model = create_test_model_with_mxfp8_dq() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + quantized_model = quantize_weights_to_mxfp8(model) + + assert any( + issubclass(w.category, DeprecationWarning) + and "quantize_weights_to_mxfp8" in str(w.message) + for w in caught + ) + + weight_tensor = next( + init for init in quantized_model.graph.initializer if init.name == "linear.weight" + ) + assert weight_tensor.data_type == TensorProto.FLOAT8E4M3FN + + gelu_node = next(node for node in quantized_model.graph.node if node.op_type == "Gelu") + approximate_attr = next(attr for attr in gelu_node.attribute if attr.name == "approximate") + assert approximate_attr.s == b"tanh" + + @pytest.mark.parametrize("with_transpose", [False, True]) + def test_fp4qdq_to_2dq_shim(self, with_transpose): + import warnings + + from modelopt.onnx.quantization.qdq_utils import fp4qdq_to_2dq + + model = create_test_model_with_nvfp4_qdq(with_transpose=with_transpose) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + converted_model = fp4qdq_to_2dq(model) + + assert any( + issubclass(w.category, DeprecationWarning) and "fp4qdq_to_2dq" in str(w.message) + for w in caught + ) + + fp4qdq_nodes = [node for node in converted_model.graph.node if node.op_type == "TRT_FP4QDQ"] + assert len(fp4qdq_nodes) == 0 + + dq_nodes = [ + node for node in converted_model.graph.node if node.op_type == "DequantizeLinear" + ] + assert len(dq_nodes) == 2 + + initializer_names = {init.name for init in converted_model.graph.initializer} + assert "linear.weight_f4" in initializer_names + assert "linear.weight_f8_scale" in initializer_names + assert "linear.weight_f8_scale_f32_scale" in initializer_names + assert "linear.weight" not in initializer_names