Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/tvm/relax/op/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@

def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
"""Take elements from a tensor along an axis.
Its semantic is mostly similar to `numpy.take`
(https://numpy.org/doc/stable/reference/generated/numpy.take.html),
which can cover `torch.take` (https://pytorch.org/docs/stable/generated/torch.take.html) and
`onnx.gather` (https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13).

Parameters
----------
Expand All @@ -35,7 +39,6 @@ def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:

indices : relax.Expr
The indices of the values to extract.
It is required to be a one-dimensional tensor which has integer dtype.

axis : Optional[int]
The axis over which to select values.
Expand Down
27 changes: 16 additions & 11 deletions src/relax/op/tensor/index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,11 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
TensorStructInfo data_sinfo = input_sinfo[0];
TensorStructInfo indices_sinfo = input_sinfo[1];
if (indices_sinfo->ndim != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Take op requires the input indices to be 1-dimensional tensor. However, "
"the given indices ndim is "
<< indices_sinfo->ndim);
} else if (!indices_sinfo->IsUnknownDtype() &&
!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) {

if (indices_sinfo->IsUnknownDtype()) {
// TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving as LOG(WARNING) looks fine to me.

LOG(WARNING) << "Data type of indice has not been specified. Assume it has an integer type.";
} else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Take op requires the input indices to have integer dtype. However, the "
"given indices dtype is "
Expand All @@ -67,7 +65,7 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
"is not specified. However, the given data tensor has ndim "
<< data_sinfo->ndim);
}
if (data_sinfo->IsUnknownNdim()) {
if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) {
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}

Expand All @@ -77,11 +75,18 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
if (data_shape == nullptr || indices_shape == nullptr) {
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim + data_sinfo->ndim - 1);
}

Array<PrimExpr> output_shape = data_shape->values;
output_shape.Set(axis, indices_shape->values[0]);
Array<PrimExpr> output_shape;
for (int i = 0; i < data_sinfo->ndim; i++) {
if (i == axis) {
for (int j = 0; j < indices_sinfo->ndim; j++)
output_shape.push_back(indices_shape->values[j]);
} else {
output_shape.push_back(data_shape->values[i]);
}
}
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}

Expand Down
133 changes: 92 additions & 41 deletions tests/python/relax/test_op_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def test_take_infer_struct_info():
idx1 = relax.Var("idx", R.Tensor("int64", ndim=1))
idx2 = relax.Var("idx", R.Tensor((6,)))
idx3 = relax.Var("idx", R.Tensor(ndim=1))
idx4 = relax.Var("idx", R.Tensor((6, 4), "int64"))
idx5 = relax.Var("idx", R.Tensor("int64", ndim=2))
idx6 = relax.Var("idx", R.Tensor((6, 4)))
idx7 = relax.Var("idx", R.Tensor(ndim=2))

_check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float32"))
_check_inference(
Expand Down Expand Up @@ -93,6 +97,62 @@ def test_take_infer_struct_info():
_check_inference(bb, relax.op.take(x3, idx3, axis=1), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(x4, idx3, axis=1), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(x5, idx3, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(
bb, relax.op.take(x0, idx4, axis=0), relax.TensorStructInfo((6, 4, 10), dtype="float32")
)
_check_inference(
bb, relax.op.take(x0, idx4, axis=1), relax.TensorStructInfo((4, 6, 4), dtype="float32")
)
_check_inference(
bb, relax.op.take(x1, idx4, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(bb, relax.op.take(x2, idx4, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(
bb, relax.op.take(x3, idx4, axis=1), relax.TensorStructInfo((4, 6, 4), dtype="")
)
_check_inference(bb, relax.op.take(x4, idx4, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.take(x5, idx4, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(
bb, relax.op.take(x0, idx5, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(
bb, relax.op.take(x0, idx5, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(
bb, relax.op.take(x1, idx5, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(bb, relax.op.take(x2, idx5, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.take(x3, idx5, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.take(x4, idx5, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.take(x5, idx5, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(
bb, relax.op.take(x0, idx6, axis=0), relax.TensorStructInfo((6, 4, 10), dtype="float32")
)
_check_inference(
bb, relax.op.take(x0, idx6, axis=1), relax.TensorStructInfo((4, 6, 4), dtype="float32")
)
_check_inference(
bb, relax.op.take(x1, idx6, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(bb, relax.op.take(x2, idx6, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(
bb, relax.op.take(x3, idx6, axis=1), relax.TensorStructInfo((4, 6, 4), dtype="")
)
_check_inference(bb, relax.op.take(x4, idx6, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.take(x5, idx6, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(
bb, relax.op.take(x0, idx7, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(
bb, relax.op.take(x0, idx7, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(
bb, relax.op.take(x1, idx7, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(bb, relax.op.take(x2, idx7, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.take(x3, idx7, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.take(x4, idx7, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.take(x5, idx7, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((6,), "float32"))
_check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo(dtype="float32", ndim=1))
_check_inference(bb, relax.op.take(y2, idx0), relax.TensorStructInfo((6,), dtype=""))
Expand All @@ -109,13 +169,31 @@ def test_take_infer_struct_info():
_check_inference(bb, relax.op.take(y1, idx3), relax.TensorStructInfo(dtype="float32", ndim=1))
_check_inference(bb, relax.op.take(y2, idx3), relax.TensorStructInfo(dtype="", ndim=1))
_check_inference(bb, relax.op.take(y3, idx3), relax.TensorStructInfo(dtype="", ndim=1))
_check_inference(bb, relax.op.take(y0, idx4), relax.TensorStructInfo((6, 4), "float32"))
_check_inference(bb, relax.op.take(y1, idx4), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, relax.op.take(y2, idx4), relax.TensorStructInfo((6, 4), dtype=""))
_check_inference(bb, relax.op.take(y3, idx4), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(y0, idx5), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, relax.op.take(y1, idx5), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, relax.op.take(y2, idx5), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(y3, idx5), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(y0, idx6), relax.TensorStructInfo((6, 4), "float32"))
_check_inference(bb, relax.op.take(y1, idx6), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, relax.op.take(y2, idx6), relax.TensorStructInfo((6, 4), dtype=""))
_check_inference(bb, relax.op.take(y3, idx6), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(y0, idx7), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, relax.op.take(y1, idx7), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, relax.op.take(y2, idx7), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(y3, idx7), relax.TensorStructInfo(dtype="", ndim=2))


def test_take_infer_struct_info_shape_symbolic():
bb = relax.BlockBuilder()
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
i = tir.Var("i", "int64")
j = tir.Var("j", "int64")
k = tir.Var("k", "int64")
x0 = relax.Var("x", R.Tensor((m, n), "float32"))
x1 = relax.Var("x", R.Tensor((m, n)))
y0 = relax.Var("y", R.Tensor((n,), "float32"))
Expand All @@ -127,15 +205,29 @@ def test_take_infer_struct_info_shape_symbolic():
(i,),
),
)
idx2 = relax.Var(
"idx",
R.Tensor(
(i, j, k),
),
)

_check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((m, i), "float32"))
_check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo((m, i), dtype=""))
_check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo((m, i), "float32"))
_check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo((m, i), dtype=""))
_check_inference(
bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo((m, i, j, k), dtype="")
)
_check_inference(
bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo((m, i, j, k), dtype="")
)
_check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((i,), "float32"))
_check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo((i,), dtype=""))
_check_inference(bb, relax.op.take(y0, idx1), relax.TensorStructInfo((i,), "float32"))
_check_inference(bb, relax.op.take(y1, idx1), relax.TensorStructInfo((i,), dtype=""))
_check_inference(bb, relax.op.take(y0, idx2), relax.TensorStructInfo((i, j, k), "float32"))
_check_inference(bb, relax.op.take(y1, idx2), relax.TensorStructInfo((i, j, k), dtype=""))


def test_take_infer_struct_info_shape_var():
Expand Down Expand Up @@ -202,47 +294,6 @@ def test_take_infer_struct_info_more_input_dtype():
_check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo((4, 6), "int32"))


def test_take_infer_struct_info_indices_not_one_dimensional():
bb = relax.BlockBuilder()
sidx0 = relax.Var("sidx", relax.ShapeStructInfo((6, 6)))
sidx1 = relax.Var("sidx", relax.ShapeStructInfo(()))
sidx2 = relax.Var("sidx", relax.ShapeStructInfo(ndim=2))
sidx3 = relax.Var("sidx", relax.ShapeStructInfo(ndim=0))
sidx4 = relax.Var("sidx", relax.ShapeStructInfo())
x = relax.Var("x", R.Tensor((4, 10), "float32"))
idx0 = relax.Var("idx", R.Tensor((6, 6), "int64"))
idx1 = relax.Var("idx", R.Tensor((), "int64"))
idx2 = relax.Var("idx", R.Tensor("int64", ndim=2))
idx3 = relax.Var("idx", R.Tensor("int64", ndim=0))
idx4 = relax.Var("idx", R.Tensor("int64"))
idx5 = relax.Var("idx", relax.TensorStructInfo(sidx0, "int64"))
idx6 = relax.Var("idx", relax.TensorStructInfo(sidx1, "int64"))
idx7 = relax.Var("idx", relax.TensorStructInfo(sidx2, "int64"))
idx8 = relax.Var("idx", relax.TensorStructInfo(sidx3, "int64"))
idx9 = relax.Var("idx", relax.TensorStructInfo(sidx4, "int64"))

with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx0, axis=1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx1, axis=1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx2, axis=1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx3, axis=1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx4, axis=1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx5, axis=1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx6, axis=1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx7, axis=1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx8, axis=1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx9, axis=1))


def test_take_infer_struct_info_indices_not_integer_dtype():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((4, 10), "float32"))
Expand Down
3 changes: 3 additions & 0 deletions tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,9 @@ def test_take():
verify_take((3, 3, 3), [[11, 25]], mode="fast")
verify_take((3, 4), [0, 2], axis=0, mode="fast")
verify_take((3, 4), [0, 2], axis=1, mode="fast")
verify_take((3, 5, 7), [[0, 2], [0, 2], [0, 2], [0, 2]], axis=0, mode="fast")
verify_take((3, 5, 7), [[0, 2], [0, 2], [0, 2], [0, 2]], axis=1, mode="fast")
verify_take((3, 5, 7), [[0, 2], [0, 2], [0, 2], [0, 2]], axis=2, mode="fast")
verify_take((3, 4), [1, 2], axis=1, indices_dtype="uint32")
verify_take((3, 4), [1, 2], axis=1, mode="wrap", indices_dtype="uint16")
verify_take((3, 3, 3), [[11, 20]], mode="fast", indices_dtype="uint8")
Expand Down