Skip to content

Commit

Permalink
conv3d_ndhwc schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgl-github committed Jan 24, 2020
1 parent 9bd2c7b commit 57f4722
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 7 deletions.
14 changes: 7 additions & 7 deletions topi/python/topi/nn/conv3d.py
Expand Up @@ -186,15 +186,15 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
pad_before = [0, pad_front, pad_top, pad_left, 0]
pad_after = [0, pad_back, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
rd = tvm.reduce_axis((0, kernel_d), name='rd')
rh = tvm.reduce_axis((0, kernel_h), name='rh')
rw = tvm.reduce_axis((0, kernel_w), name='rw')
rc = tvm.reduce_axis((0, in_channel), name='rc')
rz = tvm.reduce_axis((0, kernel_d), name='rz')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
Output = tvm.compute(
(batch, out_depth, out_height, out_width, out_channel),
lambda nn, zz, yy, xx, ff: tvm.sum(
PaddedInput[nn, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
Filter[rz, ry, rx, rc, ff].astype(out_dtype), axis=[rz, ry, rx, rc]),
lambda nn, dd, hh, ww, cc: tvm.sum(
PaddedInput[nn, dd * stride_d + rd * dilation_d, hh * stride_h + rh * dilation_h,
ww * stride_w + rw * dilation_w, rc].astype(out_dtype) *
Filter[rd, rh, rw, rc, cc].astype(out_dtype), axis=[rd, rh, rw, rc]),
name="Conv3dOutput", tag="conv3d_ndhwc")
return Output
1 change: 1 addition & 0 deletions topi/python/topi/x86/__init__.py
Expand Up @@ -21,6 +21,7 @@

from .conv1d import schedule_conv1d_nwc
from .conv2d import schedule_conv2d, schedule_conv2d_nhwc
from .conv3d import schedule_conv3d_ndhwc
from .binarize_pack import schedule_binarize_pack
from .binary_dense import schedule_binary_dense
from .nn import *
Expand Down
93 changes: 93 additions & 0 deletions topi/python/topi/x86/conv3d.py
@@ -0,0 +1,93 @@
import tvm
from tvm import autotvm
from .. import generic, tag
from ..nn.conv3d import conv3d, conv3d_ndhwc, conv3d_ncdhw
from ..generic.nn import schedule_conv3d_ndhwc

@autotvm.register_topi_compute(conv3d, 'cpu', ['direct'])
def conv3d_x86(cfg, input, filter, strides, padding, dilation, layout='NCDHW', out_dtype=None):
if layout == 'NCDHW':
return conv3d_ncdhw(input, filter, strides, padding, dilation, out_dtype)
elif layout == 'NDHWC':
return conv3d_ndhwc(input, filter, strides, padding, dilation, out_dtype)

@autotvm.register_topi_schedule(schedule_conv3d_ndhwc, 'cpu', ['direct'])
def schedule_conv3d_ndhwc_x86(cfg, outs):
"""TOPI schedule callback for conv2d
Parameters
----------
cfg: ConfigEntity
The config for this template
outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d.
"""
s = tvm.create_schedule([x.op for x in outs])
output_op = outs[0].op
scheduled_ops = []

def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else: # inject custom schedule
if len(op.axis) == 5: # schedule bias + bn + relu
n, d, h, w, c = op.axis
fused = s[op].fuse(n, d, h, w)
s[op].parallel(fused)
s[op].vectorize(c)
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)

if 'conv3d_ndhwc' in op.tag:
conv = op.output(0)
kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()

data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n_pad, d_pad, h_pad, w_pad, c_pad = data_pad.op.axis
pad_fused = s[data_pad].fuse(h_pad, w_pad)
s[data_pad].parallel(pad_fused)

C = conv
# data axes
n, d, h, w, c = s[C].op.axis

if True:
# tile data h and w
ho, wo, hi, wi = s[C].tile(h, w, 2, 2)
# kernel axes
kd, ky, kx, kc = s[C].op.reduce_axis
kxi, kxo = s[C].split(kx, factor=2)
kci, kco = s[C].split(kc, factor=2)
#
s[C].reorder(n, d, ho, wo, hi, wi, c, kxo, kco, kxi, kci)
s[C].unroll(kci)

s[C].vectorize(c)
if op != output_op:
_, _, _, _, c_out = output_op.axis
s[C].compute_at(s[output_op], c_out)
else:
fused = s[C].fuse(n, d)
s[C].parallel(fused)

scheduled_ops.append(op)

traverse(output_op)
return s

0 comments on commit 57f4722

Please sign in to comment.