Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI]Add where operator #1416

Merged
merged 4 commits into from
Jul 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nnvm/python/nnvm/top/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,7 @@ def compute_reshape_like(attrs, inputs, out_info):
# slice_like
reg.register_pattern("slice_like", OpPattern.INJECTIVE)
reg.register_schedule("slice_like", _fschedule_injective)

# where
reg.register_pattern("where", OpPattern.INJECTIVE)
reg.register_schedule("where", _fschedule_injective)
97 changes: 95 additions & 2 deletions nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1125,8 +1125,8 @@ Examples::
DMLC_REGISTER_PARAMETER(SliceLikeParam);

inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed);
Expand Down Expand Up @@ -1221,5 +1221,98 @@ NNVM_REGISTER_OP(slice_like)
})
.set_support_level(4);

// where
inline bool WhereShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
const TShape& cond_shape = in_attrs->at(0);
const TShape& x_shape = in_attrs->at(1);
const TShape& y_shape = in_attrs->at(2);
CHECK_EQ(x_shape, y_shape) << "x and y must have the same shape: "
<< x_shape << " vs " << y_shape;
if (cond_shape != x_shape) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about proper broadcasting to allow arbitrary shape for condition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current implementation aligns with numpy to allow condition shape to be either same as x or 1-D. We can add more broadcasting later if necessary.

CHECK_EQ(cond_shape.ndim(), 1)
<< "Shape of condition " << cond_shape
<< " must be either equal to x or has dimension of 1.";
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, x_shape);
return true;
}

inline bool WhereInferType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check for dtype of X, Y to be same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

return true;
}

inline bool WhereCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
CHECK_EQ(olayouts->size(), 1U);

for (size_t i = 0; i < ilayouts->size(); ++i) {
const Layout& input = last_ilayouts->at(i).defined() ?
last_ilayouts->at(i) : ilayouts->at(i);
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
}

return true;
}

NNVM_REGISTER_OP(where)
.describe(R"code(
Return the elements, either from x or y, depending on the condition.

Given three ndarrays, condition, x, and y, return an ndarray with the elements
from x or y, depending on the elements from condition are true or false.
x and y must have the same shape. If condition has the same shape as x,
each element in the output array is from x if the corresponding element
in the condition is true, and from y if false.

If condition does not have the same shape as x, it must be a 1D array whose
size is the same as x’s first dimension size. Each row of the output array
is from x’s row if the corresponding element from condition is true, and
from y’s row if false.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to add some basic example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Note that all non-zero values are interpreted as True in condition.

Examples::

x = [[1, 2], [3, 4]]
y = [[5, 6], [7, 8]]
cond = [[0, 1], [-1, 0]]
where(cond, x, y) = [[5, 2], [3, 8]]


cond = [1, 0]
where(cond, x, y) = [[1, 2], [7, 8]]

)code" NNVM_ADD_FILELINE)
.add_argument("condition", "Tensor", "Condition array")
.add_argument("x", "Tensor", "First array to be selected")
.add_argument("y", "Tensor", "Second array to be selected")
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", WhereShape)
.set_attr<FInferType>("FInferType", WhereInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", WhereCorrectLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{
topi::where(inputs[0], inputs[1], inputs[2])
};
})
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"condition", "x", "y"};
})
.set_support_level(4);

} // namespace top
} // namespace nnvm
31 changes: 31 additions & 0 deletions nnvm/tests/python/compiler/test_top_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,36 @@ def test_slice_like():
axis = (2, 3)
verify_slice_like(np_data, np_shape_like, axis)

def verify_where(condition, x, y):
dtype = "float32"
if len(condition.shape) == 1:
np_out = np.array([xv if c else yv for (c,xv,yv) in zip(condition,x,y)])
else:
np_out = np.where(condition, x, y)
cond_var = sym.Variable("condition")
x_var = sym.Variable("x")
y_var = sym.Variable("y")
net = sym.where(cond_var, x_var, y_var)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(net, target, {"condition": condition.shape,
"x": x.shape, "y": y.shape})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"condition": condition, "x": x, "y": y})
m.run()
out = m.get_output(0, tvm.nd.empty(x.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)

def test_where():
shape = (13, 8, 224, 224, 6)
condition = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
condition = np.random.uniform(low=-1, high=1, size=(shape[0],)).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)


if __name__ == "__main__":
test_reshape()
Expand All @@ -665,4 +695,5 @@ def test_slice_like():
test_multibox_transform_loc()
test_nms()
test_slice_like()
test_where()
print(nnvm.compiler.engine.dump())
48 changes: 48 additions & 0 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,5 +575,53 @@ inline Tensor take(const Tensor& a,
}, name, tag);
}

/*!
* \brief Return the elements, either from x or y, depending on the condition.
*
* \param condition The condition array.
* \param x First array to be selected.
* \param y Second array to be selected.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor selected from x or y depending on condition.
*/
inline Tensor where(const Tensor& condition,
const Tensor& x,
const Tensor& y,
std::string name = "tensor",
std::string tag = kInjective) {
CHECK_EQ(x->shape.size(), y->shape.size())
<< "x and y must have the same shape.Got different number of dimension: "
<< x->shape.size() << " vs " << y->shape.size();
CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: "
<< x->dtype << " vs " << y->dtype;
Array<Expr> oshape = x->shape;
Tensor out;

if (condition->shape.size() != 1) {
CHECK_EQ(condition->shape.size(), x->shape.size())
<< "condition array must be either have the same shape as x or to be a "
"1-D array.Got different number of dimension: "
<< condition->shape.size() << " vs " << x->shape.size();
out = compute(
oshape, [&](const Array<Var>& indices) {
return tvm::select(condition(indices) != 0, x(indices), y(indices));
}, name, tag);
} else {
CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0]))
<< "If condition is 1-D, the first dimension must be the same as x: "
<< condition->shape[0] << " vs " << x->shape[0];
out = compute(
oshape, [&](const Array<Var>& indices) {
Array<Expr> condition_idx{indices[0]};
return tvm::select(condition(condition_idx) != 0,
x(indices), y(indices));
}, name, tag);
}
return out;
}


} // namespace topi
#endif // TOPI_TRANSFORM_H_
5 changes: 5 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,11 @@ TVM_REGISTER_GLOBAL("topi.take")
}
});

TVM_REGISTER_GLOBAL("topi.where")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = where(args[0], args[1], args[2]);
});

TVM_REGISTER_GLOBAL("topi.strided_slice")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3]);
Expand Down
42 changes: 42 additions & 0 deletions topi/tests/python_cpp/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,35 @@ def check_device(device):
for device in ["llvm", "opencl"]:
check_device(device)

def verify_where(condition, x, y):
dtype = "float32"
if len(condition.shape) == 1:
np_out = np.array([xv if c else yv for (c,xv,yv) in zip(condition,x,y)])
else:
np_out = np.where(condition, x, y)
A = tvm.placeholder(shape=condition.shape, dtype=dtype, name="condition")
B = tvm.placeholder(shape=x.shape, dtype=dtype, name="x")
C = tvm.placeholder(shape=y.shape, dtype=dtype, name="y")
out_tensor = topi.cpp.where(A, B, C)

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor)

foo = tvm.build(s, [A, B, C, out_tensor], device, name="where")
tvm_out = tvm.nd.empty(x.shape, ctx=ctx, dtype=dtype)
foo(tvm.nd.array(condition, ctx), tvm.nd.array(x, ctx),
tvm.nd.array(y, ctx), tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), np_out)

for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device(device)

def verify_concatenate_split(shapes, axis, indices_or_sections):
tensor_l_concatenate = []
for i, shape in enumerate(shapes):
Expand Down Expand Up @@ -324,6 +353,18 @@ def test_take():
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)

def test_where():
shape = (10, 3, 7, 13)
condition = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
condition = np.random.uniform(low=-1, high=1, size=(shape[0],)).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)


def test_regression_1():
verify_concatenate_split([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1, [3, 7])
verify_concatenate_split([(3, 4), (2, 4), (3, 4)], 0, [1, 2, 3, 4])
Expand All @@ -340,5 +381,6 @@ def test_regression_2():
test_squeeze()
test_split()
test_take()
test_where()
test_regression_1()
test_regression_2()