From d8afafec04ceb888eb2d32ea2734940ad3685be9 Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Fri, 17 Apr 2020 21:52:04 +0000 Subject: [PATCH] [Topi, ARM] Disbale Winograd for quantized tensors. --- python/tvm/relay/op/strategy/arm_cpu.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index bcef8ab43a24..520de39b240f 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -59,16 +59,21 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack), name="conv2d_nchw_spatial_pack.arm_cpu") + # Intel x86 conv2d schedule. strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_nchw), wrap_topi_schedule(topi.x86.schedule_conv2d_nchw), name="conv2d_nchw.x86") + # check if winograd algorithm is applicable _, _, kh, kw = get_const_tuple(kernel.shape) pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (kh, kw)) - if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \ - dilation_h == 1 and dilation_w == 1: + is_dtype_fp32 = data.dtype == "float32" and kernel.dtype == "float32" + is_winograd_applicable = kh == 3 and kw == 3 and \ + stride_h == 1 and stride_w == 1 and \ + dilation_h == 1 and dilation_w == 1 + if is_dtype_fp32 and is_winograd_applicable: strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd),