Skip to content

Commit

Permalink
[Relay] Fix shape func for strided slice (#10418)
Browse files Browse the repository at this point in the history
* fix dyn strided slice

* add tests

* remove stuff

* jostle ci

* jostle ci

* jostle
  • Loading branch information
AndrewZhaoLuo committed Mar 2, 2022
1 parent fdbb88f commit a5cb76a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _strided_slice_shape_func_with_axes(data_shape, begin, end, strides, slice_m
else:
cend = cbegin + int64(end[i])
else:
if end[i] > data_shape[i]:
if end[i] > data_shape[axes[i]]:
cend = dim_size
else:
cend = int64(end[i])
Expand Down
30 changes: 24 additions & 6 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
import numpy as np
import numpy.random
import pytest

import tvm
import tvm.testing
import tvm.topi.testing

from tvm import relay, te
from tvm.relay import transform
from tvm.relay.testing import run_infer_type
Expand Down Expand Up @@ -448,14 +446,20 @@ def verify(
slice_mode="end",
test_ref=True,
dtype="int32",
unknown_dim_value=10,
):
x = relay.var("x", relay.TensorType(dshape, "float32"))
ndim = len(dshape)
begin = begin if begin else [0] * ndim
end = end if end else list(dshape)

# target numpy result
# Resolve unknown dimensions to create test case:
dshape = list(dshape)
for i, d in enumerate(dshape):
if not isinstance(d, int):
dshape[i] = unknown_dim_value
x_data = np.random.uniform(size=dshape).astype("float32")

ref_res = tvm.topi.testing.strided_slice_python(
x_data,
begin,
Expand Down Expand Up @@ -484,9 +488,8 @@ def verify(
if not test_ref:
return
for target, dev in tvm.testing.enabled_targets():
op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
x_data
)
# Need VM to run tests with non-static dimensions
op_res = relay.create_executor("vm", device=dev, target=target).evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.numpy(), ref_res)

verify((1, 3, 10, 10), [0, 0, 0, 0], [-1, 3, 10, 10], [1], (0, 3, 10, 10), dtype="int64")
Expand All @@ -498,6 +501,7 @@ def verify(
(1, 120, 120, 3),
dtype="int64",
)

verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16")
verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
Expand All @@ -506,6 +510,7 @@ def verify(
verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))

# Test backwards slicing.
verify((3, 4, 3), [-1, -1, -1], [-5, -5, -5], [-1, -1, -1], (3, 4, 3))
# Test slicing with overlarge indices.
Expand All @@ -514,9 +519,22 @@ def verify(
verify(
(3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
)

verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True)
verify((3, 4, 3), [1], [4], None, None, axes=[1])

# Test Any dims for simple cases
verify((3, relay.Any()), [0], [1], [1], None, axes=[1], unknown_dim_value=10)
verify((relay.Any(), 3), [0], [1], [1], None, axes=[1], unknown_dim_value=10)
verify(
(relay.Any(), relay.Any(), relay.Any()),
[0, 1, 2],
[5, 5, 5],
[1, 2, 1],
None,
unknown_dim_value=10,
)


@tvm.testing.uses_gpu
def test_dyn_strided_slice():
Expand Down

0 comments on commit a5cb76a

Please sign in to comment.