From c1a77f658ba6a81bb2d0cafcc40049fdcc238c5b Mon Sep 17 00:00:00 2001 From: liuliang Date: Sat, 8 May 2021 15:11:35 +0800 Subject: [PATCH 1/4] Add QLinearConv for onnx frontend --- python/tvm/relay/frontend/onnx.py | 93 ++++++++++ tests/python/frontend/onnx/test_forward.py | 189 ++++++++++++++++++++- 2 files changed, 279 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f3282f03c813..e9a4ef6a21e0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2827,6 +2827,98 @@ def _impl_v11(cls, inputs, attr, params): ) +class QLinearConv(OnnxOpConverter): + """Operator converter for QLinearConv.""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + + def get_scalar(input): + if isinstance(input, _expr.Var) and input.name_hint in params: + return params[input.name_hint].asnumpy() + return input + + data = inputs[0] + x_scale = get_scalar(inputs[1]) + x_zero_point = get_scalar(inputs[2]) + weight = inputs[3] + w_scale = get_scalar(inputs[4]) + w_zero_point = get_scalar(inputs[5]) + y_scale = get_scalar(inputs[6]) + y_zero_point = get_scalar(inputs[7]) + + input_shape = infer_shape(data) + ndim = len(input_shape) + kernel_type = infer_type(weight) + kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] + if "kernel_shape" not in attr: + attr["kernel_shape"] = kernel_shapes[0][2:] + + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + data = autopad( + data, + attr.get("strides", [1] * (ndim - 2)), + attr["kernel_shape"], + attr.get("dilations", [1] * (ndim - 2)), + ndim, + pad_value=x_zero_point, + mode=attr["auto_pad"], + ) + elif attr["auto_pad"] == "VALID": + attr["pads"] = tuple([0 for i in range(ndim - 2)]) + elif attr["auto_pad"] == "NOTSET": + pass + else: + msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) + attr.pop("auto_pad") + + out_channels = kernel_shapes[0][0] + dilation = attr.get("dilations", [1] * (ndim - 2)) + strides = attr.get("strides", [1] * (ndim - 2)) + padding = attr["pads"] if "pads" in attr else 0 + groups = attr["group"] if "group" in attr else 1 + + if ndim != 4: + raise tvm.error.OpAttributeInvalid("Only 2D kernels are supported for operator QLinearConv.") + + out = _qnn.op.conv2d( + data, + weight, + _expr.const(x_zero_point, "int32"), + _expr.const(w_zero_point, "int32"), + _expr.const(x_scale), + _expr.const(w_scale), + kernel_size=attr["kernel_shape"], + channels=out_channels, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + ) + use_bias = len(inputs) == 9 + if use_bias: + out = _op.nn.bias_add(out, inputs[8]) + + requantize_scale = x_scale * w_scale + out_dtype = infer_type(inputs[7]).checked_type.dtype + requantized = _qnn.op.requantize( + out, + _expr.const(requantize_scale), + tvm.relay.const(0, "int32"), + _expr.const(y_scale), + _expr.const(y_zero_point, "int32"), + out_dtype=out_dtype, + axis=0, + ) + + return requantized + + class BitShift(OnnxOpConverter): """Operator converter for NonZero""" @@ -3018,6 +3110,7 @@ def _get_convert_map(opset): "DequantizeLinear": DequantizeLinear.get_converter(opset), "DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset), "ReverseSequence": ReverseSequence.get_converter(opset), + "QLinearConv": QLinearConv.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 3ffeb3e4f788..0f0ae0a2a904 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -66,7 +66,8 @@ def get_tvm_output_with_vm( def get_tvm_output( - graph_def, input_data, target, device, output_shape=None, output_dtype="float32", opset=None + graph_def, input_data, target, device, output_shape=None, + output_dtype="float32", opset=None, opt_level=1, ): """Generic function to execute and get tvm output""" # TODO: Resolve the issues and remove the following lines @@ -76,7 +77,8 @@ def get_tvm_output( input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) - with tvm.transform.PassContext(opt_level=1): + + with tvm.transform.PassContext(opt_level=opt_level): graph, lib, params = relay.build(mod, target, params=params) m = graph_executor.create(graph, lib, device) @@ -135,6 +137,7 @@ def verify_with_ort_with_inputs( rtol=1e-5, atol=1e-5, apply_softmax=False, + opt_level=1, ): if opset is not None: model.opset_import[0].version = opset @@ -156,7 +159,9 @@ def verify_with_ort_with_inputs( convert_to_static=convert_to_static, ) else: - tvm_out = get_tvm_output(model, inputs, target, dev, out_shape, dtype, opset=opset) + tvm_out = get_tvm_output( + model, inputs, target, dev, out_shape, dtype, opset=opset, opt_level=opt_level + ) if not isinstance(tvm_out, list): tvm_out = [tvm_out] if not isinstance(ort_out, list): @@ -4387,6 +4392,183 @@ def test_reverse_sequence(): verify_reverse_sequence(x, sequence_lens, 1, 0) +def verify_qlinearconv( + x_shape, + w_shape, + y_shape, + padding, + kernel_shape, + strides, + dilations, + auto_pad="NOTSET", + bias=False, +): + + x_array = np.random.randint(low=0, high=255, size=x_shape).astype("uint8") + w_array = np.random.uniform(low=0, high=255, size=w_shape).astype("uint8") + + initializer = [ + helper.make_tensor("x_scale", TensorProto.FLOAT, (), [np.random.rand()]), + helper.make_tensor("x_zero_point", TensorProto.UINT8, (), [np.random.randint(0, 255)]), + helper.make_tensor("w_scale", TensorProto.FLOAT, (), [np.random.rand()]), + helper.make_tensor("w_zero_point", TensorProto.UINT8, (), [np.random.randint(0, 255)]), + helper.make_tensor("y_scale", TensorProto.FLOAT, (), [np.random.rand()]), + helper.make_tensor("y_zero_point", TensorProto.UINT8, (), [np.random.randint(0, 255)]), + ] + + input_nodes = [ + helper.make_tensor_value_info("x", TensorProto.UINT8, list(x_shape)), + helper.make_tensor_value_info("w", TensorProto.UINT8, list(w_shape)), + ] + input_names = [ + "x", "x_scale", "x_zero_point", "w", "w_scale", "w_zero_point", "y_scale", "y_zero_point" + ] + input_values = [x_array, w_array] + + if bias is True: + b_shape = w_shape[0:1] + b_array = np.random.randint(low=0, high=65536, size=b_shape).astype("int32") + input_nodes.append( + helper.make_tensor_value_info("B", TensorProto.INT32, list(b_shape)) + ) + input_names.append("B") + input_values.append(b_array) + + if padding is None: + ## autopadding with unset default attributes + kwargs = {} + if not all([s == 1 for s in strides]): + kwargs["strides"] = strides + if not all([d == 1 for d in dilations]): + kwargs["dilations"] = dilations + + node = helper.make_node( + "QLinearConv", + inputs=input_names, + outputs=["y"], + # Default values for other attributes: + auto_pad=auto_pad, + **kwargs, + ) + else: + node = helper.make_node( + "QLinearConv", + inputs=input_names, + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + # groups=1 + pads=padding, + ) + + graph = helper.make_graph( + [node], + "conv_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("y", TensorProto.UINT8, list(y_shape))], + initializer=initializer, + ) + model = helper.make_model(graph, producer_name="qlinearconv_test") + # opt_level=1 will cause error + verify_with_ort_with_inputs(model, input_values, opt_level=2) + + +def test_qlinearconv(): + def repeat(N, D): + return tuple([N for _ in range(D)]) + # only support QLinearConv2d because only support qnn.conv2d + D = 2 + + # Convolution with padding + verify_qlinearconv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + + # Convolution with bias + verify_qlinearconv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + bias=True, + ) + + # Convolution with assymetric padding + verify_qlinearconv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(4, D), + repeat(0, D) + repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution without padding + verify_qlinearconv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution with autopadding + verify_qlinearconv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with valid autopadding + verify_qlinearconv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="VALID", + ) + # Convolution with non uniform stride + verify_qlinearconv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(2, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with dilation + verify_qlinearconv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(2, D), + repeat(3, D), + repeat(1, D), + repeat(2, D), + ) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -4468,3 +4650,4 @@ def test_reverse_sequence(): test_wrong_input() test_aten() test_reverse_sequence() + test_qlinearconv() From 7e40042bb452e567e5568f563f65d05d0acf81fe Mon Sep 17 00:00:00 2001 From: liuliang Date: Sat, 8 May 2021 15:37:59 +0800 Subject: [PATCH 2/4] Reformat --- python/tvm/relay/frontend/onnx.py | 13 ++++++----- tests/python/frontend/onnx/test_forward.py | 26 +++++++++++++++++----- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index e9a4ef6a21e0..c7c4ec48458d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2832,11 +2832,10 @@ class QLinearConv(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): - - def get_scalar(input): - if isinstance(input, _expr.Var) and input.name_hint in params: - return params[input.name_hint].asnumpy() - return input + def get_scalar(x): + if isinstance(x, _expr.Var) and x.name_hint in params: + return params[x.name_hint].asnumpy() + return x data = inputs[0] x_scale = get_scalar(inputs[1]) @@ -2884,7 +2883,9 @@ def get_scalar(input): groups = attr["group"] if "group" in attr else 1 if ndim != 4: - raise tvm.error.OpAttributeInvalid("Only 2D kernels are supported for operator QLinearConv.") + raise tvm.error.OpAttributeInvalid( + "Only 2D kernels are supported for operator QLinearConv." + ) out = _qnn.op.conv2d( data, diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0f0ae0a2a904..19a72d4f826d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -66,8 +66,14 @@ def get_tvm_output_with_vm( def get_tvm_output( - graph_def, input_data, target, device, output_shape=None, - output_dtype="float32", opset=None, opt_level=1, + graph_def, + input_data, + target, + device, + output_shape=None, + output_dtype="float32", + opset=None, + opt_level=1, ): """Generic function to execute and get tvm output""" # TODO: Resolve the issues and remove the following lines @@ -4224,6 +4230,8 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_maxpool_with_argmax_2d_precomputed_strides/", "test_maxunpool_export_with_output_shape/", "test_mvn/", + # For QLinearConv tests in ONNX, scale and zero_point input are placeholders, + # but qnn.conv2d requires scale and zero_point input to be scalar "test_qlinearconv/", "test_qlinearmatmul_2D/", "test_qlinearmatmul_3D/", @@ -4421,16 +4429,21 @@ def verify_qlinearconv( helper.make_tensor_value_info("w", TensorProto.UINT8, list(w_shape)), ] input_names = [ - "x", "x_scale", "x_zero_point", "w", "w_scale", "w_zero_point", "y_scale", "y_zero_point" + "x", + "x_scale", + "x_zero_point", + "w", + "w_scale", + "w_zero_point", + "y_scale", + "y_zero_point", ] input_values = [x_array, w_array] if bias is True: b_shape = w_shape[0:1] b_array = np.random.randint(low=0, high=65536, size=b_shape).astype("int32") - input_nodes.append( - helper.make_tensor_value_info("B", TensorProto.INT32, list(b_shape)) - ) + input_nodes.append(helper.make_tensor_value_info("B", TensorProto.INT32, list(b_shape))) input_names.append("B") input_values.append(b_array) @@ -4478,6 +4491,7 @@ def verify_qlinearconv( def test_qlinearconv(): def repeat(N, D): return tuple([N for _ in range(D)]) + # only support QLinearConv2d because only support qnn.conv2d D = 2 From 21be6ef6be5670c47dc1394f10982add361efc7f Mon Sep 17 00:00:00 2001 From: liuliang Date: Tue, 11 May 2021 11:49:18 +0800 Subject: [PATCH 3/4] Squeeze 1D tensor for weight_scale & weight_zero_point --- python/tvm/relay/frontend/onnx.py | 33 ++++++++++++---------- tests/python/frontend/onnx/test_forward.py | 4 +-- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c7c4ec48458d..a72ac727df24 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2832,19 +2832,23 @@ class QLinearConv(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): - def get_scalar(x): + def get_scalar(x, dtype="float32"): if isinstance(x, _expr.Var) and x.name_hint in params: - return params[x.name_hint].asnumpy() - return x + return _op.const(params[x.name_hint].asnumpy(), dtype) + rank = len(infer_shape(x)) + assert rank <= 1, "QLinearConv scale and zero_point input must be scalars" + if rank == 1: + x = _op.squeeze(x, [0]) + return _op.cast(x, dtype) data = inputs[0] x_scale = get_scalar(inputs[1]) - x_zero_point = get_scalar(inputs[2]) + x_zero_point = get_scalar(inputs[2], "int32") weight = inputs[3] w_scale = get_scalar(inputs[4]) - w_zero_point = get_scalar(inputs[5]) + w_zero_point = get_scalar(inputs[5], "int32") y_scale = get_scalar(inputs[6]) - y_zero_point = get_scalar(inputs[7]) + y_zero_point = get_scalar(inputs[7], "int32") input_shape = infer_shape(data) ndim = len(input_shape) @@ -2864,7 +2868,7 @@ def get_scalar(x): attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), ndim, - pad_value=x_zero_point, + pad_value=x_zero_point.data, mode=attr["auto_pad"], ) elif attr["auto_pad"] == "VALID": @@ -2890,10 +2894,10 @@ def get_scalar(x): out = _qnn.op.conv2d( data, weight, - _expr.const(x_zero_point, "int32"), - _expr.const(w_zero_point, "int32"), - _expr.const(x_scale), - _expr.const(w_scale), + x_zero_point, + w_zero_point, + x_scale, + w_scale, kernel_size=attr["kernel_shape"], channels=out_channels, strides=strides, @@ -2905,14 +2909,13 @@ def get_scalar(x): if use_bias: out = _op.nn.bias_add(out, inputs[8]) - requantize_scale = x_scale * w_scale out_dtype = infer_type(inputs[7]).checked_type.dtype requantized = _qnn.op.requantize( out, - _expr.const(requantize_scale), + _op.multiply(x_scale, w_scale), tvm.relay.const(0, "int32"), - _expr.const(y_scale), - _expr.const(y_zero_point, "int32"), + y_scale, + y_zero_point, out_dtype=out_dtype, axis=0, ) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 19a72d4f826d..a1316f93903b 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4230,8 +4230,8 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_maxpool_with_argmax_2d_precomputed_strides/", "test_maxunpool_export_with_output_shape/", "test_mvn/", - # For QLinearConv tests in ONNX, scale and zero_point input are placeholders, - # but qnn.conv2d requires scale and zero_point input to be scalar + # For ONNX test model of QLinearConv, y_scale is input node and will be parsed to free_var in + # relay. It will cause error while lowering requantize in src/relay/qnn/op/requantize.cc:148 "test_qlinearconv/", "test_qlinearmatmul_2D/", "test_qlinearmatmul_3D/", From 6f18bf9dac826b7f59971ec2a6a7580ff12d2e2f Mon Sep 17 00:00:00 2001 From: liuliang Date: Wed, 12 May 2021 09:48:02 +0800 Subject: [PATCH 4/4] Doing dequatize -> quantize if y_scale is not constant --- python/tvm/relay/frontend/onnx.py | 29 ++++++++++++++-------- tests/python/frontend/onnx/test_forward.py | 3 --- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a72ac727df24..2a57cba53cd2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2910,17 +2910,24 @@ def get_scalar(x, dtype="float32"): out = _op.nn.bias_add(out, inputs[8]) out_dtype = infer_type(inputs[7]).checked_type.dtype - requantized = _qnn.op.requantize( - out, - _op.multiply(x_scale, w_scale), - tvm.relay.const(0, "int32"), - y_scale, - y_zero_point, - out_dtype=out_dtype, - axis=0, - ) - - return requantized + requantize_scale = _op.multiply(x_scale, w_scale) + + # requantize requires y_scale to be constant, + # if y_scale is not constant, doing dequantize -> quantize + if isinstance(y_scale, _expr.Constant): + out = _qnn.op.requantize( + out, + requantize_scale, + _op.const(0, dtype="int32"), + y_scale, + y_zero_point, + out_dtype=out_dtype, + axis=0, + ) + else: + out = _qnn.op.dequantize(out, requantize_scale, _op.const(0, dtype="int32"), axis=0) + out = _qnn.op.quantize(out, y_scale, y_zero_point, axis=0, out_dtype=out_dtype) + return out class BitShift(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a1316f93903b..fdb8d205a244 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4230,9 +4230,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_maxpool_with_argmax_2d_precomputed_strides/", "test_maxunpool_export_with_output_shape/", "test_mvn/", - # For ONNX test model of QLinearConv, y_scale is input node and will be parsed to free_var in - # relay. It will cause error while lowering requantize in src/relay/qnn/op/requantize.cc:148 - "test_qlinearconv/", "test_qlinearmatmul_2D/", "test_qlinearmatmul_3D/", "test_resize_tf_crop_and_resize/",