From 2f58db5a92b354edaef8e854ca64bad1c670af19 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 6 Apr 2021 10:47:40 -0600 Subject: [PATCH] ONNX bitshfit --- python/tvm/relay/frontend/onnx.py | 19 +++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 8 -------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 669eab8cc250..b68ebc3084c3 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2782,6 +2782,24 @@ def _impl_v1(cls, inputs, attr, params): return cls._op_dispatch(operator, inputs, attr, params) +class BitShift(OnnxOpConverter): + """Operator converter for NonZero""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + if len(inputs) != 2: + raise ValueError("Bitshift expects 2 inputs") + + direction = attr.get("direction", "LEFT").decode("ascii") + if direction == "LEFT": + out = _op.left_shift(*inputs) + elif direction == "RIGHT": + out = _op.right_shift(*inputs) + else: + raise ValueError("Unsupported Shift Direction: " + direction) + return out + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -2796,6 +2814,7 @@ def _get_convert_map(opset): # defs/experimental "Identity": Renamer("copy"), "Affine": Affine.get_converter(opset), + "BitShift": BitShift.get_converter(opset), "ThresholdedRelu": ThresholdedRelu.get_converter(opset), "ScaledTanh": ScaledTanh.get_converter(opset), "ParametricSoftplus": ParametricSoftPlus.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7ed330dd47a9..5c6a735f901e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4137,14 +4137,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): unsupported_onnx_tests = [ "test_basic_convinteger/", - "test_bitshift_left_uint16/", - "test_bitshift_left_uint32/", - "test_bitshift_left_uint64/", - "test_bitshift_left_uint8/", - "test_bitshift_right_uint16/", - "test_bitshift_right_uint32/", - "test_bitshift_right_uint64/", - "test_bitshift_right_uint8/", "test_cast_DOUBLE_to_FLOAT16/", "test_cast_FLOAT16_to_DOUBLE/", "test_cast_FLOAT16_to_FLOAT/",