Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][AutoTVM] NHWC conv2d templates for ARM #3859

Merged
merged 2 commits into from
Dec 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/tvm/autotvm/task/topi_integration.py
Expand Up @@ -179,12 +179,15 @@ def _topi_nn_conv2d(*args, **kwargs):
args = deserialize_args(args)
A, W = args[:2]
layout = args[-2]
assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently"
C = topi.nn.conv2d(*args, **kwargs)
if layout == 'NCHW':
s = topi.generic.schedule_conv2d_nchw([C])
else:
elif layout == 'HWCN':
s = topi.generic.schedule_conv2d_hwcn([C])
elif layout == 'NHWC':
s = topi.generic.schedule_conv2d_nhwc([C])
else:
raise ValueError("Unsupported layout {}".format(layout))
return s, [A, W, C]

@register("topi_nn_depthwise_conv2d_nchw")
Expand Down
38 changes: 36 additions & 2 deletions topi/python/topi/arm_cpu/conv2d.py
Expand Up @@ -24,7 +24,8 @@
from tvm import autotvm
import tvm.contrib.nnpack

from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
from ..generic import schedule_conv2d_nchw, schedule_conv2d_nhwc, \
schedule_conv2d_winograd_without_weight_transform, \
schedule_conv2d_winograd_nnpack_without_weight_transform
from ..util import traverse_inline, get_const_tuple
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
Expand All @@ -34,7 +35,9 @@
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices
from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \
schedule_conv2d_spatial_pack_nchw
conv2d_spatial_pack_nhwc, \
schedule_conv2d_spatial_pack_nchw, \
schedule_conv2d_spatial_pack_nhwc

logger = logging.getLogger('topi')

Expand Down Expand Up @@ -78,6 +81,9 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dt
if layout == 'NCHW':
return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
dilation, out_dtype, num_tile=2)
elif layout == 'NHWC':
return conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding,
dilation, out_dtype)
else:
raise ValueError("Unsupported layout {}".format(layout))

Expand Down Expand Up @@ -136,6 +142,34 @@ def _callback(op):
traverse_inline(s, outs[0].op, _callback)
return s

@autotvm.register_topi_schedule(schedule_conv2d_nhwc, 'arm_cpu', ['direct'])
def schedule_conv2d_nhwc_arm_cpu(cfg, outs):
"""TOPI schedule callback for conv2d

Parameters
----------
cfg: ConfigEntity
The config for this template

outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for conv2d.
"""
s = tvm.create_schedule([x.op for x in outs])

def _callback(op):
if 'spatial_conv_output_NHWC' in op.tag:
schedule_conv2d_spatial_pack_nhwc(cfg, s, op, outs[0])

traverse_inline(s, outs[0].op, _callback)
return s


@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd'])
def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
""" TOPI compute callback. Use winograd template """
Expand Down
160 changes: 160 additions & 0 deletions topi/python/topi/arm_cpu/conv2d_spatial_pack.py
Expand Up @@ -196,3 +196,163 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
s[kernel_vec].parallel(co)

return s

def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""Spatial pack compute for Conv2d NHWC"""
out_dtype = out_dtype or data.dtype

N, IH, IW, IC = get_const_tuple(data.shape)
assert len(kernel.shape) == 4, "AlterOpLayout not enabled for NHWC yet"
KH, KW, _, OC = get_const_tuple(kernel.shape)

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation

dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1

pad_top, pad_left, pad_down, pad_right = \
get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)

OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0])

# ==================== define configuration space ====================
n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW)
ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW)

oco, oci = cfg.define_split('tile_co', oc, num_outputs=2)
oho, ohi = cfg.define_split('tile_oh', oh, num_outputs=2)
owo, owi = cfg.define_split('tile_ow', ow, num_outputs=2)

cfg.define_reorder('reorder_conv',
[n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
policy='candidate', candidate=[
[n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
[n, oho, owo, oco, ohi, kh, kw, ic, owi, oci],
[n, oho, owo, oco, ohi, kh, kw, owi, ic, oci],
[n, oho, owo, ohi, oco, kh, kw, owi, ic, oci]])

cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
cfg.define_annotate("ann_spatial", [ohi, owi, oci], policy='try_unroll_vec')
# ====================================================================

OCI = cfg['tile_co'].size[-1]
OHI = cfg['tile_oh'].size[-1]
OWI = cfg['tile_ow'].size[-1]
OCO = OC // OCI
OHO = OH // OHI
OWO = OW // OWI

kvshape = (OCO, KH, KW, IC, OCI)
ovshape = (N, OHO, OWO, OCO, OHI, OWI, OCI)
oshape = (N, OH, OW, OC)

if dilation_h != 1 or dilation_w != 1:
# undilate input data
dvshape = (N, OHO, OWO, KH, KW, IC, OHI, OWI)
data_vec = tvm.compute(dvshape, lambda n, oho, owo, kh, kw, ic, ohi, owi:
data_pad[n][(oho*OHI+ohi)*HSTR+kh*dilation_h]
[(owo*OWI+owi)*WSTR+kw*dilation_w][ic],
name='data_vec_undilated')
else:
dvshape = (N, OHO, OWO, KH + (OHI-1)*HSTR, KW + (OWI-1)*WSTR, IC)
data_vec = tvm.compute(dvshape, lambda n, oho, owo, ohi, owi, ic:
data_pad[n][oho*OHI*HSTR+ohi][owo*OWI*WSTR+owi][ic],
name='data_vec')
kernel_vec = tvm.compute(kvshape, lambda oco, kh, kw, ic, oci: \
kernel[kh][kw][ic][oco*OCI+oci],
name='kernel_vec')

ic = tvm.reduce_axis((0, IC), name='ic')
kh = tvm.reduce_axis((0, KH), name='kh')
kw = tvm.reduce_axis((0, KW), name='kw')

if dilation_h != 1 or dilation_w != 1:
conv = tvm.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \
tvm.sum(data_vec[n, oho, owo, kh, kw, ohi, owi, ic].astype(out_dtype) *
kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
axis=[ic, kh, kw]), name='conv')
else:
conv = tvm.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \
tvm.sum(data_vec[n, oho, owo, ohi*HSTR+kh, owi*WSTR+kw, ic].astype(out_dtype) *
kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
axis=[ic, kh, kw]), name='conv')

idiv = tvm.indexdiv
imod = tvm.indexmod
output = tvm.compute(oshape, lambda n, oho, owo, oc:
conv[n][idiv(oho, OHI)][idiv(owo, OWI)][idiv(oc, OCI)]\
[imod(oho, OHI)][imod(owo, OWI)][imod(oc, OCI)],
name='output_unpack', tag='spatial_conv_output_NHWC')
zhenhuaw-me marked this conversation as resolved.
Show resolved Hide resolved
return output

def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output):
"""Spatial Pack schedule for Conv2d NHWC"""
unpack = op.output(0)
conv = unpack.op.input_tensors[0]
data_vec = conv.op.input_tensors[0]
kernel_vec = conv.op.input_tensors[1]
data_pad = data_vec.op.input_tensors[0]
OHI = cfg['tile_oh'].size[-1]
OWI = cfg['tile_ow'].size[-1]
OCI = cfg['tile_co'].size[-1]

# schedule unpack/output
if output != unpack:
s[unpack].compute_inline()
n, oh, ow, oc = s[output].op.axis
oco, oci = cfg['tile_co'].apply(s, output, oc)
oho, ohi = cfg['tile_oh'].apply(s, output, oh)
owo, owi = cfg['tile_ow'].apply(s, output, ow)
s[output].reorder(n, oho, owo, oco, ohi, owi, oci)
cfg['ann_spatial'].apply(s, output, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI],
max_unroll=16, cfg=cfg)
cfg.define_knob('compat', [0, 1, 2])
if cfg['compat'].val < 2:
compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706
s[conv].compute_at(s[output], compat_axis)
paxis = s[output].fuse(n, oho)
s[output].parallel(paxis)

# schedule conv
n, oho, owo, oco, ohi, owi, oci = s[conv].op.axis
ic, kh, kw = s[conv].op.reduce_axis
cfg['reorder_conv'].apply(s, conv, [n, oho, owo, oco, kh, kw, ohi, owi, ic, oci])
cfg['ann_reduce'].apply(s, conv, [kh, kw],
axis_lens=[get_const_int(kh.dom.extent),
get_const_int(kw.dom.extent)],
max_unroll=16,
cfg=cfg)
zhenhuaw-me marked this conversation as resolved.
Show resolved Hide resolved
cfg['ann_spatial'].apply(s, conv, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI],
max_unroll=16, cfg=cfg)
if cfg['compat'].val < 2:
compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706
s[kernel_vec].compute_at(s[conv], compat_axis)
s[data_vec].compute_at(s[conv], compat_axis)

# schedule kernel pack
oco, kh, kw, ic, oci = kernel_vec.op.axis
s[kernel_vec].vectorize(oci)
s[kernel_vec].unroll(ic)
if cfg['compat'].val == 2:
s[kernel_vec].parallel(oco)
zhenhuaw-me marked this conversation as resolved.
Show resolved Hide resolved

# schedule data pack
if data_vec.op.name == 'data_vec_undilated':
n, oho, owo, kh, kw, ic, ohi, owi = s[data_vec].op.axis
s[data_vec].vectorize(owi)
s[data_vec].unroll(ohi)
else:
n, oho, owo, ohi, owi, ic = s[data_vec].op.axis
s[data_vec].vectorize(ic)
s[data_vec].unroll(owi)
if cfg['compat'].val == 2:
paxis = s[data_vec].fuse(n, oho)
s[data_vec].parallel(paxis)

zhenhuaw-me marked this conversation as resolved.
Show resolved Hide resolved
return s