Skip to content

Commit

Permalink
Properly test axes argument in relay tests
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jun 2, 2021
1 parent bd5ae6c commit 88c79f3
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 24 deletions.
15 changes: 7 additions & 8 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,33 +252,32 @@ def _strided_slice_shape_func_with_axes(data_shape, begin, end, strides, slice_m
out[i] = data_shape[i]

for i in const_range(len(axes)):
axis = int64(axes[i])
cbegin = int64(0)
cend = int64(data_shape[axis])
cend = int64(data_shape[axes[i]])
cstride = int64(1)
if len(strides) > i:
cstride = int64(strides[i])
if len(begin) > i:
cbegin = int64(begin[i])
if cbegin < 0:
cbegin += int64(data_shape[axis])
cbegin += int64(data_shape[axes[i]])
if len(end) <= i:
cend = int64(data_shape[axis])
cend = int64(data_shape[axes[i]])
elif slice_mode != 0:
cstride = int64(1)
if end[i] < 0:
cend = int64(data_shape[axis])
cend = int64(data_shape[axes[i]])
else:
cend = cbegin + int64(end[i])
else:
if end[i] > data_shape[i]:
cend = int64(data_shape[axis])
cend = int64(data_shape[axes[i]])
elif end[i] < -data_shape[i]:
cend = int64(-1)
else:
cend = int64(end[i])
if cend < 0:
cend += int64(data_shape[axis])
cend += int64(data_shape[axes[i]])
assert cstride != 0, "Strides can't be zero."
if cstride < 0:
slice_range = cbegin - cend
Expand All @@ -287,7 +286,7 @@ def _strided_slice_shape_func_with_axes(data_shape, begin, end, strides, slice_m
slice_range = cend - cbegin
step = cstride

out[axis] = int64(ceil_div(slice_range, step))
out[axes[i]] = int64(ceil_div(slice_range, step))
return out


Expand Down
15 changes: 9 additions & 6 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def split(data, indices_or_sections, axis=0):
return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size)


def strided_slice(data, begin, end, strides=None, slice_mode="end", axes=None):
def strided_slice(data, begin, end, strides=None, axes=None, slice_mode="end"):
"""Strided slice of an array.
Parameters
Expand All @@ -885,18 +885,19 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end", axes=None):
Specifies the stride values, it can be negative in that case,
the input tensor will be reversed in that particular axis.
axes : Tuple[int] or List[int], optional
Axes along which slicing is applied. When it is specified, the length of begin, end,
strides, and axes must be equal. Moreover, begin, end, strides, and axes must be
static (cannot be relay.Expr). Axes argument for dynamic parameter slicing is
not supported yet.
slice_mode : str, optional
The slice mode [end, size].
end: The ending indices for the slice [default].
size: The input strides will be ignored, input end in this mode indicates
the size of a slice starting at the location specified by begin. If end[i]
is -1, all remaining elements in that dimension are included in the slice.
axes : Tuple[int] or List[int], optional
Axes along which slicing is applied. When it is specified, the length of begin, end,
strides, and axes must be equal. Moreover, begin, end, strides, and axes must be
static (cannot be relay.Expr).
Returns
-------
ret : relay.Expr
Expand All @@ -921,6 +922,8 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end", axes=None):
ishape_slice = slice_like(ishape, begin)
begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin)
begin = _make.where(begin >= ishape_slice, ishape_slice, begin)
# TODO(masahi): Support axes argument in dynamic strided slice
assert axes is None, "Axes argument for dynamic parameter slicing is not supported yet."
return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)
return _make.strided_slice(data, begin, end, strides, slice_mode, axes)

Expand Down
22 changes: 21 additions & 1 deletion python/tvm/topi/testing/strided_slice_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""strided_slice/set in python"""


def strided_slice_python(data, begin, end, strides, slice_mode="end"):
def strided_slice_python(data, begin, end, strides, axes=None, slice_mode="end"):
"""Python version of strided slice operator.
Parameters
Expand All @@ -34,6 +34,9 @@ def strided_slice_python(data, begin, end, strides, slice_mode="end"):
strides : list
The stride of each slice.
axes : list, optional
Axes along which slicing is applied
slice_mode : str, optional
The slice mode [end, size].
end: The default slice mode, ending indices for the slice.
Expand All @@ -48,6 +51,22 @@ def strided_slice_python(data, begin, end, strides, slice_mode="end"):
The sliced result.
"""
strides = [] if strides is None else strides
if axes is not None:
rank = len(data.shape)
new_begin = [0] * rank
new_end = [data.shape[i] for i in range(rank)]
new_strides = [1] * rank

for i, axis in enumerate(axes):
new_begin[axis] = begin[i]
new_end[axis] = end[i]
if len(strides) > i:
new_strides[axis] = strides[i]

begin = new_begin
end = new_end
strides = new_strides

slices = []
for i in range(len(data.shape)):
new_stride = None
Expand All @@ -66,6 +85,7 @@ def strided_slice_python(data, begin, end, strides, slice_mode="end"):
new_end = end[i]

slices.append(slice(new_begin, new_end, new_stride))

return data[tuple(slices)]


Expand Down
26 changes: 17 additions & 9 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def verify(
end,
strides,
output,
axis=None,
axes=None,
slice_mode="end",
test_ref=True,
dtype="int32",
Expand All @@ -397,12 +397,16 @@ def verify(

# target numpy result
x_data = np.random.uniform(size=dshape).astype("float32")
ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode)
ref_res = tvm.topi.testing.strided_slice_python(
x_data, begin, end, strides, axes, slice_mode
)

if strides:
z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode)
z = relay.strided_slice(
x, begin=begin, end=end, strides=strides, axes=axes, slice_mode=slice_mode
)
else:
z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode)
z = relay.strided_slice(x, begin=begin, end=end, axes=axes, slice_mode=slice_mode)
func = relay.Function([x], z)

func = run_infer_type(func)
Expand Down Expand Up @@ -446,7 +450,7 @@ def verify(
(3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
)
verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True)
verify((3, 4, 3), [1], [4], None, None, axis=[1])
verify((3, 4, 3), [1], [4], None, None, axes=[1])


@tvm.testing.uses_gpu
Expand All @@ -469,16 +473,20 @@ def verify(

# target numpy result
x_data = np.random.uniform(size=dshape).astype("float32")
ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode)
ref_res = tvm.topi.testing.strided_slice_python(
x_data, begin, end, strides, axes, slice_mode
)

if ishape is None:
ishape = (relay.Any(),) * ndim

x = relay.var("x", relay.TensorType(ishape, "float32"))
if strides:
z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode)
z = relay.strided_slice(
x, begin=begin, end=end, strides=strides, axes=axes, slice_mode=slice_mode
)
else:
z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode)
z = relay.strided_slice(x, begin=begin, end=end, axes=axes, slice_mode=slice_mode)
func = relay.Function([x], z)

func = run_infer_type(func)
Expand Down Expand Up @@ -518,7 +526,7 @@ def verify(
(3, 4, 3, 2),
[1, 0],
[3, 1],
None,
[1, 1],
None,
axes=[1, 3],
ishape=(relay.Any(), 4, relay.Any(), 2),
Expand Down

0 comments on commit 88c79f3

Please sign in to comment.