From ece073a64018e9153ca04779bf583e0cd214147b Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Wed, 5 Jan 2022 16:38:03 +0000 Subject: [PATCH 1/2] [CMSIS-NN] Support for asymmetric padding in Convolutions Change-Id: Ife2cb6e4cd0e0a2438c95f55e5482b728b9df37a --- python/tvm/relay/op/contrib/cmsisnn.py | 2 -- .../python/contrib/test_cmsisnn/test_conv2d.py | 17 ++++++----------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index 9ca9b979a44d..7af47c3a81a1 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -135,8 +135,6 @@ def check_qnn_conv2d(pattern): return ( conv2d.attrs.out_dtype == "int32" - and int(conv2d.attrs.padding[0]) == int(conv2d.attrs.padding[2]) - and int(conv2d.attrs.padding[1]) == int(conv2d.attrs.padding[3]) and conv2d_input.checked_type.dtype == "int8" and conv2d_weight.checked_type.dtype == "int8" and pattern.checked_type.dtype == "int8" diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 7bbbc810894e..71463c3abe63 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -69,8 +69,6 @@ def make_model( kernel_w = kernel_shape[w_index] invar = relay.var("input", shape=shape, dtype=dtype) p = (0, 0, 0, 0) - if padding == "INVALID": - p = [1, 2, 2, 1] if padding == "SAME": p = get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) invar = relay.nn.pad( @@ -126,10 +124,10 @@ def make_model( @tvm.testing.requires_cmsisnn -@pytest.mark.parametrize("ifm_shape", [(1, 28, 28, 12), (1, 64, 100, 4)]) -@pytest.mark.parametrize("kernel_size", [(3, 3)]) +@pytest.mark.parametrize("ifm_shape", [(1, 25, 25, 12), (1, 64, 100, 4)]) +@pytest.mark.parametrize("kernel_size", [(5, 5)]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) -@pytest.mark.parametrize("strides, dilation", [((1, 1), (1, 1))]) +@pytest.mark.parametrize("strides, dilation", [((2, 2), (1, 1))]) @pytest.mark.parametrize("relu_type", ["RELU"]) @pytest.mark.parametrize("enable_bias", [True, False]) @pytest.mark.parametrize( @@ -353,19 +351,17 @@ def parameterize_for_invalid_model(test): in_dtype = ["uint8", "int8"] kernel_dtype = ["uint8", "int8"] kernel_zero_point = [-33, 10, 0] - padding = ["SAME", "INVALID"] - all_combinations = itertools.product(in_dtype, kernel_dtype, kernel_zero_point, padding) + all_combinations = itertools.product(in_dtype, kernel_dtype, kernel_zero_point) all_combinations = filter( lambda parameters: not ( parameters[0] == "int8" and parameters[1] == "int8" and parameters[2] == 0 - and parameters[3] == "SAME" ), all_combinations, ) return pytest.mark.parametrize( - ["in_dtype", "kernel_dtype", "kernel_zero_point", "padding"], + ["in_dtype", "kernel_dtype", "kernel_zero_point"], all_combinations, )(test) @@ -376,7 +372,6 @@ def test_invalid_parameters( in_dtype, kernel_dtype, kernel_zero_point, - padding, ): ifm_shape = (1, 28, 28, 12) out_channels = 2 @@ -407,7 +402,7 @@ def test_invalid_parameters( kernel_scale=kernel_scale, output_zero_point=output_zero_point, output_scale=output_scale, - padding=padding, + padding="SAME", strides=(1, 1), dilation=(1, 1), groups=1, From dd699d632cfd8c051a8f311c3d753e4d8173ef0b Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Mon, 10 Jan 2022 14:16:16 +0000 Subject: [PATCH 2/2] Fixed lint error in the CMSIS-NN unit tests Change-Id: Ifb3e2efca4788ea13a030e860a8e1aa0cef43d9f --- tests/python/contrib/test_cmsisnn/test_conv2d.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 71463c3abe63..d8c559cec6e0 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -354,9 +354,7 @@ def parameterize_for_invalid_model(test): all_combinations = itertools.product(in_dtype, kernel_dtype, kernel_zero_point) all_combinations = filter( lambda parameters: not ( - parameters[0] == "int8" - and parameters[1] == "int8" - and parameters[2] == 0 + parameters[0] == "int8" and parameters[1] == "int8" and parameters[2] == 0 ), all_combinations, )