Skip to content

Commit

Permalink
Add More Shape Functions (#4179)
Browse files Browse the repository at this point in the history
* Add shape functions

* Fix get_const_tuple

* Fix cpplint

* Fix pylint

* Fix pylint

* rebase and fix

* Check Any for infer type

* Fix expand_dim shape func for zero rank input

* Fix pooling infer type

* Address comment

* Register layout transform attr
  • Loading branch information
kevinthesun authored and icemelon committed Nov 11, 2019
1 parent 10b77ef commit 6252145
Show file tree
Hide file tree
Showing 19 changed files with 864 additions and 58 deletions.
6 changes: 3 additions & 3 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,15 @@ def args_to_workload(x, topi_compute_func=None):
workload = get_const_tuple(x.shape) + (x.dtype, )
elif isinstance(x, (tuple, list, container.Array)):
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)):
elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
workload = x
elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
workload = x.value
elif x is None:
workload = 0
else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use '
'primitive types only' % type(x))
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types or tvm.expr.Var only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload

def template(func):
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/autotvm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def get_const_int(exp):


def get_const_tuple(in_tuple):
"""Verifies input tuple is IntImm, returns tuple of int.
"""Verifies input tuple is IntImm or Var, returns tuple of int or Var.
Parameters
----------
Expand All @@ -175,4 +175,14 @@ def get_const_tuple(in_tuple):
out_tuple : tuple of int
The output.
"""
return tuple(get_const_int(x) for x in in_tuple)
ret = []
for elem in in_tuple:
if isinstance(elem, expr.Var):
ret.append(elem)
elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)):
elem = ir_pass.Simplify(elem)
if not isinstance(elem, (expr.IntImm, expr.UIntImm)):
ret.append(elem)
else:
ret.append(get_const_int(elem))
return tuple(ret)
68 changes: 68 additions & 0 deletions python/tvm/relay/op/_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from __future__ import absolute_import

import topi

from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from ...api import convert
from ...hybrid import script


def _schedule_reduce(_, outs, target):
Expand All @@ -39,3 +43,67 @@ def _schedule_reduce(_, outs, target):
_reg.register_schedule("variance", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)


def _create_axis_record(attrs, inputs):
axes = attrs.axis if attrs.axis is None else list(get_const_tuple(attrs.axis))
exclude = get_const_int(attrs.exclude) > 0
keepdims = get_const_int(attrs.keepdims) > 0
data_shape = inputs[0]
shape_size = data_shape.shape[0].value
axis_record = [-1] * shape_size
if axes is None:
axes = list(range(shape_size))

for i, axis in enumerate(axes):
if axis < 0:
axes[i] = shape_size + axis

if exclude:
ex_axes = []
for i in range(shape_size):
if i not in axes:
ex_axes.append(i)
axes = ex_axes

for i in range(shape_size):
if i not in axes:
axis_record[i] = i

if not keepdims:
tmp = []
for i in axis_record:
if i >= 0:
tmp.append(i)
axis_record = tmp

return axis_record


@script
def _reduce_shape_func(data_shape, axis_record):
out = output_tensor((len(axis_record),), "int64")
for i in const_range(len(axis_record)):
if axis_record[i] >= 0:
out[i] = data_shape[axis_record[i]]
else:
out[i] = int64(1)

return out

def reduce_shape_func(attrs, inputs, _):
"""
Shape function for reduce op.
"""
axis_record = _create_axis_record(attrs, inputs)
return [_reduce_shape_func(inputs[0], convert(axis_record))]

_reg.register_shape_func("argmax", False, reduce_shape_func)
_reg.register_shape_func("argmin", False, reduce_shape_func)
_reg.register_shape_func("all", False, reduce_shape_func)
_reg.register_shape_func("sum", False, reduce_shape_func)
_reg.register_shape_func("max", False, reduce_shape_func)
_reg.register_shape_func("min", False, reduce_shape_func)
_reg.register_shape_func("prod", False, reduce_shape_func)
_reg.register_shape_func("mean", False, reduce_shape_func)
_reg.register_shape_func("variance", False, reduce_shape_func)
25 changes: 12 additions & 13 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,6 @@ def _cast_shape_function(x):
def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]

@script
def _expand_dims_shape_func(x):
ndim = len(x.shape)
out = output_tensor((ndim+1,), "int64")
out[0] = int64(1)
for i in const_range(0, ndim):
out[i+1] = int64(x.shape[i])
return out

def expand_dims_shape_func(attrs, inputs, out_ndims):
return [_expand_dims_shape_func(*inputs)]

# shape func
@script
def _broadcast_shape_func(x, y, ndim):
Expand Down Expand Up @@ -161,9 +149,17 @@ def _broadcast_shape_func(x, y, ndim):
return out

def broadcast_shape_func(attrs, inputs, out_ndims):
"""
Shape function for broadcast op.
"""
return [_broadcast_shape_func(*inputs, out_ndims[0])]

register_shape_func("expand_dims", False, expand_dims_shape_func)
def elemwise_shape_func(attrs, inputs, _):
"""
Shape function for elemwise op.
"""
return [topi.math.identity(inputs[0])]

register_shape_func("cast", False, cast_shape_func)

register_shape_func("add", False, broadcast_shape_func)
Expand All @@ -179,3 +175,6 @@ def broadcast_shape_func(attrs, inputs, out_ndims):
register_shape_func("less_equal", False, broadcast_shape_func)
register_shape_func("greater", False, broadcast_shape_func)
register_shape_func("greater_equal", False, broadcast_shape_func)

register_shape_func("sqrt", False, elemwise_shape_func)
register_shape_func("negative", False, elemwise_shape_func)
194 changes: 193 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
from __future__ import absolute_import
import tvm
import topi
Expand Down Expand Up @@ -303,3 +303,195 @@ def compute_argwhere(attrs, inputs, output_type, _):
output_shape.append(tvm.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]

@script
def _layout_transform_shape_func(data_shape,
out_layout_len,
dst_equal_list,
dst_mul_list,
dst_div_list,
dst_mix_list):
out = output_tensor((out_layout_len,), "int64")
for i in const_range(len(dst_equal_list)):
out[dst_equal_list[i][0]] = data_shape[dst_equal_list[i][1]]
for i in const_range(len(dst_mul_list)):
out[dst_mul_list[i][0]] = data_shape[dst_mul_list[i][1]] * \
data_shape[dst_mul_list[i][2]]
for i in const_range(len(dst_div_list)):
out[dst_div_list[i][0]] = data_shape[dst_div_list[i][1]] \
// dst_div_list[i][3]
out[dst_div_list[i][2]] = int64(dst_div_list[i][3])
for i in const_range(len(dst_mix_list)):
out[dst_mix_list[i][0]] = data_shape[dst_mix_list[i][1]] * \
dst_mix_list[i][2] // dst_mix_list[i][4]
out[dst_mix_list[i][3]] = int64(dst_mix_list[i][4])

return out

@_reg.register_shape_func("layout_transform", False)
def layout_transform_shape_func(attrs, inputs, _):
"""
Shape function for layout_transform op.
"""
def _fetch_axis(layout):
major_axes = []
minor_axes = {}
num_start = -1
for i, item in enumerate(layout):
if "A" <= item <= "Z":
major_axes.append(item)
elif "a" <= item <= "z":
last_num = int(layout[num_start:i])
minor_axes[item] = last_num
num_start = -1
elif num_start < 0:
num_start = i
return major_axes, minor_axes

_, src_minor_axes = _fetch_axis(attrs.src_layout)
dst_major_axes, dst_minor_axes = _fetch_axis(attrs.dst_layout)
src_letter_list = []
dst_letter_list = []
for item in attrs.src_layout:
if "A" <= item <= "Z" or "a" <= item <= "z":
src_letter_list.append(item)
for item in attrs.dst_layout:
if "A" <= item <= "Z" or "a" <= item <= "z":
dst_letter_list.append(item)
out_layout_len = len(dst_major_axes) + len(dst_minor_axes)
dst_equal_list = []
dst_mul_list = []
dst_div_list = []
dst_mix_list = []

for key in dst_major_axes:
if key.lower() not in dst_minor_axes:
if key.lower() not in src_minor_axes:
dst_equal_list.append((dst_letter_list.index(key),
src_letter_list.index(key)))
else:
dst_mul_list.append((dst_letter_list.index(key),
src_letter_list.index(key),
src_letter_list.index(key.lower())))
else:
if key.lower() not in src_minor_axes:
dst_div_list.append((dst_letter_list.index(key),
src_letter_list.index(key),
dst_letter_list.index(key.lower()),
dst_minor_axes[key.lower()]))
else:
dst_mix_list.append((dst_letter_list.index(key),
src_letter_list.index(key),
src_minor_axes[key.lower()],
dst_letter_list.index(key.lower()),
dst_minor_axes[key.lower()]))

return [_layout_transform_shape_func(inputs[0],
convert(out_layout_len),
convert(dst_equal_list),
convert(dst_mul_list),
convert(dst_div_list),
convert(dst_mix_list))]

@script
def _expand_dim_shape_func(data_shape, ndim, axis, num_newaxis):
out = output_tensor((ndim + num_newaxis,), "int64")
for i in const_range(out.shape[0]):
if i < axis:
out[i] = data_shape[i]
elif i < axis + num_newaxis:
out[i] = int64(1)
else:
out[i] = data_shape[i - num_newaxis]

return out

@_reg.register_shape_func("expand_dims", False)
def expand_dim_shape_func(attrs, inputs, _):
"""
Shape function for expand_dim op.
"""
axis = get_const_int(attrs.axis)
num_newaxis = get_const_int(attrs.num_newaxis)
if axis < 0:
axis = inputs[0].shape[0] + axis + 1
ndim = inputs[0].shape[0] if inputs[0].shape else 0
return [_expand_dim_shape_func(inputs[0],
convert(ndim),
convert(axis),
convert(num_newaxis))]

@script
def _transpose_shape_func(data_shape, axes):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(len(axes)):
out[i] = data_shape[axes[i]]

return out

@_reg.register_shape_func("transpose", False)
def transpose_shape_func(attrs, inputs, _):
"""
Shape function for transpose op.
"""
axes = attrs.axes if attrs.axes is None else get_const_tuple(attrs.axes)
if axes is None:
axes = list(range(inputs[0].shape[0].value))
axes.reverse()
for i, axis in enumerate(axes):
if axis < 0:
axes[i] = inputs[0].shape[0] - axis
return [_transpose_shape_func(inputs[0], convert(axes))]

@script
def _squeeze_shape_func(data_shape, keep_axes):
out = output_tensor((len(keep_axes),), "int64")
if len(keep_axes) == 0:
out_size = 0
for i in const_range(data_shape.shape[0]):
if data_shape[i] != 1:
out_size += 1

if out_size == 0:
out_size = 1
out = output_tensor((out_size,), "int64")
out[0] = int64(1)
pos = 0
for i in const_range(data_shape.shape[0]):
if data_shape[i] != 1:
out[pos] = data_shape[i]
pos += 1
else:
for i in const_range(len(keep_axes)):
out[i] = data_shape[keep_axes[i]]

return out

@_reg.register_shape_func("squeeze", False)
def squeeze_shape_func(attrs, inputs, _):
"""
Shape function for squeeze op.
"""
axis = attrs.axis if attrs.axis is None else get_const_tuple(attrs.axis)
keep_axes = []
if axis is not None:
for i in range(inputs[0].shape[0].value):
if i not in axis:
keep_axes.append(i)

return [_squeeze_shape_func(inputs[0], convert(keep_axes))]

@script
def _reshape_like_shape_func(target_shape):
out = output_tensor((target_shape.shape[0],), "int64")
for i in const_range(target_shape.shape[0]):
out[i] = target_shape[i]

return out

@_reg.register_shape_func("reshape_like", False)
def reshape_like_shape_func(attrs, inputs, _):
"""
Shape function for reshape_like op.
"""
return [_reshape_like_shape_func(inputs[1])]

0 comments on commit 6252145

Please sign in to comment.