Skip to content

Commit

Permalink
[TFLite][Frontend] Support quantized floor_div (#15724)
Browse files Browse the repository at this point in the history
As per #15148, we are adding support for quantized operations one by one. This PR adds support for floor_div.
  • Loading branch information
p3achyjr committed Sep 12, 2023
1 parent d8136fb commit e3055c1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
6 changes: 1 addition & 5 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2686,11 +2686,7 @@ def convert_pad(self, op):

def convert_floor_div(self, op):
"""Convert TFLite FLOOR_DIV"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
"TFlite quantized FLOOR DIV operator is not supported yet."
)
return self._convert_elemwise(_op.floor_divide, op)
return self._convert_elemwise(_op.floor_divide, op, self.is_quantized(op))

def convert_floor_mod(self, op):
"""Convert TFLite FLOOR_MOD"""
Expand Down
13 changes: 11 additions & 2 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2747,9 +2747,16 @@ def _test_squared_difference(data, fused_activation_function=None, quantized=Fal
# ------------


def _test_floor_divide(data):
def _test_floor_divide(data, fused_activation_function=None, quantized=False, qnn_op=None):
"""One iteration of floor_div"""
return _test_elemwise(math_ops.floordiv, data)
return _test_elemwise(
math_ops.floordiv,
data,
fused_activation_function,
quantized,
qnn_op,
same_qnn_params=True,
)


#######################################################################
Expand Down Expand Up @@ -2808,6 +2815,7 @@ def _test_elemwise_qnn_out_range(qnn_op):
_test_equal: (-150, 150),
_test_greater: (-150, 150),
_test_squared_difference: (0, 65025),
_test_floor_divide: (-150, 150),
}

return qnn_out_range[qnn_op]
Expand Down Expand Up @@ -2849,6 +2857,7 @@ def test_all_elemwise():
_test_forward_elemwise(_test_not_equal)
if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
_test_forward_elemwise(_test_floor_divide)
_test_forward_elemwise_quantized(_test_floor_divide)
_test_forward_elemwise(_test_floor_mod)


Expand Down

0 comments on commit e3055c1

Please sign in to comment.