diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 2dec7e2e1ede..52211f89221a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -66,6 +66,17 @@ def revert_caffe2_pad(pads): return pads +def get_pad_pair(input1d, kernel1d, stride1d): + """infer pad size""" + if input1d % stride1d == 0: + pad = max(kernel1d - stride1d, 0) + else: + pad = max(kernel1d - (input1d % stride1d), 0) + pad_before = pad // 2 + pad_after = pad - pad_before + return [pad_before, pad_after] + + def onnx_storage_order2layout(storage_order): """converter of onnx storage order parameter to tvm storage order format""" if storage_order not in (0, 1): @@ -202,14 +213,37 @@ class Conv(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - out = AttrCvt(op_name=dimension_picker('conv'), - transforms={ - 'kernel_shape': 'kernel_size', - 'dilations': ('dilation', (0, 0)), - 'pads': ('padding', (0, 0), revert_caffe2_pad), - 'group': ('groups', 1)}, - ignores=['auto_pad'], - custom_check=dimension_constraint())(inputs[:2], attr, params) + # infer pads for auto_pad + if 'auto_pad' in attr: + attr['auto_pad'] = attr['auto_pad'].decode('utf-8') + if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'): + input_shape = infer_shape(inputs[0]) + in_h, in_w = input_shape[2], input_shape[3] + stride_h, stride_w = attr['strides'] + kernel_h, kernel_w = attr['kernel_shape'] + dilation_h, dilation_w = attr['dilations'] + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h) + pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w) + attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1]) + elif attr['auto_pad'] == 'VALID': + attr['pads'] = (0, 0) + elif attr['auto_pad'] == 'NOTSET': + pass + else: + msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'])) + attr.pop('auto_pad') + + out = AttrCvt( + op_name=dimension_picker('conv'), + transforms={ + 'kernel_shape': 'kernel_size', + 'dilations': ('dilation', (0, 0)), + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'group': ('groups', 1)}, + custom_check=dimension_constraint())(inputs[:2], attr, params) use_bias = len(inputs) == 3 if use_bias: out = _op.nn.bias_add(out, inputs[2]) @@ -226,6 +260,29 @@ def _impl_v1(cls, inputs, attr, params): attr['channels'] = channels groups = attr.pop('group') attr['groups'] = groups + # infer pads for auto_pad + if 'auto_pad' in attr: + attr['auto_pad'] = attr['auto_pad'].decode('utf-8') + if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'): + input_shape = infer_shape(inputs[0]) + in_h, in_w = input_shape[2], input_shape[3] + stride_h, stride_w = attr['strides'] + kernel_h, kernel_w = attr['kernel_shape'] + dilation_h, dilation_w = attr['dilations'] + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h) + pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w) + attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1]) + elif attr['auto_pad'] == 'VALID': + attr['pads'] = (0, 0) + elif attr['auto_pad'] == 'NOTSET': + pass + else: + msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'])) + attr.pop('auto_pad') + out = AttrCvt( op_name=dimension_picker('conv', '_transpose'), transforms={ diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index cdd2df4ea14f..12dee0ff534d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -77,11 +77,14 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output return tvm_output.asnumpy() -def get_onnxruntime_output(model, x, dtype='float32'): +def get_onnxruntime_output(model, inputs, dtype='float32'): import onnxruntime.backend rep = onnxruntime.backend.prepare(model, 'CPU') - x = x.astype(dtype) - ort_out = rep.run(x)[0] + if isinstance(inputs, list) and len(inputs) > 1: + ort_out = rep.run(inputs) + else: + x = inputs.astype(dtype) + ort_out = rep.run(x)[0] return ort_out @@ -1746,6 +1749,83 @@ def test_or(): verify_or(indata=[x, y], dtype=bool) +def verify_conv(x_shape, w_shape, y_shape, p): + node = helper.make_node('Conv', + inputs=['x', 'W'], + outputs=['y'], + kernel_shape=[3, 3], + # Default values for other attributes: + # strides=[1, 1], + # dilations=[1, 1], + # groups=1 + pads=p,) + + graph = helper.make_graph([node], + 'conv_test', + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))]) + + model = helper.make_model(graph, producer_name='conv_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=x_shape).astype('float32') + W = np.random.uniform(size=w_shape).astype('float32') + tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape) + onnx_out = get_onnxruntime_output(model, [x, W], 'float32')[0] + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + + +def test_conv(): + # Convolution with padding + # (1, 1, 5, 5) input tensor + # (1, 1, 3, 3) tensor for convolution weights + # (1, 1, 5, 5) output tensor + # [1, 1, 1, 1] list for pads + verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), [1, 1, 1, 1]) + + # Convolution without padding + # (1, 1, 5, 5) input tensor + # (1, 1, 3, 3) tensor for convolution weights + # (1, 1, 3, 3) output tensor + # [0, 0, 0, 0] list for pads + verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), [0, 0, 0, 0]) + + +def verify_convtranspose(x_shape, w_shape, y_shape, p): + node = onnx.helper.make_node("ConvTranspose", + inputs=["x", "W"], + outputs=['y'], + strides=[3, 2], + group=1, + kernel_shape=[3, 3], + pads=p) + + graph = helper.make_graph([node], + 'verify_convtranspose_test', + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))]) + + model = helper.make_model(graph, producer_name='convtranspose_trest') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=x_shape).astype('float32') + W = np.random.uniform(size=w_shape).astype('float32') + tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape) + onnx_out = get_onnxruntime_output(model, [x, W], 'float32')[0] + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + + +def test_convtranspose(): + # Convolution Transpose with padding + # (1, 1, 3, 3) input tensor + # (1, 2, 3, 3) tensor for convolution weights + # (1, 2, 7, 3) output tensor + # [1, 2, 1, 2] list for pads + verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2]) + + if __name__ == '__main__': test_flatten() test_reshape() @@ -1800,3 +1880,5 @@ def test_or(): test_or() test_depth_to_space() test_space_to_depth() + test_conv() + test_convtranspose()