Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,73 @@ def get_converter(cls, opset):
return getattr(cls, f"_impl_v{version}")
raise NotImplementedError(f"opset version {version} of {cls.__name__} not implemented")

class QuantizeLinear(OnnxOpConverter):
@classmethod
def _impl_v10(cls, bb, inputs, attr, params):
x, scale = inputs[0], inputs[1]
zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
axis = attr.get("axis", 1)
if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1:
axis = 0
out_dtype = "uint8" if zp is None else zp.struct_info.dtype
if zp is None:
zp = relax.const(0, out_dtype)
return relax.op.quantize(x, scale, zp, axis=axis, out_dtype=out_dtype)

@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
x, scale = inputs[0], inputs[1]
zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
axis = attr.get("axis", 1)
if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1:
axis = 0
out_dtype = "uint8" if zp is None else zp.struct_info.dtype
if zp is None:
zp = relax.const(0, out_dtype)
return relax.op.quantize(x, scale, zp, axis=axis, out_dtype=out_dtype)


class DequantizeLinear(OnnxOpConverter):
@classmethod
def _impl_v10(cls, bb, inputs, attr, params):
x, scale = inputs[0], inputs[1]
zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
axis = attr.get("axis", 1)
if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1:
axis = 0
if zp is None:
zp = relax.const(0, x.struct_info.dtype)
return relax.op.dequantize(x, scale, zp, axis=axis, out_dtype="float32")

@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
x, scale = inputs[0], inputs[1]
zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
axis = attr.get("axis", 1)
if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1:
axis = 0
if zp is None:
zp = relax.const(0, x.struct_info.dtype)
return relax.op.dequantize(x, scale, zp, axis=axis, out_dtype="float32")


class DynamicQuantizeLinear(OnnxOpConverter):
@classmethod
def _impl_v11(cls, bb, inputs, attr, params):
x = inputs[0]
x_dtype = x.struct_info.dtype
qmin = relax.const(0, x_dtype)
qmax = relax.const(255, x_dtype)

x_max = relax.op.maximum(qmin, relax.op.max(x))
x_min = relax.op.minimum(qmin, relax.op.min(x))
y_scale = relax.op.divide(relax.op.subtract(x_max, x_min), qmax)

zp_fp = relax.op.subtract(qmin, relax.op.divide(x_min, y_scale))
y_zero_point = relax.op.astype(relax.op.round(relax.op.clip(zp_fp, 0, 255)), "uint8")

y = relax.op.quantize(x, y_scale, y_zero_point, axis=0, out_dtype="uint8")
return relax.Tuple([y, y_scale, y_zero_point])

class MatMul(OnnxOpConverter):
"""Converts an onnx MatMul node into an equivalent Relax expression."""
Expand Down Expand Up @@ -4812,6 +4879,10 @@ def _get_convert_map():
"ConcatFromSequence": ConcatFromSequence,
"SplitToSequence": SplitToSequence,
"SequenceAt": SequenceAt,
# Quantization
"QuantizeLinear": QuantizeLinear,
"DequantizeLinear": DequantizeLinear,
"DynamicQuantizeLinear": DynamicQuantizeLinear,
}


Expand Down
57 changes: 49 additions & 8 deletions python/tvm/relax/transform/legalize_ops/qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name
"""Default legalization function for quantize/dequantize operators."""

from typing import Union
import tvm
from tvm import te, tirx

Expand All @@ -35,6 +36,18 @@ def is_const_scalar(x):
return isinstance(x, tvm.tirx.IntImm | tvm.tirx.FloatImm)


def _is_singleton_qparam(qparam: te.Tensor) -> bool:
"""Return True if qparam is a tensor with all dimensions equal to 1."""
if not isinstance(qparam, te.Tensor):
return False
if len(qparam.shape) == 0:
return True
for dim in qparam.shape:
if not isinstance(dim, tirx.IntImm) or dim.value != 1:
return False
return True


@register_legalize("relax.quantize")
def _quantize(bb: BlockBuilder, call: Call) -> Expr:
"""
Expand All @@ -46,12 +59,26 @@ def _quantize(bb: BlockBuilder, call: Call) -> Expr:

def te_quantize(
data: te.Tensor,
scale: te.Tensor | tirx.IntImm | tirx.FloatImm,
zp: te.Tensor | tirx.IntImm | tirx.FloatImm,
scale: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
zp: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
):
scale_singleton = _is_singleton_qparam(scale) if isinstance(scale, te.Tensor) else False
zp_singleton = _is_singleton_qparam(zp) if isinstance(zp, te.Tensor) else False

def quantize_compute(*indices):
scale_value = scale if is_const_scalar(scale) else scale[indices[axis]]
zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
if is_const_scalar(scale):
scale_value = scale
elif scale_singleton:
scale_value = scale[(0,) * len(scale.shape)]
else:
scale_value = scale[indices[axis]]

if is_const_scalar(zp):
zp_value = zp
elif zp_singleton:
zp_value = zp[(0,) * len(zp.shape)]
else:
zp_value = zp[indices[axis]]
scaled = data[indices] / scale_value
round_val = (te.round(scaled) if "int" in out_dtype else scaled) + zp_value
return clip_cast(round_val, out_dtype)
Expand Down Expand Up @@ -94,12 +121,26 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr:

def te_dequantize(
data: te.Tensor,
scale: te.Tensor | tirx.IntImm | tirx.FloatImm,
zp: te.Tensor | tirx.IntImm | tirx.FloatImm,
scale: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
zp: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
):
scale_singleton = _is_singleton_qparam(scale) if isinstance(scale, te.Tensor) else False
zp_singleton = _is_singleton_qparam(zp) if isinstance(zp, te.Tensor) else False

def dequantize_compute(*indices):
scale_value = scale if is_const_scalar(scale) else scale[indices[axis]]
zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
if is_const_scalar(scale):
scale_value = scale
elif scale_singleton:
scale_value = scale[(0,) * len(scale.shape)]
else:
scale_value = scale[indices[axis]]

if is_const_scalar(zp):
zp_value = zp
elif zp_singleton:
zp_value = zp[(0,) * len(zp.shape)]
else:
zp_value = zp[indices[axis]]
dtype = "float32" if "float" in data.dtype else "int32"
sub = te.subtract(data[indices].astype(dtype), zp_value)
out = te.multiply(sub, scale_value.astype("float32"))
Expand Down
54 changes: 44 additions & 10 deletions src/relax/op/tensor/qdq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) {
}

// Check datatype of zero_point param:
if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::Float(16)) {
if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::UInt(8) &&
zp_sinfo->dtype != DataType::Int(16) && zp_sinfo->dtype != DataType::UInt(16) &&
zp_sinfo->dtype != DataType::Int(32) && zp_sinfo->dtype != DataType::UInt(32) &&
zp_sinfo->dtype != DataType::Float(16)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "zero_point param datatype should be 'int8' or 'float16', but got "
<< zp_sinfo->dtype);
<< "zero_point param datatype should be one of "
<< "['int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'float16'], "
<< "but got " << zp_sinfo->dtype);
}

// Check that "axis" attribute is not out of range:
Expand All @@ -104,9 +108,22 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) {
}
};

auto is_scalar_or_singleton_vector = [&](const TensorStructInfo& param_sinfo) {
if (IsScalarTensor(param_sinfo)) return true;
if (param_sinfo->shape.defined() && param_sinfo->shape->IsInstance<ShapeExprNode>()) {
const auto& values = param_sinfo->shape.as<ShapeExprNode>()->values;
if (!values.empty()) {
return std::all_of(values.begin(), values.end(), [&](const PrimExpr& dim) {
return ctx->GetAnalyzer()->CanProveEqual(dim, 1);
});
}
}
return false;
};

// Check size matching of scale/zp params with input shape at dim = attrs->axis.
if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale");
if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point");
if (!is_scalar_or_singleton_vector(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale");
if (!is_scalar_or_singleton_vector(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point");

auto output_sinfo = ffi::make_object<TensorStructInfoNode>(*input_sinfo.get());
output_sinfo->dtype = attrs->out_dtype;
Expand Down Expand Up @@ -167,10 +184,14 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx)
}

// Check datatype of zero_point param:
if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::Float(16)) {
if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::UInt(8) &&
zp_sinfo->dtype != DataType::Int(16) && zp_sinfo->dtype != DataType::UInt(16) &&
zp_sinfo->dtype != DataType::Int(32) && zp_sinfo->dtype != DataType::UInt(32) &&
zp_sinfo->dtype != DataType::Float(16)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "zero_point param datatype should be 'int8' or 'float16', but got "
<< zp_sinfo->dtype);
<< "zero_point param datatype should be one of "
<< "['int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'float16'], "
<< "but got " << zp_sinfo->dtype);
}

// Check that "axis" attribute is not out of range:
Expand All @@ -192,9 +213,22 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx)
}
};

auto is_scalar_or_singleton_vector = [&](const TensorStructInfo& param_sinfo) {
if (IsScalarTensor(param_sinfo)) return true;
if (param_sinfo->shape.defined() && param_sinfo->shape->IsInstance<ShapeExprNode>()) {
const auto& values = param_sinfo->shape.as<ShapeExprNode>()->values;
if (!values.empty()) {
return std::all_of(values.begin(), values.end(), [&](const PrimExpr& dim) {
return ctx->GetAnalyzer()->CanProveEqual(dim, 1);
});
}
}
return false;
};

// Check size matching of scale/zp params with input shape at dim = attrs->axis.
if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale");
if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point");
if (!is_scalar_or_singleton_vector(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale");
if (!is_scalar_or_singleton_vector(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point");

auto output_sinfo = ffi::make_object<TensorStructInfoNode>(*input_sinfo.get());
output_sinfo->dtype = attrs->out_dtype;
Expand Down
108 changes: 108 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5599,6 +5599,114 @@ def test_split_to_sequence_uneven_last_chunk(axis: int):
model = helper.make_model(graph, producer_name="test_split_to_sequence_uneven")
check_correctness(model)

def test_quantizelinear_singleton_qparams_opset10():
"""QuantizeLinear must treat shape-[1] scale/zp as scalar in opset10."""
node = helper.make_node("QuantizeLinear", ["x", "scale", "zero_point"], ["y"])
graph = helper.make_graph(
[node],
"quantizelinear_singleton_qparams_opset10",
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [4, 3, 2, 2])],
[helper.make_tensor_value_info("y", TensorProto.UINT8, [4, 3, 2, 2])],
initializer=[
helper.make_tensor("scale", TensorProto.FLOAT, [1], [0.03125]),
helper.make_tensor("zero_point", TensorProto.UINT8, [1], [127]),
],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)])

x = rg.standard_normal((4, 3, 2, 2)).astype("float32")
check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)


def test_dequantizelinear_singleton_qparams_opset10():
"""DequantizeLinear must treat shape-[1] scale/zp as scalar in opset10."""
node = helper.make_node("DequantizeLinear", ["x", "scale", "zero_point"], ["y"])
graph = helper.make_graph(
[node],
"dequantizelinear_singleton_qparams_opset10",
[helper.make_tensor_value_info("x", TensorProto.UINT8, [64])],
[helper.make_tensor_value_info("y", TensorProto.FLOAT, [64])],
initializer=[
helper.make_tensor("scale", TensorProto.FLOAT, [1], [0.125]),
helper.make_tensor("zero_point", TensorProto.UINT8, [1], [1]),
],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)])

x = rg.integers(low=0, high=255, size=(64,), dtype=np.uint8)
check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)


def test_quantizelinear_optional_zero_point_opset13():
"""ONNX allows missing zero_point input; importer should default it to 0 (uint8)."""
node = helper.make_node("QuantizeLinear", ["x", "scale"], ["y"])
graph = helper.make_graph(
[node],
"quantizelinear_optional_zero_point_opset13",
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 5])],
[helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 5])],
initializer=[helper.make_tensor("scale", TensorProto.FLOAT, [], [0.2])],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])

x = rg.standard_normal((2, 5)).astype("float32")
check_correctness(model, inputs={"x": x}, opset=13, check_dtypes=True)


def test_dynamicquantizelinear_opset11():
"""DynamicQuantizeLinear returns (y, y_scale, y_zero_point) with ORT parity."""
node = helper.make_node("DynamicQuantizeLinear", ["x"], ["y", "y_scale", "y_zero_point"])
graph = helper.make_graph(
[node],
"dynamicquantizelinear_opset11",
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4])],
[
helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 3, 4]),
helper.make_tensor_value_info("y_scale", TensorProto.FLOAT, []),
helper.make_tensor_value_info("y_zero_point", TensorProto.UINT8, []),
],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)])

x = rg.standard_normal((2, 3, 4)).astype("float32")
check_correctness(model, inputs={"x": x}, opset=11, atol=1e-5, rtol=1e-5, check_dtypes=True)

def test_quantizelinear_default_axis_opset10():
"""opset10 QuantizeLinear should honor default axis=1 (not hardcode axis=0)."""
node = helper.make_node("QuantizeLinear", ["x", "scale", "zero_point"], ["y"])
graph = helper.make_graph(
[node],
"quantizelinear_axis_opset10",
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4])],
[helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 3, 4])],
initializer=[
helper.make_tensor("scale", TensorProto.FLOAT, [3], [0.05, 0.1, 0.2]),
helper.make_tensor("zero_point", TensorProto.UINT8, [3], [1, 127, 250]),
],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)])

x = rg.standard_normal((2, 3, 4)).astype("float32")
check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)


def test_dequantizelinear_default_axis_opset10():
"""opset10 DequantizeLinear should honor default axis=1 (not hardcode axis=0)."""
node = helper.make_node("DequantizeLinear", ["x", "scale", "zero_point"], ["y"])
graph = helper.make_graph(
[node],
"dequantizelinear_axis_opset10",
[helper.make_tensor_value_info("x", TensorProto.UINT8, [2, 3, 4])],
[helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3, 4])],
initializer=[
helper.make_tensor("scale", TensorProto.FLOAT, [3], [0.05, 0.1, 0.2]),
helper.make_tensor("zero_point", TensorProto.UINT8, [3], [1, 127, 250]),
],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)])

x = rg.integers(low=0, high=255, size=(2, 3, 4), dtype=np.uint8)
check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)

if __name__ == "__main__":
tvm.testing.main()