diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py index b5e00f012996..5f136365c898 100644 --- a/nnvm/python/nnvm/top/transform.py +++ b/nnvm/python/nnvm/top/transform.py @@ -72,3 +72,7 @@ def compute_reshape_like(attrs, inputs, out_info): # slice_like reg.register_pattern("slice_like", OpPattern.INJECTIVE) reg.register_schedule("slice_like", _fschedule_injective) + +# where +reg.register_pattern("where", OpPattern.INJECTIVE) +reg.register_schedule("where", _fschedule_injective) diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 5bb2ec137594..05adc06a801b 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -1125,8 +1125,8 @@ Examples:: DMLC_REGISTER_PARAMETER(SliceLikeParam); inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); const SliceLikeParam& param = nnvm::get(attrs.parsed); @@ -1221,5 +1221,98 @@ NNVM_REGISTER_OP(slice_like) }) .set_support_level(4); +// where +inline bool WhereShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 1U); + const TShape& cond_shape = in_attrs->at(0); + const TShape& x_shape = in_attrs->at(1); + const TShape& y_shape = in_attrs->at(2); + CHECK_EQ(x_shape, y_shape) << "x and y must have the same shape: " + << x_shape << " vs " << y_shape; + if (cond_shape != x_shape) { + CHECK_EQ(cond_shape.ndim(), 1) + << "Shape of condition " << cond_shape + << " must be either equal to x or has dimension of 1."; + } + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, x_shape); + return true; +} + +inline bool WhereInferType(const NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(1)); + return true; +} + +inline bool WhereCorrectLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + CHECK_EQ(ilayouts->size(), last_ilayouts->size()); + CHECK_EQ(olayouts->size(), 1U); + + for (size_t i = 0; i < ilayouts->size(); ++i) { + const Layout& input = last_ilayouts->at(i).defined() ? + last_ilayouts->at(i) : ilayouts->at(i); + NNVM_ASSIGN_LAYOUT(*ilayouts, i, input); + } + + return true; +} + +NNVM_REGISTER_OP(where) +.describe(R"code( +Return the elements, either from x or y, depending on the condition. + +Given three ndarrays, condition, x, and y, return an ndarray with the elements +from x or y, depending on the elements from condition are true or false. +x and y must have the same shape. If condition has the same shape as x, +each element in the output array is from x if the corresponding element +in the condition is true, and from y if false. + +If condition does not have the same shape as x, it must be a 1D array whose +size is the same as x’s first dimension size. Each row of the output array +is from x’s row if the corresponding element from condition is true, and +from y’s row if false. + +Note that all non-zero values are interpreted as True in condition. + +Examples:: + + x = [[1, 2], [3, 4]] + y = [[5, 6], [7, 8]] + cond = [[0, 1], [-1, 0]] + where(cond, x, y) = [[5, 2], [3, 8]] + + + cond = [1, 0] + where(cond, x, y) = [[1, 2], [7, 8]] + +)code" NNVM_ADD_FILELINE) +.add_argument("condition", "Tensor", "Condition array") +.add_argument("x", "Tensor", "First array to be selected") +.add_argument("y", "Tensor", "Second array to be selected") +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr("FInferShape", WhereShape) +.set_attr("FInferType", WhereInferType) +.set_attr("FCorrectLayout", WhereCorrectLayout) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + return Array{ + topi::where(inputs[0], inputs[1], inputs[2]) + }; + }) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"condition", "x", "y"}; +}) +.set_support_level(4); + } // namespace top } // namespace nnvm diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index de6a3fa331bc..236ac8e82578 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -645,6 +645,36 @@ def test_slice_like(): axis = (2, 3) verify_slice_like(np_data, np_shape_like, axis) +def verify_where(condition, x, y): + dtype = "float32" + if len(condition.shape) == 1: + np_out = np.array([xv if c else yv for (c,xv,yv) in zip(condition,x,y)]) + else: + np_out = np.where(condition, x, y) + cond_var = sym.Variable("condition") + x_var = sym.Variable("x") + y_var = sym.Variable("y") + net = sym.where(cond_var, x_var, y_var) + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(net, target, {"condition": condition.shape, + "x": x.shape, "y": y.shape}) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**{"condition": condition, "x": x, "y": y}) + m.run() + out = m.get_output(0, tvm.nd.empty(x.shape, dtype)) + np.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5) + +def test_where(): + shape = (13, 8, 224, 224, 6) + condition = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + x = np.random.uniform(size=shape).astype("float32") + y = np.random.uniform(size=shape).astype("float32") + verify_where(condition, x, y) + condition = np.random.uniform(low=-1, high=1, size=(shape[0],)).astype("float32") + x = np.random.uniform(size=shape).astype("float32") + y = np.random.uniform(size=shape).astype("float32") + verify_where(condition, x, y) + if __name__ == "__main__": test_reshape() @@ -665,4 +695,5 @@ def test_slice_like(): test_multibox_transform_loc() test_nms() test_slice_like() + test_where() print(nnvm.compiler.engine.dump()) diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 10fb5bc478cb..09af612b957b 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -575,5 +575,53 @@ inline Tensor take(const Tensor& a, }, name, tag); } +/*! +* \brief Return the elements, either from x or y, depending on the condition. +* +* \param condition The condition array. +* \param x First array to be selected. +* \param y Second array to be selected. +* \param name The name of the operation. +* \param tag The tag to mark the operation. +* +* \return A Tensor selected from x or y depending on condition. +*/ +inline Tensor where(const Tensor& condition, + const Tensor& x, + const Tensor& y, + std::string name = "tensor", + std::string tag = kInjective) { + CHECK_EQ(x->shape.size(), y->shape.size()) + << "x and y must have the same shape.Got different number of dimension: " + << x->shape.size() << " vs " << y->shape.size(); + CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " + << x->dtype << " vs " << y->dtype; + Array oshape = x->shape; + Tensor out; + + if (condition->shape.size() != 1) { + CHECK_EQ(condition->shape.size(), x->shape.size()) + << "condition array must be either have the same shape as x or to be a " + "1-D array.Got different number of dimension: " + << condition->shape.size() << " vs " << x->shape.size(); + out = compute( + oshape, [&](const Array& indices) { + return tvm::select(condition(indices) != 0, x(indices), y(indices)); + }, name, tag); + } else { + CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0])) + << "If condition is 1-D, the first dimension must be the same as x: " + << condition->shape[0] << " vs " << x->shape[0]; + out = compute( + oshape, [&](const Array& indices) { + Array condition_idx{indices[0]}; + return tvm::select(condition(condition_idx) != 0, + x(indices), y(indices)); + }, name, tag); + } + return out; +} + + } // namespace topi #endif // TOPI_TRANSFORM_H_ diff --git a/topi/src/topi.cc b/topi/src/topi.cc index c08bd5f565d2..4cdab4401459 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -280,6 +280,11 @@ TVM_REGISTER_GLOBAL("topi.take") } }); +TVM_REGISTER_GLOBAL("topi.where") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = where(args[0], args[1], args[2]); +}); + TVM_REGISTER_GLOBAL("topi.strided_slice") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = strided_slice(args[0], args[1], args[2], args[3]); diff --git a/topi/tests/python_cpp/test_topi_transform.py b/topi/tests/python_cpp/test_topi_transform.py index a94fc89328af..c8b7c3906caa 100644 --- a/topi/tests/python_cpp/test_topi_transform.py +++ b/topi/tests/python_cpp/test_topi_transform.py @@ -206,6 +206,35 @@ def check_device(device): for device in ["llvm", "opencl"]: check_device(device) +def verify_where(condition, x, y): + dtype = "float32" + if len(condition.shape) == 1: + np_out = np.array([xv if c else yv for (c,xv,yv) in zip(condition,x,y)]) + else: + np_out = np.where(condition, x, y) + A = tvm.placeholder(shape=condition.shape, dtype=dtype, name="condition") + B = tvm.placeholder(shape=x.shape, dtype=dtype, name="x") + C = tvm.placeholder(shape=y.shape, dtype=dtype, name="y") + out_tensor = topi.cpp.where(A, B, C) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(out_tensor) + + foo = tvm.build(s, [A, B, C, out_tensor], device, name="where") + tvm_out = tvm.nd.empty(x.shape, ctx=ctx, dtype=dtype) + foo(tvm.nd.array(condition, ctx), tvm.nd.array(x, ctx), + tvm.nd.array(y, ctx), tvm_out) + np.testing.assert_allclose(tvm_out.asnumpy(), np_out) + + for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: + check_device(device) + def verify_concatenate_split(shapes, axis, indices_or_sections): tensor_l_concatenate = [] for i, shape in enumerate(shapes): @@ -324,6 +353,18 @@ def test_take(): verify_take((2,2), [[[1,0],[0,1]]], 1) verify_take((4,3,5,6), [[2,1,0,0]], -2) +def test_where(): + shape = (10, 3, 7, 13) + condition = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + x = np.random.uniform(size=shape).astype("float32") + y = np.random.uniform(size=shape).astype("float32") + verify_where(condition, x, y) + condition = np.random.uniform(low=-1, high=1, size=(shape[0],)).astype("float32") + x = np.random.uniform(size=shape).astype("float32") + y = np.random.uniform(size=shape).astype("float32") + verify_where(condition, x, y) + + def test_regression_1(): verify_concatenate_split([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1, [3, 7]) verify_concatenate_split([(3, 4), (2, 4), (3, 4)], 0, [1, 2, 3, 4]) @@ -340,5 +381,6 @@ def test_regression_2(): test_squeeze() test_split() test_take() + test_where() test_regression_1() test_regression_2()