diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index d63cd8c83a93..6e66939ee2f0 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -285,7 +285,7 @@ def check_concatenate(expr): return False attrs, type_args = expr.attrs, expr.type_args for idx in range(len(type_args[0].fields)): - if type_args[0].fields[idx].dtype not in ["float32", "uint8"]: + if type_args[0].fields[idx].dtype not in ["float32", "uint8", "int8"]: return False # ACL concatenate only supports maximum 4 dimensions input tensor if attrs.axis not in [-4, -3, -2, -1, 0, 1, 2, 3]: diff --git a/tests/python/contrib/test_arm_compute_lib/test_concatenate.py b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py index deba26a0db56..55072f37c2bf 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_concatenate.py +++ b/tests/python/contrib/test_arm_compute_lib/test_concatenate.py @@ -17,6 +17,7 @@ """Arm Compute Library integration concatenate tests.""" import numpy as np +import pytest import tvm from tvm import relay @@ -88,16 +89,9 @@ def _get_expected_codegen(input_shape_a, input_shape_b, input_shape_c, axis, dty return [input_a, input_b, input_c, node] -def test_concatenate(): - Device.load("test_config.json") - - if skip_runtime_test(): - return - - device = Device() - np.random.seed(0) - - for input_shape_a, input_shape_b, input_shape_c, axis, dtype in [ +@pytest.mark.parametrize( + "input_shape_a, input_shape_b, input_shape_c, axis, dtype", + [ ([1, 234, 234, 256], [2, 234, 234, 256], [3, 234, 234, 256], 0, "float32"), ([1, 1, 234, 256], [1, 2, 234, 256], [1, 3, 234, 256], 1, "float32"), ([1, 234, 234, 1], [1, 234, 234, 2], [1, 234, 234, 3], -1, "float32"), @@ -106,29 +100,43 @@ def test_concatenate(): ([1, 1, 234, 256], [1, 2, 234, 256], [1, 3, 234, 256], 1, "uint8"), ([1, 234, 234, 1], [1, 234, 234, 2], [1, 234, 234, 3], -1, "uint8"), ([1, 234, 234, 256], [2, 234, 234, 256], [3, 234, 234, 256], -4, "uint8"), - ]: - outputs = [] - inputs = { - "a": tvm.nd.array(np.random.randn(*input_shape_a).astype(dtype)), - "b": tvm.nd.array(np.random.randn(*input_shape_b).astype(dtype)), - "c": tvm.nd.array(np.random.randn(*input_shape_c).astype(dtype)), - } - func = _get_model( - inputs["a"].shape, inputs["b"].shape, inputs["c"].shape, axis, dtype, iter(inputs) + ([1, 234, 234, 256], [2, 234, 234, 256], [3, 234, 234, 256], 0, "int8"), + ([1, 1, 234, 256], [1, 2, 234, 256], [1, 3, 234, 256], 1, "int8"), + ([1, 234, 234, 1], [1, 234, 234, 2], [1, 234, 234, 3], -1, "int8"), + ([1, 234, 234, 256], [2, 234, 234, 256], [3, 234, 234, 256], -4, "int8"), + ], +) +def test_concatenate(input_shape_a, input_shape_b, input_shape_c, axis, dtype): + Device.load("test_config.json") + + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.randn(*input_shape_a).astype(dtype)), + "b": tvm.nd.array(np.random.randn(*input_shape_b).astype(dtype)), + "c": tvm.nd.array(np.random.randn(*input_shape_c).astype(dtype)), + } + func = _get_model( + inputs["a"].shape, inputs["b"].shape, inputs["c"].shape, axis, dtype, iter(inputs) + ) + for acl in [False, True]: + outputs.append( + build_and_run(func, inputs, 1, None, device, enable_acl=acl, disabled_ops=[])[0] ) - for acl in [False, True]: - outputs.append( - build_and_run(func, inputs, 1, None, device, enable_acl=acl, disabled_ops=[])[0] - ) - - config = { - "input_shape_a": input_shape_a, - "input_shape_b": input_shape_b, - "input_shape_c": input_shape_c, - "axis": axis, - "dtype": dtype, - } - verify(outputs, atol=1e-7, rtol=1e-7, config=config) + + config = { + "input_shape_a": input_shape_a, + "input_shape_b": input_shape_b, + "input_shape_c": input_shape_c, + "axis": axis, + "dtype": dtype, + } + verify(outputs, atol=1e-7, rtol=1e-7, config=config) def test_codegen_concatenate():