Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Topi, ARM] Disbale Winograd for quantized tensors. #5363

Merged
merged 2 commits into from Apr 21, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/tvm/relay/op/strategy/arm_cpu.py
Expand Up @@ -59,16 +59,21 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.arm_cpu")

# Intel x86 conv2d schedule.
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
name="conv2d_nchw.x86")

# check if winograd algorithm is applicable
_, _, kh, kw = get_const_tuple(kernel.shape)
pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (kh, kw))
if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
is_dtype_fp32 = data.dtype == "float32" and kernel.dtype == "float32"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float16 Too ? (perhaps double aka float64 too)
I get good results with it (e.g yolov3-tiny on mali).
Wouldn't be better to deny int8/int16/int32 ?

is_winograd_applicable = kh == 3 and kw == 3 and \
stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1
if is_dtype_fp32 and is_winograd_applicable:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd),
Expand Down