diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 2707f6ff1c59..5397f2c30928 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -311,6 +311,73 @@ def get_converter(cls, opset): return getattr(cls, f"_impl_v{version}") raise NotImplementedError(f"opset version {version} of {cls.__name__} not implemented") +class QuantizeLinear(OnnxOpConverter): + @classmethod + def _impl_v10(cls, bb, inputs, attr, params): + x, scale = inputs[0], inputs[1] + zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None + axis = attr.get("axis", 1) + if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1: + axis = 0 + out_dtype = "uint8" if zp is None else zp.struct_info.dtype + if zp is None: + zp = relax.const(0, out_dtype) + return relax.op.quantize(x, scale, zp, axis=axis, out_dtype=out_dtype) + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + x, scale = inputs[0], inputs[1] + zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None + axis = attr.get("axis", 1) + if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1: + axis = 0 + out_dtype = "uint8" if zp is None else zp.struct_info.dtype + if zp is None: + zp = relax.const(0, out_dtype) + return relax.op.quantize(x, scale, zp, axis=axis, out_dtype=out_dtype) + + +class DequantizeLinear(OnnxOpConverter): + @classmethod + def _impl_v10(cls, bb, inputs, attr, params): + x, scale = inputs[0], inputs[1] + zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None + axis = attr.get("axis", 1) + if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1: + axis = 0 + if zp is None: + zp = relax.const(0, x.struct_info.dtype) + return relax.op.dequantize(x, scale, zp, axis=axis, out_dtype="float32") + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + x, scale = inputs[0], inputs[1] + zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None + axis = attr.get("axis", 1) + if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1: + axis = 0 + if zp is None: + zp = relax.const(0, x.struct_info.dtype) + return relax.op.dequantize(x, scale, zp, axis=axis, out_dtype="float32") + + +class DynamicQuantizeLinear(OnnxOpConverter): + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + x = inputs[0] + x_dtype = x.struct_info.dtype + qmin = relax.const(0, x_dtype) + qmax = relax.const(255, x_dtype) + + x_max = relax.op.maximum(qmin, relax.op.max(x)) + x_min = relax.op.minimum(qmin, relax.op.min(x)) + y_scale = relax.op.divide(relax.op.subtract(x_max, x_min), qmax) + + zp_fp = relax.op.subtract(qmin, relax.op.divide(x_min, y_scale)) + y_zero_point = relax.op.astype(relax.op.round(relax.op.clip(zp_fp, 0, 255)), "uint8") + + y = relax.op.quantize(x, y_scale, y_zero_point, axis=0, out_dtype="uint8") + return relax.Tuple([y, y_scale, y_zero_point]) class MatMul(OnnxOpConverter): """Converts an onnx MatMul node into an equivalent Relax expression.""" @@ -4812,6 +4879,10 @@ def _get_convert_map(): "ConcatFromSequence": ConcatFromSequence, "SplitToSequence": SplitToSequence, "SequenceAt": SequenceAt, + # Quantization + "QuantizeLinear": QuantizeLinear, + "DequantizeLinear": DequantizeLinear, + "DynamicQuantizeLinear": DynamicQuantizeLinear, } diff --git a/python/tvm/relax/transform/legalize_ops/qdq.py b/python/tvm/relax/transform/legalize_ops/qdq.py index caec63ffa8c0..5e28d1b29105 100644 --- a/python/tvm/relax/transform/legalize_ops/qdq.py +++ b/python/tvm/relax/transform/legalize_ops/qdq.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name """Default legalization function for quantize/dequantize operators.""" +from typing import Union import tvm from tvm import te, tirx @@ -35,6 +36,18 @@ def is_const_scalar(x): return isinstance(x, tvm.tirx.IntImm | tvm.tirx.FloatImm) +def _is_singleton_qparam(qparam: te.Tensor) -> bool: + """Return True if qparam is a tensor with all dimensions equal to 1.""" + if not isinstance(qparam, te.Tensor): + return False + if len(qparam.shape) == 0: + return True + for dim in qparam.shape: + if not isinstance(dim, tirx.IntImm) or dim.value != 1: + return False + return True + + @register_legalize("relax.quantize") def _quantize(bb: BlockBuilder, call: Call) -> Expr: """ @@ -46,12 +59,26 @@ def _quantize(bb: BlockBuilder, call: Call) -> Expr: def te_quantize( data: te.Tensor, - scale: te.Tensor | tirx.IntImm | tirx.FloatImm, - zp: te.Tensor | tirx.IntImm | tirx.FloatImm, + scale: Union[te.Tensor, tirx.IntImm, tirx.FloatImm], + zp: Union[te.Tensor, tirx.IntImm, tirx.FloatImm], ): + scale_singleton = _is_singleton_qparam(scale) if isinstance(scale, te.Tensor) else False + zp_singleton = _is_singleton_qparam(zp) if isinstance(zp, te.Tensor) else False + def quantize_compute(*indices): - scale_value = scale if is_const_scalar(scale) else scale[indices[axis]] - zp_value = zp if is_const_scalar(zp) else zp[indices[axis]] + if is_const_scalar(scale): + scale_value = scale + elif scale_singleton: + scale_value = scale[(0,) * len(scale.shape)] + else: + scale_value = scale[indices[axis]] + + if is_const_scalar(zp): + zp_value = zp + elif zp_singleton: + zp_value = zp[(0,) * len(zp.shape)] + else: + zp_value = zp[indices[axis]] scaled = data[indices] / scale_value round_val = (te.round(scaled) if "int" in out_dtype else scaled) + zp_value return clip_cast(round_val, out_dtype) @@ -94,12 +121,26 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr: def te_dequantize( data: te.Tensor, - scale: te.Tensor | tirx.IntImm | tirx.FloatImm, - zp: te.Tensor | tirx.IntImm | tirx.FloatImm, + scale: Union[te.Tensor, tirx.IntImm, tirx.FloatImm], + zp: Union[te.Tensor, tirx.IntImm, tirx.FloatImm], ): + scale_singleton = _is_singleton_qparam(scale) if isinstance(scale, te.Tensor) else False + zp_singleton = _is_singleton_qparam(zp) if isinstance(zp, te.Tensor) else False + def dequantize_compute(*indices): - scale_value = scale if is_const_scalar(scale) else scale[indices[axis]] - zp_value = zp if is_const_scalar(zp) else zp[indices[axis]] + if is_const_scalar(scale): + scale_value = scale + elif scale_singleton: + scale_value = scale[(0,) * len(scale.shape)] + else: + scale_value = scale[indices[axis]] + + if is_const_scalar(zp): + zp_value = zp + elif zp_singleton: + zp_value = zp[(0,) * len(zp.shape)] + else: + zp_value = zp[indices[axis]] dtype = "float32" if "float" in data.dtype else "int32" sub = te.subtract(data[indices].astype(dtype), zp_value) out = te.multiply(sub, scale_value.astype("float32")) diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index 406868ab4bfc..3a7a9f164a74 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -79,10 +79,14 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { } // Check datatype of zero_point param: - if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::Float(16)) { + if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::UInt(8) && + zp_sinfo->dtype != DataType::Int(16) && zp_sinfo->dtype != DataType::UInt(16) && + zp_sinfo->dtype != DataType::Int(32) && zp_sinfo->dtype != DataType::UInt(32) && + zp_sinfo->dtype != DataType::Float(16)) { ctx->ReportFatal(Diagnostic::Error(call) - << "zero_point param datatype should be 'int8' or 'float16', but got " - << zp_sinfo->dtype); + << "zero_point param datatype should be one of " + << "['int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'float16'], " + << "but got " << zp_sinfo->dtype); } // Check that "axis" attribute is not out of range: @@ -104,9 +108,22 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { } }; + auto is_scalar_or_singleton_vector = [&](const TensorStructInfo& param_sinfo) { + if (IsScalarTensor(param_sinfo)) return true; + if (param_sinfo->shape.defined() && param_sinfo->shape->IsInstance()) { + const auto& values = param_sinfo->shape.as()->values; + if (!values.empty()) { + return std::all_of(values.begin(), values.end(), [&](const PrimExpr& dim) { + return ctx->GetAnalyzer()->CanProveEqual(dim, 1); + }); + } + } + return false; + }; + // Check size matching of scale/zp params with input shape at dim = attrs->axis. - if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale"); - if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point"); + if (!is_scalar_or_singleton_vector(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale"); + if (!is_scalar_or_singleton_vector(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point"); auto output_sinfo = ffi::make_object(*input_sinfo.get()); output_sinfo->dtype = attrs->out_dtype; @@ -167,10 +184,14 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) } // Check datatype of zero_point param: - if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::Float(16)) { + if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::UInt(8) && + zp_sinfo->dtype != DataType::Int(16) && zp_sinfo->dtype != DataType::UInt(16) && + zp_sinfo->dtype != DataType::Int(32) && zp_sinfo->dtype != DataType::UInt(32) && + zp_sinfo->dtype != DataType::Float(16)) { ctx->ReportFatal(Diagnostic::Error(call) - << "zero_point param datatype should be 'int8' or 'float16', but got " - << zp_sinfo->dtype); + << "zero_point param datatype should be one of " + << "['int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'float16'], " + << "but got " << zp_sinfo->dtype); } // Check that "axis" attribute is not out of range: @@ -192,9 +213,22 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) } }; + auto is_scalar_or_singleton_vector = [&](const TensorStructInfo& param_sinfo) { + if (IsScalarTensor(param_sinfo)) return true; + if (param_sinfo->shape.defined() && param_sinfo->shape->IsInstance()) { + const auto& values = param_sinfo->shape.as()->values; + if (!values.empty()) { + return std::all_of(values.begin(), values.end(), [&](const PrimExpr& dim) { + return ctx->GetAnalyzer()->CanProveEqual(dim, 1); + }); + } + } + return false; + }; + // Check size matching of scale/zp params with input shape at dim = attrs->axis. - if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale"); - if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point"); + if (!is_scalar_or_singleton_vector(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale"); + if (!is_scalar_or_singleton_vector(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point"); auto output_sinfo = ffi::make_object(*input_sinfo.get()); output_sinfo->dtype = attrs->out_dtype; diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index db7c3da25a48..7e434d2659bd 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -5599,6 +5599,114 @@ def test_split_to_sequence_uneven_last_chunk(axis: int): model = helper.make_model(graph, producer_name="test_split_to_sequence_uneven") check_correctness(model) +def test_quantizelinear_singleton_qparams_opset10(): + """QuantizeLinear must treat shape-[1] scale/zp as scalar in opset10.""" + node = helper.make_node("QuantizeLinear", ["x", "scale", "zero_point"], ["y"]) + graph = helper.make_graph( + [node], + "quantizelinear_singleton_qparams_opset10", + [helper.make_tensor_value_info("x", TensorProto.FLOAT, [4, 3, 2, 2])], + [helper.make_tensor_value_info("y", TensorProto.UINT8, [4, 3, 2, 2])], + initializer=[ + helper.make_tensor("scale", TensorProto.FLOAT, [1], [0.03125]), + helper.make_tensor("zero_point", TensorProto.UINT8, [1], [127]), + ], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)]) + + x = rg.standard_normal((4, 3, 2, 2)).astype("float32") + check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True) + + +def test_dequantizelinear_singleton_qparams_opset10(): + """DequantizeLinear must treat shape-[1] scale/zp as scalar in opset10.""" + node = helper.make_node("DequantizeLinear", ["x", "scale", "zero_point"], ["y"]) + graph = helper.make_graph( + [node], + "dequantizelinear_singleton_qparams_opset10", + [helper.make_tensor_value_info("x", TensorProto.UINT8, [64])], + [helper.make_tensor_value_info("y", TensorProto.FLOAT, [64])], + initializer=[ + helper.make_tensor("scale", TensorProto.FLOAT, [1], [0.125]), + helper.make_tensor("zero_point", TensorProto.UINT8, [1], [1]), + ], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)]) + + x = rg.integers(low=0, high=255, size=(64,), dtype=np.uint8) + check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True) + + +def test_quantizelinear_optional_zero_point_opset13(): + """ONNX allows missing zero_point input; importer should default it to 0 (uint8).""" + node = helper.make_node("QuantizeLinear", ["x", "scale"], ["y"]) + graph = helper.make_graph( + [node], + "quantizelinear_optional_zero_point_opset13", + [helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 5])], + [helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 5])], + initializer=[helper.make_tensor("scale", TensorProto.FLOAT, [], [0.2])], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + + x = rg.standard_normal((2, 5)).astype("float32") + check_correctness(model, inputs={"x": x}, opset=13, check_dtypes=True) + + +def test_dynamicquantizelinear_opset11(): + """DynamicQuantizeLinear returns (y, y_scale, y_zero_point) with ORT parity.""" + node = helper.make_node("DynamicQuantizeLinear", ["x"], ["y", "y_scale", "y_zero_point"]) + graph = helper.make_graph( + [node], + "dynamicquantizelinear_opset11", + [helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4])], + [ + helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 3, 4]), + helper.make_tensor_value_info("y_scale", TensorProto.FLOAT, []), + helper.make_tensor_value_info("y_zero_point", TensorProto.UINT8, []), + ], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + + x = rg.standard_normal((2, 3, 4)).astype("float32") + check_correctness(model, inputs={"x": x}, opset=11, atol=1e-5, rtol=1e-5, check_dtypes=True) + +def test_quantizelinear_default_axis_opset10(): + """opset10 QuantizeLinear should honor default axis=1 (not hardcode axis=0).""" + node = helper.make_node("QuantizeLinear", ["x", "scale", "zero_point"], ["y"]) + graph = helper.make_graph( + [node], + "quantizelinear_axis_opset10", + [helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4])], + [helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 3, 4])], + initializer=[ + helper.make_tensor("scale", TensorProto.FLOAT, [3], [0.05, 0.1, 0.2]), + helper.make_tensor("zero_point", TensorProto.UINT8, [3], [1, 127, 250]), + ], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)]) + + x = rg.standard_normal((2, 3, 4)).astype("float32") + check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True) + + +def test_dequantizelinear_default_axis_opset10(): + """opset10 DequantizeLinear should honor default axis=1 (not hardcode axis=0).""" + node = helper.make_node("DequantizeLinear", ["x", "scale", "zero_point"], ["y"]) + graph = helper.make_graph( + [node], + "dequantizelinear_axis_opset10", + [helper.make_tensor_value_info("x", TensorProto.UINT8, [2, 3, 4])], + [helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3, 4])], + initializer=[ + helper.make_tensor("scale", TensorProto.FLOAT, [3], [0.05, 0.1, 0.2]), + helper.make_tensor("zero_point", TensorProto.UINT8, [3], [1, 127, 250]), + ], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)]) + + x = rg.integers(low=0, high=255, size=(2, 3, 4), dtype=np.uint8) + check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True) if __name__ == "__main__": tvm.testing.main()