Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[topi][relay] Add operation gather to relay. #5716

Merged
merged 3 commits into from Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Expand Up @@ -55,6 +55,7 @@ List of operators
topi.concatenate
topi.split
topi.take
topi.gather
topi.gather_nd
topi.full
topi.full_like
Expand Down Expand Up @@ -160,6 +161,7 @@ topi
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
.. autofunction:: topi.take
.. autofunction:: topi.gather
.. autofunction:: topi.gather_nd
.. autofunction:: topi.full
.. autofunction:: topi.full_like
Expand Down
1 change: 1 addition & 0 deletions docs/langref/relay_op.rst
Expand Up @@ -118,6 +118,7 @@ This level enables additional math and transform operators.
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.gather
tvm.relay.gather_nd
tvm.relay.full
tvm.relay.full_like
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/relay/attrs/transform.h
Expand Up @@ -101,6 +101,16 @@ struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
}
};

struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
Integer axis;

TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Integer>())
.describe("The axis over which to select values.");
}
};

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer axis;
std::string mode;
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Expand Up @@ -51,6 +51,7 @@
_reg.register_injective_schedule("transpose")
_reg.register_injective_schedule("stack")
_reg.register_injective_schedule("_contrib_reverse_reshape")
_reg.register_injective_schedule("gather")
_reg.register_injective_schedule("gather_nd")
_reg.register_injective_schedule("sequence_mask")
_reg.register_injective_schedule("one_hot")
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Expand Up @@ -189,6 +189,10 @@ class TransposeAttrs(Attrs):
class ReshapeAttrs(Attrs):
"""Attributes for transform.reshape"""

@tvm._ffi.register_object("relay.attrs.GatherAttrs")
class GatherAttrs(Attrs):
"""Attributes for transform.gather"""

@tvm._ffi.register_object("relay.attrs.TakeAttrs")
class TakeAttrs(Attrs):
"""Attributes for transform.take"""
Expand Down
37 changes: 37 additions & 0 deletions python/tvm/relay/op/transform.py
Expand Up @@ -800,6 +800,43 @@ def reverse_reshape(data, newshape):
return _make._contrib_reverse_reshape(data, list(newshape))


def gather(data, axis, indices):
"""Gather values along given axis from given indices.

E.g. for a 3D tensor, output is computed as:

.. code-block:: python

out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0
out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1
out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2

``indices`` must have same shape as ``data``, except at dimension ``axis``
which must just be not null. Output will have same shape as ``indices``.

Parameters
----------
data: relay.Expr
The input data to the operator.

axis: int
The axis along which to index.

indices: relay.Expr
The indices of values to gather.

Examples
--------
.. code-block:: python

data = [[1, 2], [3, 4]]
axis = 1
indices = [[0, 0], [1, 0]]
relay.gather(data, axis, indices) = [[1, 1], [4, 3]]
"""
return _make.gather(data, axis, indices)


def gather_nd(data, indices):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.
Expand Down
82 changes: 82 additions & 0 deletions src/relay/op/tensor/transform.cc
Expand Up @@ -2385,6 +2385,88 @@ example below::
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// gather operator
TVM_REGISTER_NODE_TYPE(GatherAttrs);

bool GatherRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, indices, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* indices = types[1].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "Gather: expect input data type to be TensorType but get " << types[0];
return false;
}
if (indices == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "Gather: expect indices type to be TensorType but get " << types[1];
return false;
}
CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
const auto param = attrs.as<GatherAttrs>();
CHECK(param != nullptr);
CHECK(param->axis.defined());

const auto ndim_data = data->shape.size();
const auto ndim_indices = indices->shape.size();
int axis = param->axis->value;
CHECK_EQ(ndim_data, ndim_indices);
CHECK_GE(axis, 0);
CHECK_LT(axis, ndim_data);

std::vector<IndexExpr> oshape;
oshape.reserve(ndim_data);
for (size_t i = 0; i < ndim_data; ++i) {
if (i == (size_t)axis) {
const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]);
CHECK_GE(*indice_shape_i, 1);
} else {
CHECK(reporter->AssertEQ(indices->shape[i], data->shape[i]));
}
oshape.emplace_back(indices->shape[i]);
}
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}

Array<te::Tensor> GatherCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<GatherAttrs>();
return {topi::gather(inputs[0], param->axis, inputs[1])};
}

Expr MakeGather(Expr data, Integer axis, Expr indices) {
auto attrs = make_object<GatherAttrs>();
attrs->axis = std::move(axis);
static const Op& op = Op::Get("gather");
return Call(op, {data, indices}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.gather").set_body_typed(MakeGather);

RELAY_REGISTER_OP("gather")
.describe(R"code(Gather values along given axis from given indices.

E.g. for a 3D tensor, output is computed as:

out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0
out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1
out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2

``indices`` must have same shape as ``data``, except at dimension ``axis``
which must just be not null. Output will have same shape as ``indices``.
)code" TVM_ADD_FILELINE)
.set_attrs_type<GatherAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input data to the operator.")
.add_argument("indices", "Tensor", "The indices of values to gather.")
.set_support_level(3)
.add_type_rel("Gather", GatherRel)
.set_attr<FTVMCompute>("FTVMCompute", GatherCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// gather_nd operator
bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down
52 changes: 52 additions & 0 deletions tests/python/relay/test_op_level3.py
Expand Up @@ -711,6 +711,58 @@ def verify_scatter(dshape, ishape, axis=0):
verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)


def test_gather():
def verify_gather(data, axis, indices, ref_res):
data = np.asarray(data, dtype='float32')
indices = np.asarray(indices, dtype='int32')
ref_res = np.asarray(ref_res)

d = relay.var("x", relay.TensorType(data.shape, "float32"))
i = relay.var("y", relay.TensorType(indices.shape, "int32"))
z = relay.gather(d, axis, i)

func = relay.Function([d, i], z)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data, indices)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res,
rtol=1e-5)

verify_gather([[1, 2], [3, 4]],
1,
[[0, 0], [1, 0]],
[[1, 1], [4, 3]])
verify_gather([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]],
0,
[[[1, 0, 1], [1, 1, 0]]],
[[[6, 1, 8], [9, 10, 5]]])
verify_gather([[[-0.2321, -0.2024, -1.7624], [-0.3829, -0.4246, 0.2448],
[0.1822, 0.2360, -0.8965], [0.4497, -0.2224, 0.6103]],
[[0.0408, -0.7667, -0.4303], [-0.3216, 0.7489, -0.1502],
[0.0144, -0.4699, -0.0064], [-0.0768, -1.6064, 1.3390]]],
1,
[[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]],
[[[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]],
[[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]]])
verify_gather([[[0.3050, 1.6986, 1.1034], [0.7020, -0.6960, -2.1818],
[0.3116, -0.5773, -0.9912], [0.0835, -1.3915, -1.0720]],
[[0.1694, -0.6091, -0.6539], [-0.5234, -0.1218, 0.5084],
[0.2374, -1.9537, -2.0078], [-0.5700, -1.0302, 0.1558]]],
2,
[[[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]],
[[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]]],
[[[1.6986, 1.6986, 0.3050, 1.6986],
[0.7020, 0.7020, -2.1818, -2.1818],
[-0.5773, -0.9912, -0.5773, -0.9912],
[-1.0720, -1.0720, -1.3915, 0.0835]],
[[0.1694, 0.1694, -0.6091, -0.6539],
[0.5084, 0.5084, -0.1218, -0.5234],
[-1.9537, -2.0078, 0.2374, 0.2374],
[-0.5700, 0.1558, -0.5700, 0.1558]]])


def test_gather_nd():
def verify_gather_nd(xshape, yshape, y_data):
x = relay.var("x", relay.TensorType(xshape, "float32"))
Expand Down
48 changes: 48 additions & 0 deletions topi/include/topi/transform.h
Expand Up @@ -988,6 +988,54 @@ inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_t
}
}

/*!
* \brief Gather values along given axis from given indices.
*
* \param data The input data to the operator.
* \param axis The axis along which to index.
* \param indices The indices of values to gather.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the gather operation
*/
inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
std::string name = "T_gather", std::string tag = kInjective) {
size_t ndim_d = data->shape.size();
size_t ndim_i = indices->shape.size();
CHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
CHECK_EQ(ndim_d, ndim_i);
CHECK_GE(axis, 0);
CHECK_LT(axis, ndim_d);
size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
CHECK_GE(indices_dim_i, 1);
CHECK(indices->dtype.is_int());

Array<PrimExpr> out_shape;
for (size_t i = 0; i < ndim_i; ++i) {
out_shape.push_back(indices->shape[i]);
}

return compute(
out_shape,
[&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
for (size_t i = 0; i < ndim_i; ++i) {
indices_position.push_back(out_index[i]);
}
Array<PrimExpr> real_indices;
for (size_t i = 0; i < ndim_i; ++i) {
if (i == (size_t)axis) {
real_indices.push_back(indices(indices_position));
} else {
real_indices.push_back(indices_position[i]);
}
}
return data(real_indices);
},
name, tag);
}

/*!
* \brief Gather elements from a n-dimension array.
*
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/testing/__init__.py
Expand Up @@ -42,6 +42,7 @@
from .roi_pool_python import roi_pool_nchw_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_python import gather_python
from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python, strided_set_python
from .batch_matmul import batch_matmul
Expand Down
46 changes: 46 additions & 0 deletions topi/python/topi/testing/gather_python.py
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""gather in python"""
import numpy as np

def gather_python(data, axis, indices):
""" Python version of Gather operator

Parameters
----------
data : numpy.ndarray
Numpy array

axis: int
integer

indices : numpy.ndarray
Numpy array

Returns
-------
b_np : numpy.ndarray
Numpy array
"""
shape_indices = indices.shape
out = np.zeros(shape_indices, dtype=data.dtype)
for index in np.ndindex(*shape_indices):
new_index = list(index)
new_index[axis] = indices[index]
out[index] = data[tuple(new_index)]
return out