Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Operator add_n for row sparse ndarrays #7244

Merged
merged 4 commits into from
Aug 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ class NDArray {
return ptr_->aux_shapes[i];
}

const std::vector<TShape>& aux_shapes() const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it inline?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Member functions defined inside classes are automatically inline. There is no need to put the key word inline in front of it.

CHECK(storage_type() != kDefaultStorage);
return ptr_->aux_shapes;
}

/*!
* \brief For a sparse operation on a csr matrix for example,
* the size of the column index array
Expand Down
16 changes: 10 additions & 6 deletions src/ndarray/ndarray_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ void Copy<cpu, cpu>(const TBlob &from, TBlob *to,
}

template<typename DType, typename IType>
void ElementwiseSumRspImpl(const std::vector<NDArray>& nds,
void ElementwiseSumRspImpl(mshadow::Stream<cpu>* s,
const std::vector<NDArray>& nds,
const std::vector<IType>& uniq_row_idx,
NDArray* out,
const int nthreads = 4) {
Expand All @@ -42,15 +43,18 @@ void ElementwiseSumRspImpl(const std::vector<NDArray>& nds,
if (row_block_start < nnr) {
const size_t row_block_end = std::min(row_block_start+row_block_len, nnr);

auto out_values = out->data().FlatTo2D<cpu, DType>();
const size_t row_length = out->data().shape_.ProdShape(1, out->data().shape_.ndim());
auto out_values = out->data().get_with_shape<cpu, 2, DType>(
mshadow::Shape2(out->storage_shape()[0], row_length), s);
auto out_indices = out->aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>();
for (size_t i = row_block_start; i < row_block_end; ++i) {
out_indices[i] = uniq_row_idx[i];
}
for (const auto& nd : nds) {
if (nd.storage_initialized()) {
const auto nd_indices = nd.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>();
const auto nd_values = nd.data().FlatTo2D<cpu, DType>();
const auto nd_values = nd.data().get_with_shape<cpu, 2, DType>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you discuss the new FlatTo2D method with Eric?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

He suggested not adding a bool argument but calculating the desired shape explicitly.

mshadow::Shape2(nd.storage_shape()[0], row_length), s);
const auto nd_num_rows = nd.aux_shape(rowsparse::kIdx).Size();
const IType* nd_indices_start = &nd_indices[0];
const IType* nd_indices_end = nd_indices_start + nd_num_rows;
Expand Down Expand Up @@ -120,7 +124,7 @@ void GetUniqueRspRowIdx(const std::vector<NDArray>& nds,
uniq_row_idx->resize(it - uniq_row_idx->begin());
}

void ElementwiseSumRsp(const std::vector<NDArray>& nds, NDArray* out) {
void ElementwiseSumRsp(mshadow::Stream<cpu>* s, const std::vector<NDArray>& nds, NDArray* out) {
if (nds.empty()) return;
using namespace rowsparse;
CHECK_EQ(out->storage_type(), kRowSparseStorage)
Expand All @@ -133,7 +137,7 @@ void ElementwiseSumRsp(const std::vector<NDArray>& nds, NDArray* out) {
GetUniqueRspRowIdx(nds, &uniq_row_idx);
out->CheckAndAlloc({mshadow::Shape1(uniq_row_idx.size())});
out->data().FlatTo2D<cpu, DType>() = static_cast<DType>(0);
ElementwiseSumRspImpl<DType, IType>(nds, uniq_row_idx, out, omp_get_max_threads());
ElementwiseSumRspImpl<DType, IType>(s, nds, uniq_row_idx, out, omp_get_max_threads());
});
});
}
Expand All @@ -149,7 +153,7 @@ void ElementwiseSum<cpu>(mshadow::Stream<cpu>* s,
if (nds.empty()) return;

if (nds[0].storage_type() == kRowSparseStorage) {
ElementwiseSumRsp(nds, out);
ElementwiseSumRsp(s, nds, out);
} else {
LOG(FATAL) << "ElementwiseSum<cpu> has not been implemented for storage_type = << "
<< nds[0].storage_type();
Expand Down
36 changes: 34 additions & 2 deletions src/operator/tensor/elemwise_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* \brief elementwise sum operator
*/
#include "./elemwise_sum.h"
#include "../../ndarray/ndarray_function.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -52,6 +53,37 @@ bool ElementWiseSumType(const nnvm::NodeAttrs& attrs,
attrs, in_attrs, out_attrs, -1);
}

bool ElementWiseSumForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const Context& ctx,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK(!in_attrs->empty());
CHECK_EQ(out_attrs->size(), 1U);
return ElemwiseStorageAttr<int, type_is_none, type_assign, false, true>(
attrs, in_attrs, out_attrs);
}

void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK(!inputs.empty());
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
if (req[0] == kNullOp) return;
CHECK_EQ(req[0], kWriteTo) << "ElementWiseSumComputeExCPU only supports req = kWriteTo";
using namespace mshadow;
Stream<cpu>* s = ctx.get_stream<cpu>();
NDArray out_nd = outputs[0];
if (inputs[0].storage_type() == kRowSparseStorage) {
mxnet::ndarray::ElementwiseSum<cpu>(s, inputs, &out_nd);
} else {
FCompExFallback<cpu>(attrs, ctx, inputs, req, outputs,
ElementWiseSumCompute<cpu>, "ElementWiseSumCompute<cpu>");
}
}

NNVM_REGISTER_OP(add_n)
.add_alias("ElementWiseSum")
.describe(R"doc(Adds all input arguments element-wise.
Expand All @@ -77,16 +109,16 @@ NNVM_REGISTER_OP(add_n)
})
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<FCompute>("FCompute<cpu>", ElementWiseSumCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", ElementWiseSumComputeExCPU)
.set_attr<nnvm::FInplaceOption>(
"FInplaceOption", [](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FInferShape>("FInferShape", ElementWiseSumShape)
.set_attr<nnvm::FInferType>("FInferType", ElementWiseSumType)
.set_attr<FInferStorageType>("FInferStorageType", ElementWiseSumForwardInferStorageType)
.set_attr<nnvm::FGradient>("FGradient", ElementWiseSumGrad)
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments");



} // namespace op
} // namespace mxnet
6 changes: 5 additions & 1 deletion src/operator/tensor/elemwise_unary_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_sigmoid)
MXNET_OPERATOR_REGISTER_UNARY(_copy)
.MXNET_DESCRIBE("Returns a copy of the input.")
.add_alias("identity")
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", IdentityCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", IdentityComputeEx<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});

NNVM_REGISTER_OP(_backward_copy)
Expand All @@ -59,7 +61,9 @@ NNVM_REGISTER_OP(_backward_copy)
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<FCompute>("FCompute<cpu>", IdentityCompute<cpu>);
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", IdentityCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", IdentityComputeEx<cpu>);

MXNET_OPERATOR_REGISTER_UNARY(BlockGrad)
.add_alias("stop_gradient")
Expand Down
34 changes: 34 additions & 0 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,40 @@ void IdentityComputeRspRspImpl(const nnvm::NodeAttrs& attrs,
});
}

template<typename xpu>
void IdentityComputeEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
const auto in_stype = inputs[0].storage_type();
const auto out_stype = outputs[0].storage_type();
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
if (in_stype == out_stype) {
if (in_stype == kDefaultStorage) { // dense ndarray
IdentityCompute<xpu>(attrs, ctx, {inputs[0].data()}, req, {outputs[0].data()});
} else if (in_stype == kRowSparseStorage || in_stype == kCSRStorage) { // sparse ndarray
if (!inputs[0].storage_initialized()) {
FillComputeZerosEx<xpu>(attrs, ctx, inputs, req, outputs);
return;
}
const size_t n = mxnet::num_aux_data(out_stype);
outputs[0].CheckAndAlloc(inputs[0].aux_shapes());
IdentityCompute<xpu>(attrs, ctx, {inputs[0].data()}, req, {outputs[0].data()});
for (size_t i = 0; i < n; ++i) {
IdentityCompute<xpu>(attrs, ctx, {inputs[0].aux_data(i)}, req, {outputs[0].aux_data(i)});
}
} else {
LOG(FATAL) << "IdentityComputeEx does not support input stype = " << in_stype;
}
} else {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, IdentityCompute<xpu>, "IdentityCompute");
}
}

template<typename xpu>
void IdentityLikeRhsComputeEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
18 changes: 14 additions & 4 deletions src/operator/tensor/square_sum-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,16 @@ void SquareSumRspImpl(const nnvm::NodeAttrs& attrs,

using namespace mxnet_op;
if (!input.storage_initialized()) {
if (req == kWriteTo && output->storage_type() == kDefaultStorage) {
MSHADOW_TYPE_SWITCH(output->data().type_flag_, DType, {
Kernel<set_zero, xpu>::Launch(s, out_data_size, output->data().dptr<DType>());
})
if (req == kWriteTo) {
if (output->storage_type() == kDefaultStorage) {
MSHADOW_TYPE_SWITCH(output->data().type_flag_, DType, {
Kernel<set_zero, xpu>::Launch(s, out_data_size, output->data().dptr<DType>());
})
} else if (output->storage_type() == kRowSparseStorage) {
FillZerosRspImpl<xpu>(s, output);
} else {
LOG(FATAL) << "SquareSumRspImpl only supports row-sparse/dense output storage type";
}
}
return;
}
Expand Down Expand Up @@ -302,6 +308,8 @@ void SquareSumOpForwardEx(const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
const NDArrayStorageType istype = inputs[0].storage_type();
if (istype == kRowSparseStorage) {
CHECK_EQ(inputs[0].shape().ndim(), 2U) << "_square_sum op only supports"
" 2D ndarray as input";
NDArray output = outputs[0];
SquareSumRspImpl(attrs, s, inputs[0], req[0], &output);
} else {
Expand All @@ -324,6 +332,8 @@ void SquareSumOpBackwardEx(const nnvm::NodeAttrs& attrs,
const NDArrayStorageType ograd_stype = inputs[0].storage_type();
const NDArrayStorageType input_stype = inputs[1].storage_type();
if (input_stype == kRowSparseStorage && ograd_stype == kDefaultStorage) {
CHECK_EQ(inputs[1].shape().ndim(), 2U) << "_square_sum op only supports"
" 2D ndarray as input";
NDArray output = outputs[0];
SquareSumRspGradImpl(attrs, s, inputs[0], inputs[1], req[0], &output);
} else {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from test_operator import *
from test_optimizer import *
from test_random import *
from test_sparse_operator import test_cast_storage_ex, test_sparse_dot, test_sparse_nd_zeros
from test_sparse_operator import test_sparse_dot, test_sparse_nd_zeros
from test_sparse_ndarray import test_create_csr, test_create_row_sparse
import mxnet as mx
import numpy as np
Expand Down
33 changes: 33 additions & 0 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,39 @@ def test_sparse_square_sum():
atol=1e-2, rtol=0.1)


def test_sparse_elementwise_sum():
def check_sparse_elementwise_sum_with_shape(stype, shape, n):
# forward
inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)]
out = mx.symbol.add_n(*inputs, name='esum')
arr = []
arr_grad = [mx.nd.empty(shape) for _ in range(n)]
densities = [0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5]
for i in range(n):
arr.append(rand_ndarray(shape, stype, np.random.randint(0, len(densities))))

exec1 = out.bind(default_context(),
args=arr,
args_grad=arr_grad)
exec1.forward(is_train=True)
out1 = exec1.outputs[0].asnumpy()
out = sum(a.asnumpy() for a in arr)
assert_almost_equal(out, out1)

out_grad = mx.nd.empty(shape)
out_grad[:] = np.random.uniform(-10, 10, shape)
# backward
exec1.backward([out_grad])
for a in arr_grad:
assert_almost_equal(a.asnumpy(), out_grad.asnumpy())

maxdim = 5
for dim in range(2, maxdim):
shape = tuple(np.random.randint(5, 10, size=dim))
print shape
check_sparse_elementwise_sum_with_shape('row_sparse', shape, np.random.randint(1, 9))


if __name__ == '__main__':
import nose
nose.runmodule()