From 62521453a03d73c90900ba08ced6af06ee7f543a Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Mon, 11 Nov 2019 15:46:29 -0800 Subject: [PATCH] Add More Shape Functions (#4179) * Add shape functions * Fix get_const_tuple * Fix cpplint * Fix pylint * Fix pylint * rebase and fix * Check Any for infer type * Fix expand_dim shape func for zero rank input * Fix pooling infer type * Address comment * Register layout transform attr --- python/tvm/autotvm/task/task.py | 6 +- python/tvm/autotvm/util.py | 14 +- python/tvm/relay/op/_reduce.py | 68 ++++++++ python/tvm/relay/op/_tensor.py | 25 ++- python/tvm/relay/op/_transform.py | 194 +++++++++++++++++++++- python/tvm/relay/op/nn/_nn.py | 170 ++++++++++++++++++- src/lang/data_layout.cc | 18 +- src/relay/op/nn/convolution.cc | 15 +- src/relay/op/nn/convolution.h | 17 +- src/relay/op/nn/nn.cc | 7 +- src/relay/op/nn/pad.cc | 8 +- src/relay/op/nn/pooling.cc | 25 ++- src/relay/op/tensor/reduce.cc | 15 +- src/relay/op/tensor/transform.cc | 16 +- tests/python/relay/test_any.py | 264 ++++++++++++++++++++++++++++++ topi/include/topi/nn/flatten.h | 4 +- topi/python/topi/util.py | 14 +- topi/python/topi/x86/conv2d.py | 7 + topi/python/topi/x86/dense.py | 35 +++- 19 files changed, 864 insertions(+), 58 deletions(-) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 4f3cc90b474e..7f36914eb0a6 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -219,15 +219,15 @@ def args_to_workload(x, topi_compute_func=None): workload = get_const_tuple(x.shape) + (x.dtype, ) elif isinstance(x, (tuple, list, container.Array)): workload = tuple([args_to_workload(a) for a in x]) - elif isinstance(x, (str, int, float, np.int, np.float)): + elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)): workload = x elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)): workload = x.value elif x is None: workload = 0 else: - raise RuntimeError('Do not support type "%s" in argument. Consider to use ' - 'primitive types only' % type(x)) + raise RuntimeError('Do not support type "%s" in argument. Consider to use' + 'primitive types or tvm.expr.Var only' % type(x)) return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload def template(func): diff --git a/python/tvm/autotvm/util.py b/python/tvm/autotvm/util.py index 7c98acd317fa..3026914aed20 100644 --- a/python/tvm/autotvm/util.py +++ b/python/tvm/autotvm/util.py @@ -163,7 +163,7 @@ def get_const_int(exp): def get_const_tuple(in_tuple): - """Verifies input tuple is IntImm, returns tuple of int. + """Verifies input tuple is IntImm or Var, returns tuple of int or Var. Parameters ---------- @@ -175,4 +175,14 @@ def get_const_tuple(in_tuple): out_tuple : tuple of int The output. """ - return tuple(get_const_int(x) for x in in_tuple) + ret = [] + for elem in in_tuple: + if isinstance(elem, expr.Var): + ret.append(elem) + elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)): + elem = ir_pass.Simplify(elem) + if not isinstance(elem, (expr.IntImm, expr.UIntImm)): + ret.append(elem) + else: + ret.append(get_const_int(elem)) + return tuple(ret) diff --git a/python/tvm/relay/op/_reduce.py b/python/tvm/relay/op/_reduce.py index 06d0d66bdfb0..43f71c0aa679 100644 --- a/python/tvm/relay/op/_reduce.py +++ b/python/tvm/relay/op/_reduce.py @@ -18,7 +18,11 @@ from __future__ import absolute_import import topi + +from topi.util import get_const_int, get_const_tuple from . import op as _reg +from ...api import convert +from ...hybrid import script def _schedule_reduce(_, outs, target): @@ -39,3 +43,67 @@ def _schedule_reduce(_, outs, target): _reg.register_schedule("variance", _schedule_reduce) _reg.register_schedule("nn.cross_entropy", _schedule_reduce) _reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce) + + +def _create_axis_record(attrs, inputs): + axes = attrs.axis if attrs.axis is None else list(get_const_tuple(attrs.axis)) + exclude = get_const_int(attrs.exclude) > 0 + keepdims = get_const_int(attrs.keepdims) > 0 + data_shape = inputs[0] + shape_size = data_shape.shape[0].value + axis_record = [-1] * shape_size + if axes is None: + axes = list(range(shape_size)) + + for i, axis in enumerate(axes): + if axis < 0: + axes[i] = shape_size + axis + + if exclude: + ex_axes = [] + for i in range(shape_size): + if i not in axes: + ex_axes.append(i) + axes = ex_axes + + for i in range(shape_size): + if i not in axes: + axis_record[i] = i + + if not keepdims: + tmp = [] + for i in axis_record: + if i >= 0: + tmp.append(i) + axis_record = tmp + + return axis_record + + +@script +def _reduce_shape_func(data_shape, axis_record): + out = output_tensor((len(axis_record),), "int64") + for i in const_range(len(axis_record)): + if axis_record[i] >= 0: + out[i] = data_shape[axis_record[i]] + else: + out[i] = int64(1) + + return out + +def reduce_shape_func(attrs, inputs, _): + """ + Shape function for reduce op. + """ + axis_record = _create_axis_record(attrs, inputs) + return [_reduce_shape_func(inputs[0], convert(axis_record))] + +_reg.register_shape_func("argmax", False, reduce_shape_func) +_reg.register_shape_func("argmin", False, reduce_shape_func) +_reg.register_shape_func("all", False, reduce_shape_func) +_reg.register_shape_func("sum", False, reduce_shape_func) +_reg.register_shape_func("max", False, reduce_shape_func) +_reg.register_shape_func("min", False, reduce_shape_func) +_reg.register_shape_func("prod", False, reduce_shape_func) +_reg.register_shape_func("mean", False, reduce_shape_func) +_reg.register_shape_func("variance", False, reduce_shape_func) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 188b3bb15956..dcff0845aed6 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -119,18 +119,6 @@ def _cast_shape_function(x): def cast_shape_func(attrs, inputs, out_ndims): return [_cast_shape_function(*inputs)] -@script -def _expand_dims_shape_func(x): - ndim = len(x.shape) - out = output_tensor((ndim+1,), "int64") - out[0] = int64(1) - for i in const_range(0, ndim): - out[i+1] = int64(x.shape[i]) - return out - -def expand_dims_shape_func(attrs, inputs, out_ndims): - return [_expand_dims_shape_func(*inputs)] - # shape func @script def _broadcast_shape_func(x, y, ndim): @@ -161,9 +149,17 @@ def _broadcast_shape_func(x, y, ndim): return out def broadcast_shape_func(attrs, inputs, out_ndims): + """ + Shape function for broadcast op. + """ return [_broadcast_shape_func(*inputs, out_ndims[0])] -register_shape_func("expand_dims", False, expand_dims_shape_func) +def elemwise_shape_func(attrs, inputs, _): + """ + Shape function for elemwise op. + """ + return [topi.math.identity(inputs[0])] + register_shape_func("cast", False, cast_shape_func) register_shape_func("add", False, broadcast_shape_func) @@ -179,3 +175,6 @@ def broadcast_shape_func(attrs, inputs, out_ndims): register_shape_func("less_equal", False, broadcast_shape_func) register_shape_func("greater", False, broadcast_shape_func) register_shape_func("greater_equal", False, broadcast_shape_func) + +register_shape_func("sqrt", False, elemwise_shape_func) +register_shape_func("negative", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 687d5b4c5b2c..13f41fc87001 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Backend compiler related feature registration""" -# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks +# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments from __future__ import absolute_import import tvm import topi @@ -303,3 +303,195 @@ def compute_argwhere(attrs, inputs, output_type, _): output_shape.append(tvm.var("any_dim", "int32")) new_output_type = tvm.relay.ty.TensorType(output_shape, "int32") return [topi.argwhere(new_output_type, inputs[0])] + +@script +def _layout_transform_shape_func(data_shape, + out_layout_len, + dst_equal_list, + dst_mul_list, + dst_div_list, + dst_mix_list): + out = output_tensor((out_layout_len,), "int64") + for i in const_range(len(dst_equal_list)): + out[dst_equal_list[i][0]] = data_shape[dst_equal_list[i][1]] + for i in const_range(len(dst_mul_list)): + out[dst_mul_list[i][0]] = data_shape[dst_mul_list[i][1]] * \ + data_shape[dst_mul_list[i][2]] + for i in const_range(len(dst_div_list)): + out[dst_div_list[i][0]] = data_shape[dst_div_list[i][1]] \ + // dst_div_list[i][3] + out[dst_div_list[i][2]] = int64(dst_div_list[i][3]) + for i in const_range(len(dst_mix_list)): + out[dst_mix_list[i][0]] = data_shape[dst_mix_list[i][1]] * \ + dst_mix_list[i][2] // dst_mix_list[i][4] + out[dst_mix_list[i][3]] = int64(dst_mix_list[i][4]) + + return out + +@_reg.register_shape_func("layout_transform", False) +def layout_transform_shape_func(attrs, inputs, _): + """ + Shape function for layout_transform op. + """ + def _fetch_axis(layout): + major_axes = [] + minor_axes = {} + num_start = -1 + for i, item in enumerate(layout): + if "A" <= item <= "Z": + major_axes.append(item) + elif "a" <= item <= "z": + last_num = int(layout[num_start:i]) + minor_axes[item] = last_num + num_start = -1 + elif num_start < 0: + num_start = i + return major_axes, minor_axes + + _, src_minor_axes = _fetch_axis(attrs.src_layout) + dst_major_axes, dst_minor_axes = _fetch_axis(attrs.dst_layout) + src_letter_list = [] + dst_letter_list = [] + for item in attrs.src_layout: + if "A" <= item <= "Z" or "a" <= item <= "z": + src_letter_list.append(item) + for item in attrs.dst_layout: + if "A" <= item <= "Z" or "a" <= item <= "z": + dst_letter_list.append(item) + out_layout_len = len(dst_major_axes) + len(dst_minor_axes) + dst_equal_list = [] + dst_mul_list = [] + dst_div_list = [] + dst_mix_list = [] + + for key in dst_major_axes: + if key.lower() not in dst_minor_axes: + if key.lower() not in src_minor_axes: + dst_equal_list.append((dst_letter_list.index(key), + src_letter_list.index(key))) + else: + dst_mul_list.append((dst_letter_list.index(key), + src_letter_list.index(key), + src_letter_list.index(key.lower()))) + else: + if key.lower() not in src_minor_axes: + dst_div_list.append((dst_letter_list.index(key), + src_letter_list.index(key), + dst_letter_list.index(key.lower()), + dst_minor_axes[key.lower()])) + else: + dst_mix_list.append((dst_letter_list.index(key), + src_letter_list.index(key), + src_minor_axes[key.lower()], + dst_letter_list.index(key.lower()), + dst_minor_axes[key.lower()])) + + return [_layout_transform_shape_func(inputs[0], + convert(out_layout_len), + convert(dst_equal_list), + convert(dst_mul_list), + convert(dst_div_list), + convert(dst_mix_list))] + +@script +def _expand_dim_shape_func(data_shape, ndim, axis, num_newaxis): + out = output_tensor((ndim + num_newaxis,), "int64") + for i in const_range(out.shape[0]): + if i < axis: + out[i] = data_shape[i] + elif i < axis + num_newaxis: + out[i] = int64(1) + else: + out[i] = data_shape[i - num_newaxis] + + return out + +@_reg.register_shape_func("expand_dims", False) +def expand_dim_shape_func(attrs, inputs, _): + """ + Shape function for expand_dim op. + """ + axis = get_const_int(attrs.axis) + num_newaxis = get_const_int(attrs.num_newaxis) + if axis < 0: + axis = inputs[0].shape[0] + axis + 1 + ndim = inputs[0].shape[0] if inputs[0].shape else 0 + return [_expand_dim_shape_func(inputs[0], + convert(ndim), + convert(axis), + convert(num_newaxis))] + +@script +def _transpose_shape_func(data_shape, axes): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(len(axes)): + out[i] = data_shape[axes[i]] + + return out + +@_reg.register_shape_func("transpose", False) +def transpose_shape_func(attrs, inputs, _): + """ + Shape function for transpose op. + """ + axes = attrs.axes if attrs.axes is None else get_const_tuple(attrs.axes) + if axes is None: + axes = list(range(inputs[0].shape[0].value)) + axes.reverse() + for i, axis in enumerate(axes): + if axis < 0: + axes[i] = inputs[0].shape[0] - axis + return [_transpose_shape_func(inputs[0], convert(axes))] + +@script +def _squeeze_shape_func(data_shape, keep_axes): + out = output_tensor((len(keep_axes),), "int64") + if len(keep_axes) == 0: + out_size = 0 + for i in const_range(data_shape.shape[0]): + if data_shape[i] != 1: + out_size += 1 + + if out_size == 0: + out_size = 1 + out = output_tensor((out_size,), "int64") + out[0] = int64(1) + pos = 0 + for i in const_range(data_shape.shape[0]): + if data_shape[i] != 1: + out[pos] = data_shape[i] + pos += 1 + else: + for i in const_range(len(keep_axes)): + out[i] = data_shape[keep_axes[i]] + + return out + +@_reg.register_shape_func("squeeze", False) +def squeeze_shape_func(attrs, inputs, _): + """ + Shape function for squeeze op. + """ + axis = attrs.axis if attrs.axis is None else get_const_tuple(attrs.axis) + keep_axes = [] + if axis is not None: + for i in range(inputs[0].shape[0].value): + if i not in axis: + keep_axes.append(i) + + return [_squeeze_shape_func(inputs[0], convert(keep_axes))] + +@script +def _reshape_like_shape_func(target_shape): + out = output_tensor((target_shape.shape[0],), "int64") + for i in const_range(target_shape.shape[0]): + out[i] = target_shape[i] + + return out + +@_reg.register_shape_func("reshape_like", False) +def reshape_like_shape_func(attrs, inputs, _): + """ + Shape function for reshape_like op. + """ + return [_reshape_like_shape_func(inputs[1])] diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 891548036017..54f13c688151 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument +# pylint: disable=invalid-name, unused-argument, too-many-arguments """Backend compiler related feature registration""" from __future__ import absolute_import @@ -22,6 +22,9 @@ from topi.util import get_const_tuple from .. import op as reg from ..op import OpPattern, schedule_injective +from .._tensor import elemwise_shape_func +from ....api import convert +from ....hybrid import script # relu reg.register_schedule("nn.relu", schedule_injective) @@ -766,7 +769,6 @@ def schedule_bitserial_dense(attrs, outputs, target): reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE) - @reg.register_compute("nn.cross_entropy") def compute_cross_entropy(attrs, inputs, out_dtype, target): x, y = inputs @@ -775,8 +777,170 @@ def compute_cross_entropy(attrs, inputs, out_dtype, target): reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE) - @reg.register_compute("nn.cross_entropy_with_logits") def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target): x, y = inputs return [-topi.sum(x * y) / x.shape[0]] + +# shape func +@script +def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn): + out = output_tensor((dshape.shape[0],), "int64") + ic_chunk = dshape[1] + height = dshape[2] + width = dshape[3] + ic_bn = dshape[4] + kheight = kshape[2] + kwidth = kshape[3] + dilated_kh = (kheight - 1) * dilation[0] + 1 + dilated_kw = (kwidth - 1) * dilation[1] + 1 + kflatten = int64(1) + for i in const_range(kshape.shape[0]): + kflatten *= kshape[i] + + oc = kflatten // (kheight * kwidth * ic_chunk * ic_bn) + oc_chunk = oc // oc_bn + + out_height = (height + 2 * padding[0] - dilated_kh) // strides[0] + 1 + out_width = (width + 2 * padding[1] - dilated_kw) // strides[1] + 1 + + out[0] = dshape[0] + out[1] = oc_chunk + out[2] = out_height + out[3] = out_width + out[4] = int64(oc_bn) + return out + +@reg.register_shape_func("nn.contrib_conv2d_NCHWc", False) +def conv2d_NCHWc_shape_func(attrs, inputs, _): + """ + Shape function for contrib_conv2d_NCHWc op. + """ + strides = get_const_tuple(attrs.strides) + padding = get_const_tuple(attrs.padding) + dilation = get_const_tuple(attrs.dilation) + out_layout = attrs.out_layout + oc_bn = int(out_layout[4:-1]) + + return [_conv2d_NCHWc_shape_func(inputs[0], inputs[1], + convert(strides), convert(padding), + convert(dilation), convert(oc_bn))] + +@script +def _pool2d_shape_func(data_shape, pool_size, strides, + padding, height_axis, width_axis): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(data_shape.shape[0]): + if i == height_axis: + out[i] = (data_shape[i] + padding[0] + padding[2] - pool_size[0]) // strides[0] + 1 + elif i == width_axis: + out[i] = (data_shape[i] + padding[1] + padding[3] - pool_size[1]) // strides[1] + 1 + else: + out[i] = data_shape[i] + + return out + +def pool2d_shape_func(attrs, inputs, _): + """ + Shape function for pool2d op. + """ + pool_size = get_const_tuple(attrs.pool_size) + strides = get_const_tuple(attrs.strides) + padding = get_const_tuple(attrs.padding) + layout = attrs.layout + height_axis = layout.index("H") + width_axis = layout.index("W") + if len(padding) == 1: + padding = [padding[0]] * 4 + elif len(padding) == 2: + padding = [padding[0], padding[1], padding[0], padding[1]] + + return [_pool2d_shape_func(inputs[0], convert(pool_size), + convert(strides), convert(padding), + convert(height_axis), convert(width_axis))] + +reg.register_shape_func("nn.max_pool2d", False, pool2d_shape_func) +reg.register_shape_func("nn.avg_pool2d", False, pool2d_shape_func) + +@script +def _global_pool2d_shape_func(data_shape, height_axis, width_axis): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(out.shape[0]): + if i == height_axis or i == width_axis: + out[i] = int64(1) + else: + out[i] = data_shape[i] + + return out + +def global_pool2d_shape_func(attrs, inputs, _): + """ + Shape function for global pool2d op. + """ + layout = attrs.layout + height_axis = width_axis = 1 + for i, letter in enumerate(layout): + if letter == "H": + height_axis = i + if letter == "W": + width_axis = i + return [_global_pool2d_shape_func(inputs[0], convert(height_axis), convert(width_axis))] + +reg.register_shape_func("nn.global_max_pool2d", False, global_pool2d_shape_func) +reg.register_shape_func("nn.global_avg_pool2d", False, global_pool2d_shape_func) + +@script +def _batch_flatten_shape_func(data_shape): + out = output_tensor((2,), "int64") + out[0] = data_shape[0] + out[1] = int64(1) + for i in const_range(data_shape.shape[0] - 1): + out[1] *= data_shape[i + 1] + + return out + +@reg.register_shape_func("nn.batch_flatten", False) +def batch_flatten_shape_func(attrs, inputs, _): + """ + Shape function for batch_flatten op. + """ + return [_batch_flatten_shape_func(inputs[0])] + +@script +def _dense_shape_func(data_shape, weight_shape): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(out.shape[0] - 1): + out[i] = data_shape[i] + out[out.shape[0] - 1] = weight_shape[0] + + return out + +@reg.register_shape_func("nn.dense", False) +def dense_shape_func(attrs, inputs, _): + """ + Shape function for dense op. + """ + ret = [_dense_shape_func(inputs[0], inputs[1])] + return ret + +@script +def _pad_shape_func(data_shape, pad_width): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(out.shape[0]): + out[i] = data_shape[i] + pad_width[i][0] + pad_width[i][1] + + return out + +@reg.register_shape_func("nn.pad", False) +def pad_shape_func(attrs, inputs, _): + """ + Shape function for pad op. + """ + pad_width = [] + for pair in attrs.pad_width: + pad_width.append(get_const_tuple(pair)) + return [_pad_shape_func(inputs[0], convert(pad_width))] + +reg.register_shape_func("nn.bias_add", False, elemwise_shape_func) +reg.register_shape_func("nn.softmax", False, elemwise_shape_func) +reg.register_shape_func("nn.relu", False, elemwise_shape_func) diff --git a/src/lang/data_layout.cc b/src/lang/data_layout.cc index 7c76e40bf01c..35139bb2b87f 100644 --- a/src/lang/data_layout.cc +++ b/src/lang/data_layout.cc @@ -289,16 +289,22 @@ inline Array TransformShape(const Array& src_shape, // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule, // e.g., (C * 16 + c) / 32 std::unordered_map bind_map; + std::unordered_set symbolic_var_set; for (size_t i = 0; i < src_shape.size(); ++i) { Expr orig_shape = src_shape[i]; IterVar orig_axis = src_axis[i]; + if (orig_shape.as()) { + symbolic_var_set.insert(i); + } if (!LayoutAxis::Get(orig_axis).IsPrimal()) { if (orig_shape.defined()) { const auto* orig_shape_const = orig_shape.as(); const auto* orig_axis_extent = orig_axis->dom->extent.as(); - CHECK_EQ(orig_shape_const->value, orig_axis_extent->value) - << "Input shape mismatch at index " << i << ". Expected " - << orig_axis->dom->extent << ", get " << orig_shape; + if (orig_shape_const) { + CHECK_EQ(orig_shape_const->value, orig_axis_extent->value) + << "Input shape mismatch at index " << i << ". Expected " + << orig_axis->dom->extent << ", get " << orig_shape; + } } bind_map[orig_axis->var.get()] = Expr(0); } else { @@ -316,7 +322,11 @@ inline Array TransformShape(const Array& src_shape, if (!LayoutAxis::Get(axis).IsPrimal()) { result.push_back(axis->dom->extent); } else { - result.push_back(ir::Simplify(ir::Substitute(rule, bind_map))); + if (symbolic_var_set.count(i)) { + result.push_back(ir::Any::make()); + } else { + result.push_back(ir::Simplify(ir::Substitute(rule, bind_map))); + } } } return result; diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index bf4e54ba5ff0..002a246e210d 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -330,8 +330,19 @@ bool Conv2DWinogradRel(const Array& types, // dilation Array oshape({dshape_nchw[0], channels, 0, 0}); - oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1); - oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1); + if (!dshape_nchw[2].as()) { + oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 + - dilated_ksize_y) / param->strides[0] + 1); + } else { + oshape.Set(2, dshape_nchw[2]); + } + if (!dshape_nchw[3].as()) { + oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 + - dilated_ksize_x) / param->strides[1] + 1); + } else { + oshape.Set(3, dshape_nchw[3]); + } + DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { out_dtype = data->dtype; diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 602a091d3b5d..19b84dd2fe31 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -116,10 +116,19 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, // dilation Array oshape({dshape_nchw[0], channels, 0, 0}); - oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y, - param->strides[0]) + 1); - oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x, - param->strides[1]) + 1); + if (!dshape_nchw[2].as()) { + oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y, + param->strides[0]) + 1); + } else { + oshape.Set(2, dshape_nchw[2]); + } + + if (!dshape_nchw[3].as()) { + oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x, + param->strides[1]) + 1); + } else { + oshape.Set(3, dshape_nchw[3]); + } DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { out_dtype = data->dtype; diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 416a0d7b543f..d3a71787a837 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -408,7 +408,12 @@ bool BatchFlattenRel(const Array& types, auto target_dim = make_const(Int(32), 1); for (uint32_t i = 1; i < data->shape.size(); ++i) { - target_dim = target_dim * data->shape[i]; + if (!data->shape[i].as()) { + target_dim = target_dim * data->shape[i]; + } else { + target_dim = data->shape[i]; + break; + } } std::vector oshape({data->shape[0], target_dim}); diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 2342880063ad..d625f1976941 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -148,8 +148,12 @@ bool PadRel(const Array& types, << "Param width elements should be positive but first pad width at " << "index " << i << " is " << *width2 << "."; - auto padding = make_const(data->shape[i].type(), *width1 + *width2); - oshape.push_back(data->shape[i] + padding); + if (!data->shape[i].as()) { + auto padding = make_const(data->shape[i].type(), *width1 + *width2); + oshape.push_back(data->shape[i] + padding); + } else { + oshape.push_back(data->shape[i]); + } } reporter->Assign(types[1], TensorTypeNode::make(Array(oshape), diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 94f8a5442d6c..99d184fd8f65 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -102,14 +102,25 @@ bool Pool2DRel(const Array& types, oshape.push_back(e); } - if (param->ceil_mode) { - oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] + - param->strides[0] - 1) / param->strides[0]) + 1; - oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] + - param->strides[1] - 1) / param->strides[1]) + 1; + if (dshape[hidx].as()) { + oshape[hidx] = dshape[hidx]; } else { - oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1; - oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1; + if (param->ceil_mode) { + oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] + + param->strides[0] - 1) / param->strides[0]) + 1; + } else { + oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1; + } + } + if (dshape[widx].as()) { + oshape[widx] = dshape[widx]; + } else { + if (param->ceil_mode) { + oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] + + param->strides[1] - 1) / param->strides[1]) + 1; + } else { + oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1; + } } // assign output type diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 63524bc4e81d..3a2d4692f522 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -211,11 +211,20 @@ inline std::vector ReduceShapeImpl(const std::vector &in_s } auto max_shape = make_const(Int(64), 1); + bool is_dynamic_input = false; for (int64_t axis : r_axes) { - max_shape *= in_shape[axis]; + if (in_shape[axis].as()) { + max_shape *= in_shape[axis]; + } else { + is_dynamic_input = true; + break; + } + } + + if (is_dynamic_input) { + CHECK(reporter->Assert(max_shape < make_const(Int(64), std::numeric_limits::max()))) + << "The maximum possible index of reduced shape cannot be more than int32 max."; } - CHECK(reporter->Assert(max_shape < make_const(Int(64), std::numeric_limits::max()))) - << "The maximum possible index of reduced shape cannot be more than int32 max."; if (param->keepdims) { std::vector oshape(in_shape); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e1239ae5b9e2..203a0411d3c4 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -797,8 +797,18 @@ bool ReshapeLikeRel(const Array& types, if (reshape_like == nullptr) { return false; } - CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) - << "Reshape inputs size should be compatible."; + // Only check When input data has static shape. + bool is_static_shape = true; + for (size_t i = 0; i < data->shape.size(); ++i) { + if (!data->shape[i].as()) { + is_static_shape = false; + break; + } + } + if (is_static_shape) { + CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) + << "Reshape inputs size should be compatible."; + } reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype)); return true; } @@ -2292,6 +2302,8 @@ RELAY_REGISTER_OP("slice_like") .set_attr("TOpPattern", kInjective); // relay.layout_transform +TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); + Array LayoutTransformCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 891cfad4d1f4..75be88cbcb19 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -188,6 +188,257 @@ def test_any_shape_of(): result = ex.evaluate()(data) tvm.testing.assert_allclose(result.asnumpy(), np.array(3).astype("int64")) +def verify_any_reduce(reduce_op, data_shape, axis, exclude, keepdims, + static_data_shape, ref_out_shape): + mod = relay.Module() + dtype = "bool" if reduce_op == relay.all else "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + y = reduce_op(data, axis, keepdims, exclude) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + +def test_any_reduce(): + verify_any_reduce(relay.argmax, any_dims(3), None, False, False, (3, 4, 5), ()) + verify_any_reduce(relay.argmin, any_dims(4), 1, False, True, (3, 4, 5, 6), (3, 1, 5, 6)) + verify_any_reduce(relay.all, any_dims(3), (1, 2), True, False, (3, 4, 5), (4, 5)) + verify_any_reduce(relay.max, any_dims(4), -1, True, True, (3, 4, 5, 6), (1, 1, 1, 6)) + verify_any_reduce(relay.min, any_dims(3), (0, 1), False, False, (4, 5, 6), (6,)) + verify_any_reduce(relay.prod, any_dims(4), 2, True, True, (3, 4, 5, 6), (1, 1, 5, 1)) + verify_any_reduce(relay.mean, any_dims(2), 0, False, False, (1, 2), (2,)) + verify_any_reduce(relay.variance, any_dims(5), (2, 4), False, False, (3, 4, 5, 6, 7), (3, 4, 6)) + +def verify_any_layout_transform(data_shape, src_layout, dst_layout, static_data_shape, ref_out_shape): + mod = relay.Module() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.layout_transform(data, src_layout, dst_layout) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + +def test_any_layout_transform(): + verify_any_layout_transform(any_dims(4), "NCHW", "NHWC", (3, 4, 5, 6), (3, 5, 6, 4)) + verify_any_layout_transform(any_dims(5), "NCHW16c", "NCHW2c", (1, 2, 8, 8, 16), (1, 16, 8, 8, 2)) + verify_any_layout_transform(any_dims(5), "NCHW6n", "NHWC", (3, 4, 5, 6, 6), (18, 5, 6, 4)) + verify_any_layout_transform(any_dims(4), "NCHW", "NCHW4c", (3, 4, 5, 6), (3, 1, 5, 6, 4)) + verify_any_layout_transform((16, 1), "CH", "C4cH", (16, 1), (4, 4, 1)) + +def verify_any_expand_dims(data_shape, axis, num_newaxis, static_data_shape, ref_out_shape): + mod = relay.Module() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.expand_dims(data, axis=axis, num_newaxis=num_newaxis) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + +def test_any_expand_dims(): + verify_any_expand_dims(any_dims(3), 1, 2, (1, 2, 3), (1, 1, 1, 2, 3)) + verify_any_expand_dims(any_dims(3), -1, 2, (1, 2, 3), (1, 2, 3, 1, 1)) + +def verify_any_transpose(data_shape, axes, static_data_shape): + mod = relay.Module() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.transpose(data, axes=axes) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + ref_out = np.transpose(data_np, axes) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + tvm.testing.assert_allclose(result.asnumpy(), ref_out) + +def test_any_transpose(): + verify_any_transpose(any_dims(3), (1, 0, 2), (10, 3, 2)) + verify_any_transpose(any_dims(3), None, (2, 3, 4)) + verify_any_transpose(any_dims(6), (0, 1, 3, 2, 5, 4), (11, 12, 2, 1, 9, 17)) + +def verify_any_squeeze(data_shape, axis, static_data_shape): + mod = relay.Module() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.squeeze(data, axis=axis) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + ref_out = np.squeeze(data_np, axis) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + tvm.testing.assert_allclose(result.asnumpy(), ref_out) + +def test_any_squeeze(): + verify_any_squeeze((1, relay.Any(), relay.Any()), (0,), (1, 9, 8)) + verify_any_squeeze((1, relay.Any(), relay.Any(), 1, relay.Any(), relay.Any()), (0, 3), (1, 12, 2, 1, 9, 17)) + +def test_any_reshape_like(): + mod = relay.Module() + dtype = "float32" + data = relay.var('data', shape=(relay.Any(), 3, 10), dtype=dtype) + shape_like = relay.var('data', shape=(relay.Any(), 5, 6), dtype=dtype) + y = relay.reshape_like(data, shape_like) + mod["main"] = relay.Function([data, shape_like], y) + data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype) + shape_like_np = np.random.uniform(size=(3, 5, 6)).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np, shape_like_np) + assert result.asnumpy().shape == shape_like_np.shape, \ + "Shape mismatch: expect %s but got %s." % (str(shape_like_np.shape), str(result.asnumpy().shape)) + +def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation, + data_layout, kernel_layout, out_layout, + static_data_shape, ref_out_shape): + mod = relay.Module() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + kernel = relay.var('kernel', shape=kernel_shape, dtype=dtype) + y = relay.nn.contrib_conv2d_nchwc(data, kernel, strides, padding, dilation, + kernel_size=kernel_shape[2:4], + channels=kernel_shape[0]*kernel_shape[-1], + data_layout=data_layout, kernel_layout=kernel_layout, + out_layout=out_layout) + mod["main"] = relay.Function([data, kernel], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np, kernel_np) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + +def test_any_conv2d_NCHWc(): + verify_any_conv2d_NCHWc((relay.Any(), 8, relay.Any(), relay.Any(), 8), (8, 8, 3, 3, 8, 8), (1, 1), (1, 1), (1, 1), + "NCHW8c", "OIHW8i8o", "NCHW8c", (1, 8, 224, 224, 8), (1, 8, 224, 224, 8)) + verify_any_conv2d_NCHWc((relay.Any(), 8, relay.Any(), relay.Any(), 8), (8, 8, 3, 3, 8, 8), (1, 1), (1, 1), (2, 2), + "NCHW8c", "OIHW8i8o", "NCHW8c", (1, 8, 224, 224, 8), (1, 8, 222, 222, 8)) + +def verify_any_pool2d(pool_type, data_shape, pool_size, strides, padding, + layout, static_data_shape, ref_out_shape): + mod = relay.Module() + dtype = "float32" + pool_func = relay.nn.max_pool2d if pool_type == "max" else relay.nn.avg_pool2d + data = relay.var('data', shape=data_shape, dtype=dtype) + y = pool_func(data, pool_size, strides, padding, layout) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + +def test_any_pool2d(): + verify_any_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()), + (3, 3), (1, 1), (1, 1), "NCHW", (2, 3, 220, 220), (2, 3, 220, 220)) + verify_any_pool2d("avg", (relay.Any(), relay.Any(), relay.Any(), 4), + (1, 1), (2, 2), (0, 0), "NHWC", (3, 220, 220, 4), (3, 110, 110, 4)) + verify_any_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4), + (3, 3), (2, 2), (1, 1), "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 110, 110, 4)) + +def verify_any_global_pool2d(pool_type, data_shape, layout, static_data_shape, ref_out_shape): + mod = relay.Module() + dtype = "float32" + pool_func = relay.nn.global_max_pool2d if pool_type == "max" else relay.nn.global_avg_pool2d + data = relay.var('data', shape=data_shape, dtype=dtype) + y = pool_func(data, layout) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + +def test_any_global_pool2d(): + verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()), + "NCHW", (2, 3, 220, 220), (2, 3, 1, 1)) + verify_any_global_pool2d("avg", (relay.Any(), relay.Any(), relay.Any(), 4), + "NHWC", (3, 220, 220, 4), (3, 1, 1, 4)) + verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4), + "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 1, 1, 4)) + +def test_any_batch_flatten(): + mod = relay.Module() + dtype = "float32" + data = relay.var('data', shape=any_dims(3), dtype=dtype) + y = relay.nn.batch_flatten(data) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype) + ref_out_shape = (3, 30) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + +def verify_any_dense(data_shape, weight_shape, units, static_data_shape, + static_weight_shape, ref_out_shape): + mod = relay.Module() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + weight = relay.var('weight', shape=weight_shape, dtype=dtype) + y = relay.nn.dense(data, weight, units) + mod["main"] = relay.Function([data, weight], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + weight_np = np.random.uniform(size=static_weight_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np, weight_np) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + +def test_any_dense(): + verify_any_dense(any_dims(2), any_dims(2), None, (4, 16), (8, 16), (4, 8)) + verify_any_dense(any_dims(2), (50, relay.Any()), 50, (4, 40), (50, 40), (4, 50)) + +def verify_any_pad(data_shape, pad_width, static_data_shape): + mod = relay.Module() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.nn.pad(data, pad_width) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + ref_out = np.pad(data_np, pad_width) + tvm.testing.assert_allclose(result.asnumpy(), ref_out) + +def test_any_pad(): + verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3)) + verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1)) + +def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape): + mod = relay.Module() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.nn.softmax(data, axis) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + +def test_any_softmax(): + verify_any_softmax(any_dims(3), -1, (1, 2, 3), (1, 2, 3)) + verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1)) + def test_fused_ops(): x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32') y0 = x + relay.const(1.0, 'float32') @@ -308,6 +559,19 @@ def _body(i, st): test_any_reshape() test_any_take() test_any_shape_of() + test_any_reduce() + test_any_layout_transform() + test_any_expand_dims() + test_any_transpose() + test_any_squeeze() + test_any_reshape_like() + test_any_conv2d_NCHWc() + test_any_pool2d() + test_any_global_pool2d() + test_any_batch_flatten() + test_any_dense() + test_any_pad() + test_any_softmax() test_fused_ops() test_arange_with_dynamic_shape() test_recursive_concat() diff --git a/topi/include/topi/nn/flatten.h b/topi/include/topi/nn/flatten.h index ea4f62e9fc13..d04454701aec 100644 --- a/topi/include/topi/nn/flatten.h +++ b/topi/include/topi/nn/flatten.h @@ -52,9 +52,9 @@ inline Tensor flatten(const Tensor& x, std::string name = "tensor", std::string tag = kInjective) { auto ishape = x->shape; - int dim = 1; + Expr dim = 1; for (size_t i = 1; i < ishape.size(); ++i) { - dim = dim * static_cast(topi::detail::GetConstInt(ishape[i])); + dim = dim * ishape[i]; } Array oshape({ ishape[0], dim }); diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 1bf3a102a88f..623d06ae07c8 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -144,7 +144,7 @@ def equal_const_int(expr, value): def get_const_tuple(in_tuple): - """Verifies input tuple is IntImm, returns tuple of int. + """Verifies input tuple is IntImm or Var, returns tuple of int or Var. Parameters ---------- @@ -156,7 +156,17 @@ def get_const_tuple(in_tuple): out_tuple : tuple of int The output. """ - return tuple(get_const_int(elem) for elem in in_tuple) + ret = [] + for elem in in_tuple: + if isinstance(elem, tvm.expr.Var): + ret.append(elem) + elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)): + elem = tvm.ir_pass.Simplify(elem) + if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)): + ret.append(elem) + else: + ret.append(get_const_int(elem)) + return tuple(ret) def get_float_tuple(in_tuple): diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 9ea93cd6b647..0e284da17ee6 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -41,6 +41,13 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth """ Get default schedule config for the workload """ + static_data_shape = [] + for dim in get_const_tuple(data.shape): + if isinstance(dim, tvm.expr.Var): + static_data_shape.append(1) + else: + static_data_shape.append(dim) + data = tvm.placeholder(static_data_shape, dtype=data.dtype) if is_depthwise: wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype) from .depthwise_conv2d import _fallback_schedule diff --git a/topi/python/topi/x86/dense.py b/topi/python/topi/x86/dense.py index ec401bf0dd03..2a739d5c5d8f 100644 --- a/topi/python/topi/x86/dense.py +++ b/topi/python/topi/x86/dense.py @@ -37,6 +37,12 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None): return C M, _ = get_const_tuple(data.shape) + # Always use dense_nopack for dynamic input. + # This is a temporary for CV models. + # TODO(kevinthesun): use kernel dispatcher instead. + if isinstance(M, tvm.expr.Var): + return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype) + # For small batch sizes, don't pack weight into cache-friendly layout # because of overhead in packing and limited reuse from batch dimension # TODO(icemelon9): use a more systematic way to determine which schedule to use @@ -53,9 +59,9 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None): M, K = get_const_tuple(data.shape) # batch, in_dim N, _ = get_const_tuple(weight.shape) # out_dim # create tuning space - cfg.define_split("tile_y", M, num_outputs=3) - cfg.define_split("tile_x", N, num_outputs=3) - cfg.define_split("tile_k", K, num_outputs=2) + cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=2) + cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=2) + cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2) if cfg.is_fallback: _default_dense_pack_config(cfg, M, N, K) @@ -87,9 +93,9 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None): M, K = get_const_tuple(data.shape) N, _ = get_const_tuple(weight.shape) # create tuning space - cfg.define_split("tile_y", M, num_outputs=2) - cfg.define_split("tile_x", N, num_outputs=2) - cfg.define_split("tile_k", K, num_outputs=2) + cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=2) + cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=2) + cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2) if cfg.is_fallback: _default_dense_nopack_config(cfg, M, N, K) @@ -211,8 +217,15 @@ def _schedule_dense_nopack_template(cfg, s, C): def _default_dense_pack_config(cfg, M, N, K): - vec_width = get_fp32_len() + # Generate default schedule for dynamic shape. + if isinstance(M, tvm.expr.Var): + M = 16 + if isinstance(N, tvm.expr.Var): + N = 16 + if isinstance(K, tvm.expr.Var): + K = 16 + vec_width = get_fp32_len() tilex_ii = 1 for bn in range(vec_width*2, 0, -1): if N % bn == 0: @@ -241,6 +254,14 @@ def _default_dense_pack_config(cfg, M, N, K): def _default_dense_nopack_config(cfg, M, N, K): + # Generate default schedule for dynamic shape. + if isinstance(M, tvm.expr.Var): + M = 16 + if isinstance(N, tvm.expr.Var): + N = 16 + if isinstance(K, tvm.expr.Var): + K = 16 + vec_width = get_fp32_len() tilek_bn = 1 for bn in range(vec_width*2, 0, -1):