Skip to content

Commit

Permalink
[TFLITE]Hard Swish & MobilnetV3 model testing (#5239)
Browse files Browse the repository at this point in the history
* [TFLITE]Hard Swish & MobilnetV3 model testing

* CI Failure addressed
  • Loading branch information
siju-samuel committed Apr 7, 2020
1 parent 00a8481 commit 608e945
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
37 changes: 37 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Expand Up @@ -84,6 +84,7 @@ def __init__(self, model, subgraph, exp_tab):
'FULLY_CONNECTED': self.convert_fully_connected,
'GREATER_EQUAL': self.convert_greater_equal,
'GREATER': self.convert_greater,
'HARD_SWISH': self.convert_hard_swish,
'L2_NORMALIZATION': self.convert_l2_normalization,
'LESS_EQUAL': self.convert_less_equal,
'LESS': self.convert_less,
Expand Down Expand Up @@ -595,6 +596,42 @@ def convert_relu(self, op):

return out

def convert_hard_swish(self, op):
"""Convert TFLite Hard swish"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

def _relu6(data):
return _op.tensor.clip(data, 0.0, 6.0)

def _hard_swish(data):
return data * _relu6(data + relay.const(3.0)) / relay.const(6.0)

# Dequantize if the input is quantized.
if input_tensor.qnn_params:
in_expr = self.dequantize(in_expr, input_tensor)

# Perform hardswish
out = _hard_swish(in_expr)

# Go back to integer dataype if the original operator was quantized.
if output_tensor.qnn_params:
out = self.quantize(out, output_tensor)

return out

def convert_concatenation(self, op):
"""Convert TFLite concatenation"""
try:
Expand Down
51 changes: 51 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Expand Up @@ -1625,6 +1625,26 @@ def test_forward_mobilenet_v2():
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)

#######################################################################
# Mobilenet V3
# ------------

def test_forward_mobilenet_v3():
"""Test the Mobilenet V3 TF Lite model."""
# In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'):
return
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_float.tgz",
"v3-large_224_1.0_float/v3-large_224_1.0_float.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)

#######################################################################
# Inception
# ---------
Expand Down Expand Up @@ -1723,6 +1743,35 @@ def test_forward_qnn_mobilenet_v2_net():
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

#######################################################################
# Mobilenet V3 Quantized
# ----------------------

def test_forward_qnn_mobilenet_v3_net():
"""Test the Quantized TFLite Mobilenet V3 model."""
# In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'):
return

tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_uint8.tgz",
"v3-large_224_1.0_uint8/v3-large_224_1.0_uint8.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()

# Test image. Checking the labels because the requantize implementation is different between
# TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
# labels. Also, giving a real image, instead of random inputs.
data = get_real_image(224, 224)

tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

#######################################################################
# SSD Mobilenet
# -------------
Expand Down Expand Up @@ -1831,6 +1880,7 @@ def test_forward_mediapipe_hand_landmark():
# End to End
test_forward_mobilenet_v1()
test_forward_mobilenet_v2()
test_forward_mobilenet_v3()
test_forward_inception_v3_net()
test_forward_inception_v4_net()
test_forward_ssd_mobilenet_v1()
Expand All @@ -1840,3 +1890,4 @@ def test_forward_mediapipe_hand_landmark():
test_forward_qnn_inception_v1_net()
test_forward_qnn_mobilenet_v1_net()
test_forward_qnn_mobilenet_v2_net()
test_forward_qnn_mobilenet_v3_net()

0 comments on commit 608e945

Please sign in to comment.