From fe5e6e48861a811747669d8ace1b8b2a4eaee839 Mon Sep 17 00:00:00 2001 From: Tlopex <68688494+tlopex@users.noreply.github.com> Date: Sat, 16 Sep 2023 13:56:31 +0800 Subject: [PATCH] [TFLite][Frontend] Support quantized less (#15746) * Update tflite.py * Update test_forward.py * Update test_forward.py --- python/tvm/relay/frontend/tflite.py | 4 +--- tests/python/frontend/tflite/test_forward.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 524b80d091ff..98920b1e496a 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1484,9 +1484,7 @@ def convert_greater_equal(self, op): def convert_less(self, op): """Convert TFLite LESS""" - if self.is_quantized(op): - raise tvm.error.OpNotImplemented("TFlite quantized LESS operator is not supported yet.") - return self._convert_elemwise(_op.less, op) + return self._convert_elemwise(_op.less, op, self.is_quantized(op), comparison_op=True) def convert_less_equal(self, op): """Convert TFLite LESS_EQUAL""" diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index c07f24612da4..1552580ff400 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2682,9 +2682,17 @@ def _test_greater_equal(data): # ---- -def _test_less(data): +def _test_less(data, fused_activation_function=None, quantized=False, qnn_op=None): """One iteration of less""" - return _test_elemwise(math_ops.less, data) + return _test_elemwise( + math_ops.less, + data, + fused_activation_function, + quantized, + qnn_op, + same_qnn_params=True, + comparison_op=True, + ) ####################################################################### @@ -2823,6 +2831,7 @@ def _test_elemwise_qnn_out_range(qnn_op): _test_greater: (-150, 150), _test_squared_difference: (0, 65025), _test_floor_divide: (-150, 150), + _test_less: (-150, 150), _test_floor_mod: (-150, 150), } @@ -2859,6 +2868,7 @@ def test_all_elemwise(): _test_forward_elemwise_quantized(_test_squared_difference, np.int8) _test_forward_elemwise(_test_greater_equal) _test_forward_elemwise(_test_less) + _test_forward_elemwise_quantized(_test_less) _test_forward_elemwise(_test_less_equal) _test_forward_elemwise(_test_equal) _test_forward_elemwise_quantized(_test_equal)