From 85226ca56d975fb800246f95fce7a2ae4cc11225 Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 27 Jul 2017 22:40:23 -0700 Subject: [PATCH 1/4] Add add_n op for row-sparse ndarrays and identity FComputeEx --- include/mxnet/ndarray.h | 5 +++ src/ndarray/ndarray_function.cc | 16 +++++---- src/operator/tensor/elemwise_sum.cc | 36 +++++++++++++++++-- src/operator/tensor/elemwise_unary_op.cc | 6 +++- src/operator/tensor/elemwise_unary_op.h | 32 +++++++++++++++++ tests/python/unittest/test_sparse_operator.py | 33 +++++++++++++++++ 6 files changed, 119 insertions(+), 9 deletions(-) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index cbe815eb2c24..d34ad8c038e3 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -199,6 +199,11 @@ class NDArray { return ptr_->aux_shapes[i]; } + const std::vector& aux_shapes() const { + CHECK(storage_type() != kDefaultStorage); + return ptr_->aux_shapes; + } + /*! * \brief For a sparse operation on a csr matrix for example, * the size of the column index array diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index b03166f4d834..89048cfb4e5d 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -29,7 +29,8 @@ void Copy(const TBlob &from, TBlob *to, } template -void ElementwiseSumRspImpl(const std::vector& nds, +void ElementwiseSumRspImpl(mshadow::Stream* s, + const std::vector& nds, const std::vector& uniq_row_idx, NDArray* out, const int nthreads = 4) { @@ -42,7 +43,9 @@ void ElementwiseSumRspImpl(const std::vector& nds, if (row_block_start < nnr) { const size_t row_block_end = std::min(row_block_start+row_block_len, nnr); - auto out_values = out->data().FlatTo2D(); + const size_t num_cols = out->data().shape_.ProdShape(1, out->data().shape_.ndim()); + auto out_values = out->data().get_with_shape( + mshadow::Shape2(out->storage_shape()[0], num_cols), s); auto out_indices = out->aux_data(rowsparse::kIdx).FlatTo1D(); for (size_t i = row_block_start; i < row_block_end; ++i) { out_indices[i] = uniq_row_idx[i]; @@ -50,7 +53,8 @@ void ElementwiseSumRspImpl(const std::vector& nds, for (const auto& nd : nds) { if (nd.storage_initialized()) { const auto nd_indices = nd.aux_data(rowsparse::kIdx).FlatTo1D(); - const auto nd_values = nd.data().FlatTo2D(); + const auto nd_values = nd.data().get_with_shape( + mshadow::Shape2(nd.storage_shape()[0], num_cols), s); const auto nd_num_rows = nd.aux_shape(rowsparse::kIdx).Size(); const IType* nd_indices_start = &nd_indices[0]; const IType* nd_indices_end = nd_indices_start + nd_num_rows; @@ -120,7 +124,7 @@ void GetUniqueRspRowIdx(const std::vector& nds, uniq_row_idx->resize(it - uniq_row_idx->begin()); } -void ElementwiseSumRsp(const std::vector& nds, NDArray* out) { +void ElementwiseSumRsp(mshadow::Stream* s, const std::vector& nds, NDArray* out) { if (nds.empty()) return; using namespace rowsparse; CHECK_EQ(out->storage_type(), kRowSparseStorage) @@ -133,7 +137,7 @@ void ElementwiseSumRsp(const std::vector& nds, NDArray* out) { GetUniqueRspRowIdx(nds, &uniq_row_idx); out->CheckAndAlloc({mshadow::Shape1(uniq_row_idx.size())}); out->data().FlatTo2D() = static_cast(0); - ElementwiseSumRspImpl(nds, uniq_row_idx, out, omp_get_max_threads()); + ElementwiseSumRspImpl(s, nds, uniq_row_idx, out, omp_get_max_threads()); }); }); } @@ -149,7 +153,7 @@ void ElementwiseSum(mshadow::Stream* s, if (nds.empty()) return; if (nds[0].storage_type() == kRowSparseStorage) { - ElementwiseSumRsp(nds, out); + ElementwiseSumRsp(s, nds, out); } else { LOG(FATAL) << "ElementwiseSum has not been implemented for storage_type = << " << nds[0].storage_type(); diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc index 06ec01e8ebd0..e6930cc9993b 100644 --- a/src/operator/tensor/elemwise_sum.cc +++ b/src/operator/tensor/elemwise_sum.cc @@ -4,6 +4,7 @@ * \brief elementwise sum operator */ #include "./elemwise_sum.h" +#include "../../ndarray/ndarray_function.h" namespace mxnet { namespace op { @@ -52,6 +53,37 @@ bool ElementWiseSumType(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, -1); } +bool ElementWiseSumForwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK(!in_attrs->empty()); + CHECK_EQ(out_attrs->size(), 1U); + return ElemwiseStorageAttr( + attrs, in_attrs, out_attrs); +} + +void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK(!inputs.empty()); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + if (req[0] == kNullOp) return; + CHECK_EQ(req[0], kWriteTo) << "ElementWiseSumComputeExCPU only supports req = kWriteTo"; + using namespace mshadow; + Stream* s = ctx.get_stream(); + NDArray out_nd = outputs[0]; + if (inputs[0].storage_type() == kRowSparseStorage) { + mxnet::ndarray::ElementwiseSum(s, inputs, &out_nd); + } else { + FCompExFallback(attrs, ctx, inputs, req, outputs, + ElementWiseSumCompute, "ElementWiseSumCompute"); + } +} + NNVM_REGISTER_OP(add_n) .add_alias("ElementWiseSum") .describe(R"doc(Adds all input arguments element-wise. @@ -77,16 +109,16 @@ NNVM_REGISTER_OP(add_n) }) .set_attr("key_var_num_args", "num_args") .set_attr("FCompute", ElementWiseSumCompute) +.set_attr("FComputeEx", ElementWiseSumComputeExCPU) .set_attr( "FInplaceOption", [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; }) .set_attr("FInferShape", ElementWiseSumShape) .set_attr("FInferType", ElementWiseSumType) +.set_attr("FInferStorageType", ElementWiseSumForwardInferStorageType) .set_attr("FGradient", ElementWiseSumGrad) .add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments"); - - } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_unary_op.cc b/src/operator/tensor/elemwise_unary_op.cc index 078d62b5f96e..37e022c9bdb3 100644 --- a/src/operator/tensor/elemwise_unary_op.cc +++ b/src/operator/tensor/elemwise_unary_op.cc @@ -48,7 +48,9 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_sigmoid) MXNET_OPERATOR_REGISTER_UNARY(_copy) .MXNET_DESCRIBE("Returns a copy of the input.") .add_alias("identity") +.set_attr("FInferStorageType", ElemwiseStorageType<1, 1>) .set_attr("FCompute", IdentityCompute) +.set_attr("FComputeEx", IdentityComputeEx) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); NNVM_REGISTER_OP(_backward_copy) @@ -59,7 +61,9 @@ NNVM_REGISTER_OP(_backward_copy) [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; }) -.set_attr("FCompute", IdentityCompute); +.set_attr("FInferStorageType", ElemwiseStorageType<1, 1>) +.set_attr("FCompute", IdentityCompute) +.set_attr("FComputeEx", IdentityComputeEx); MXNET_OPERATOR_REGISTER_UNARY(BlockGrad) .add_alias("stop_gradient") diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 40434a1e40f6..03d4ec32c831 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -104,6 +104,38 @@ void IdentityComputeRspRspImpl(const nnvm::NodeAttrs& attrs, }); } +template +void IdentityComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const auto in_stype = inputs[0].storage_type(); + const auto out_stype = outputs[0].storage_type(); + mshadow::Stream *s = ctx.get_stream(); + if (in_stype == out_stype) { + if (in_stype == kDefaultStorage) { // dense ndarray + IdentityCompute(attrs, ctx, {inputs[0].data()}, req, {outputs[0].data()}); + } else { // sparse ndarray + if (!inputs[0].storage_initialized()) { + FillComputeZerosEx(attrs, ctx, inputs, req, outputs); + return; + } + const size_t n = mxnet::num_aux_data(out_stype); + outputs[0].CheckAndAlloc(inputs[0].aux_shapes()); + IdentityCompute(attrs, ctx, {inputs[0].data()}, req, {outputs[0].data()}); + for (size_t i = 0; i < n; ++i) { + IdentityCompute(attrs, ctx, {inputs[0].aux_data(i)}, req, {outputs[0].aux_data(i)}); + } + } + } else { + FCompExFallback(attrs, ctx, inputs, req, outputs, IdentityCompute, "IdentityCompute"); + } +} + template void IdentityLikeRhsComputeEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 23fbb8fa2b05..42965688d819 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -230,6 +230,39 @@ def test_sparse_square_sum(): atol=1e-2, rtol=0.1) +def test_sparse_elementwise_sum(): + def check_sparse_elementwise_sum_with_shape(stype, shape, n): + # forward + inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)] + out = mx.symbol.add_n(*inputs, name='esum') + arr = [] + arr_grad = [mx.nd.empty(shape) for _ in range(n)] + densities = [0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5] + for i in range(n): + arr.append(rand_ndarray(shape, stype, np.random.randint(0, len(densities)))) + + exec1 = out.bind(default_context(), + args=arr, + args_grad=arr_grad) + exec1.forward(is_train=True) + out1 = exec1.outputs[0].asnumpy() + out = sum(a.asnumpy() for a in arr) + assert_almost_equal(out, out1) + + out_grad = mx.nd.empty(shape) + out_grad[:] = np.random.uniform(-10, 10, shape) + # backward + exec1.backward([out_grad]) + for a in arr_grad: + assert_almost_equal(a.asnumpy(), out_grad.asnumpy()) + + maxdim = 5 + for dim in range(2, maxdim): + shape = tuple(np.random.randint(5, 10, size=dim)) + print shape + check_sparse_elementwise_sum_with_shape('row_sparse', shape, np.random.randint(1, 9)) + + if __name__ == '__main__': import nose nose.runmodule() From acc4bdad41adb82fc6d60ea53fd218ac9a2c44b2 Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 27 Jul 2017 22:57:36 -0700 Subject: [PATCH 2/4] Fix bug in square_sum --- src/operator/tensor/square_sum-inl.h | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/operator/tensor/square_sum-inl.h b/src/operator/tensor/square_sum-inl.h index 8271025197a7..3fa50153d356 100644 --- a/src/operator/tensor/square_sum-inl.h +++ b/src/operator/tensor/square_sum-inl.h @@ -195,10 +195,16 @@ void SquareSumRspImpl(const nnvm::NodeAttrs& attrs, using namespace mxnet_op; if (!input.storage_initialized()) { - if (req == kWriteTo && output->storage_type() == kDefaultStorage) { - MSHADOW_TYPE_SWITCH(output->data().type_flag_, DType, { - Kernel::Launch(s, out_data_size, output->data().dptr()); - }) + if (req == kWriteTo) { + if (output->storage_type() == kDefaultStorage) { + MSHADOW_TYPE_SWITCH(output->data().type_flag_, DType, { + Kernel::Launch(s, out_data_size, output->data().dptr()); + }) + } else if (output->storage_type() == kRowSparseStorage) { + FillZerosRspImpl(s, output); + } else { + LOG(FATAL) << "SquareSumRspImpl only supports row-sparse/dense output storage type"; + } } return; } From 0ef64cf9b7222d0d9db4081d64337999ef9b0b8f Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 28 Jul 2017 14:57:39 -0700 Subject: [PATCH 3/4] Remove test_cast_storage_ex from gpu test since it's not implemented yet --- tests/python/gpu/test_operator_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index adc4c3b903bf..ac59811df6a0 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -5,7 +5,7 @@ from test_operator import * from test_optimizer import * from test_random import * -from test_sparse_operator import test_cast_storage_ex, test_sparse_dot, test_sparse_nd_zeros +from test_sparse_operator import test_sparse_dot, test_sparse_nd_zeros from test_sparse_ndarray import test_create_csr, test_create_row_sparse import mxnet as mx import numpy as np From 6fc7ea71c0790e291851b50037af98b0c650b136 Mon Sep 17 00:00:00 2001 From: reminisce Date: Sat, 29 Jul 2017 21:01:39 -0700 Subject: [PATCH 4/4] Fix according to the cr --- src/ndarray/ndarray_function.cc | 6 +++--- src/operator/tensor/elemwise_unary_op.h | 4 +++- src/operator/tensor/square_sum-inl.h | 4 ++++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index 89048cfb4e5d..733870faa7fd 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -43,9 +43,9 @@ void ElementwiseSumRspImpl(mshadow::Stream* s, if (row_block_start < nnr) { const size_t row_block_end = std::min(row_block_start+row_block_len, nnr); - const size_t num_cols = out->data().shape_.ProdShape(1, out->data().shape_.ndim()); + const size_t row_length = out->data().shape_.ProdShape(1, out->data().shape_.ndim()); auto out_values = out->data().get_with_shape( - mshadow::Shape2(out->storage_shape()[0], num_cols), s); + mshadow::Shape2(out->storage_shape()[0], row_length), s); auto out_indices = out->aux_data(rowsparse::kIdx).FlatTo1D(); for (size_t i = row_block_start; i < row_block_end; ++i) { out_indices[i] = uniq_row_idx[i]; @@ -54,7 +54,7 @@ void ElementwiseSumRspImpl(mshadow::Stream* s, if (nd.storage_initialized()) { const auto nd_indices = nd.aux_data(rowsparse::kIdx).FlatTo1D(); const auto nd_values = nd.data().get_with_shape( - mshadow::Shape2(nd.storage_shape()[0], num_cols), s); + mshadow::Shape2(nd.storage_shape()[0], row_length), s); const auto nd_num_rows = nd.aux_shape(rowsparse::kIdx).Size(); const IType* nd_indices_start = &nd_indices[0]; const IType* nd_indices_end = nd_indices_start + nd_num_rows; diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 03d4ec32c831..dc1800d4a322 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -119,7 +119,7 @@ void IdentityComputeEx(const nnvm::NodeAttrs& attrs, if (in_stype == out_stype) { if (in_stype == kDefaultStorage) { // dense ndarray IdentityCompute(attrs, ctx, {inputs[0].data()}, req, {outputs[0].data()}); - } else { // sparse ndarray + } else if (in_stype == kRowSparseStorage || in_stype == kCSRStorage) { // sparse ndarray if (!inputs[0].storage_initialized()) { FillComputeZerosEx(attrs, ctx, inputs, req, outputs); return; @@ -130,6 +130,8 @@ void IdentityComputeEx(const nnvm::NodeAttrs& attrs, for (size_t i = 0; i < n; ++i) { IdentityCompute(attrs, ctx, {inputs[0].aux_data(i)}, req, {outputs[0].aux_data(i)}); } + } else { + LOG(FATAL) << "IdentityComputeEx does not support input stype = " << in_stype; } } else { FCompExFallback(attrs, ctx, inputs, req, outputs, IdentityCompute, "IdentityCompute"); diff --git a/src/operator/tensor/square_sum-inl.h b/src/operator/tensor/square_sum-inl.h index 3fa50153d356..4aa1557bbfb3 100644 --- a/src/operator/tensor/square_sum-inl.h +++ b/src/operator/tensor/square_sum-inl.h @@ -308,6 +308,8 @@ void SquareSumOpForwardEx(const nnvm::NodeAttrs& attrs, mshadow::Stream* s = ctx.get_stream(); const NDArrayStorageType istype = inputs[0].storage_type(); if (istype == kRowSparseStorage) { + CHECK_EQ(inputs[0].shape().ndim(), 2U) << "_square_sum op only supports" + " 2D ndarray as input"; NDArray output = outputs[0]; SquareSumRspImpl(attrs, s, inputs[0], req[0], &output); } else { @@ -330,6 +332,8 @@ void SquareSumOpBackwardEx(const nnvm::NodeAttrs& attrs, const NDArrayStorageType ograd_stype = inputs[0].storage_type(); const NDArrayStorageType input_stype = inputs[1].storage_type(); if (input_stype == kRowSparseStorage && ograd_stype == kDefaultStorage) { + CHECK_EQ(inputs[1].shape().ndim(), 2U) << "_square_sum op only supports" + " 2D ndarray as input"; NDArray output = outputs[0]; SquareSumRspGradImpl(attrs, s, inputs[0], inputs[1], req[0], &output); } else {