Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Sparse dot enhancement (#6842)
Browse files Browse the repository at this point in the history
* Initial checkin

Initial checkin

Fix sparse dot test

Fix unitest and add fallback for sparse dot

* Add benchmark code

* Revert "Add benchmark code"

This reverts commit be009fe.

* Fix bug

* Fix storage shape

* Remove unnecessary test code

* Use idx type switch
  • Loading branch information
reminisce authored and piiswrong committed Jun 30, 2017
1 parent 1e804c1 commit 8ed829f
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 12 deletions.
145 changes: 135 additions & 10 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,13 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
out_attrs->at(0) = kDefaultStorage;
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
if (param.transpose_a && kCSRStorage == (*in_attrs)[0]
&& kDefaultStorage == (*in_attrs)[1]) {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage);
} else {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage);
}
return true;
}

Expand All @@ -493,8 +499,14 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 2U);
out_attrs->at(0) = kDefaultStorage;
out_attrs->at(1) = kDefaultStorage;
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage);
if (!param.transpose_a && kDefaultStorage == (*in_attrs)[0]
&& kCSRStorage == (*in_attrs)[1] && kDefaultStorage == (*in_attrs)[2]) {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kRowSparseStorage);
} else {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kDefaultStorage);
}
return true;
}

Expand Down Expand Up @@ -642,6 +654,45 @@ struct DotCsrTransDnsDnsByRowBlocks {
}
};

/*!
* \brief Kernel of dot(csr.T(), dns) = rsp
* Parallelization by row blocks.
* This kernel fills up the row_idx array
* of the rsp with 1 for nonzero rows and 0
* for zero rows.
* The matrix will be compacted after this kernel call.
*/
struct DotCsrTransDnsRspByRowBlocks {
/*!
* \brief
* \param i the i-th thread
*/
template<typename DType, typename RType, typename IType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx, const DType* data_l,
const IType* indptr_l, const CType* col_idx_l,
const DType* data_r, const size_t seg_len,
const size_t num_rows_l, const size_t num_rows,
const size_t num_cols) {
const size_t seg_start = i * seg_len;
if (seg_start >= num_rows) return;
const size_t seg_end = (i + 1) * seg_len;
for (size_t j = 0; j < num_rows_l; ++j) {
if (indptr_l[j] == indptr_l[j+1]) continue;
const size_t offset_r = j * num_cols;
for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
const auto col_idx = col_idx_l[k];
if (col_idx < seg_start || col_idx >= seg_end) continue;
const size_t offset_out = col_idx * num_cols;
row_idx[col_idx] = 1;
const auto val = data_l[k];
for (size_t l = 0; l < num_cols; ++l) {
out[offset_out+l] += data_r[offset_r+l] * val;
}
}
}
}
};

template<typename xpu>
void DotCsrDnsDnsImpl(const OpContext& ctx,
const NDArray& lhs,
Expand Down Expand Up @@ -702,6 +753,75 @@ void DotCsrDnsDnsImpl(const OpContext& ctx,
});
}

template<typename xpu>
void DotCsrDnsRspImpl(const OpContext& ctx,
const NDArray& lhs,
const TBlob& rhs,
const OpReqType req,
const bool trans_lhs,
NDArray* ret) {
if (kNullOp == req) return;
CHECK_EQ(lhs.storage_type(), kCSRStorage);
CHECK_EQ(ret->storage_type(), kRowSparseStorage);
if (!lhs.storage_initialized()) return;

mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob data_l = lhs.data();
const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
const TBlob& data_r = rhs;

// pre-allocate spaces for ret using the dense dimension size
ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])});
const TBlob data_out = ret->data();
const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx);

MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, { // col idx type
if (std::is_same<xpu, cpu>::value) { // cpu parallelization by row blocks
if (kWriteTo == req) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
s, data_out.Size(), data_out.dptr<DType>());
}
RType* row_idx = row_idx_out.dptr<RType>();
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
s, row_idx_out.Size(), row_idx);
int num_threads = mxnet_op::get_num_threads<xpu>(data_out.shape_[0]);
size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
if (trans_lhs) {
mxnet_op::Kernel<DotCsrTransDnsRspByRowBlocks, xpu>::Launch(s, num_threads,
data_out.dptr<DType>(), row_idx, data_l.dptr<DType>(),
indptr_l.dptr<IType>(), col_idx_l.dptr<CType>(), data_r.dptr<DType>(),
seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]);
index_t nnr = 0;
nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr);
ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr));
ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1]));
if (0 == nnr) return;
mshadow::Tensor<xpu, 2, DType> rsp_data = data_out.FlatTo2D<xpu, DType>(s);
size_t idx = 0;
for (index_t i = 0; i < ret->shape()[0]; ++i) {
if (row_idx[i] > 0) {
row_idx[idx] = i;
mshadow::Copy(rsp_data[idx], rsp_data[i], s);
++idx;
}
}
} else {
LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet."
" Only the cpu version of dot(csr.T, dns)=rsp is supported now";
}
} else {
LOG(FATAL) << "DotCsrDnsRspImpl has not implemented GPU version yet.";
}
});
});
});
});
}

template<typename xpu>
void DotCsrRspDnsImpl(const OpContext& ctx,
const NDArray& lhs,
Expand Down Expand Up @@ -803,10 +923,12 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs,
out_stype == kDefaultStorage) {
TBlob ret = outputs[0].data();
DotCsrRspDnsImpl<xpu>(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret);
} else { // TODO(junwu): add fallback
LOG(FATAL) << "Not supported dot operation for lhs.storage_type = "
<< inputs[0].storage_type() << ", rhs.storage_type = " << inputs[1].storage_type()
<< ", out.storage_type = " << outputs[0].storage_type();
} else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage
&& out_stype == kRowSparseStorage) {
NDArray out = outputs[0];
DotCsrDnsRspImpl<xpu>(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out);
} else {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, DotForward_<xpu>, "DotForward_");
}
}

Expand All @@ -823,7 +945,6 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs,
<< "sparse dot does not support computing the gradient of the csr/lhs";
CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace";

// TODO(junwu): check whether this CHECK is reasonable
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)";
auto ograd_stype = inputs[0].storage_type();
Expand All @@ -836,11 +957,15 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs,
// dns, csr, dns => *, dns
DotBackwardCsrDnsDns<xpu>(attrs, ctx, inputs, req, outputs);
} else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage &&
rhs_stype == kRowSparseStorage && outputs[1].storage_type() == kDefaultStorage) {
rhs_stype == kRowSparseStorage && outputs[1].storage_type() == kDefaultStorage) {
// dns, csr, rsp => *, dns
DotBackwardCsrRspDns<xpu>(attrs, ctx, inputs, req, outputs);
} else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage &&
rhs_stype == kDefaultStorage && outputs[1].storage_type() == kRowSparseStorage) {
NDArray grad_rhs = outputs[1];
DotCsrDnsRspImpl<xpu>(ctx, inputs[1], inputs[2].data(), req[1], !param.transpose_a, &grad_rhs);
} else {
LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients";
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, DotBackward_<xpu>, "DotBackward_");
}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def fm_model(k, feature_dim):
x = mx.symbol.Variable("data", storage_type='csr')
v = mx.symbol.Variable("v", shape=(feature_dim, k), init=norm, storage_type='row_sparse')

w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm)
w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm, storage_type='row_sparse')
w1 = mx.symbol.dot(x, w1_weight)

v_s = mx.symbol.sum(data=mx.symbol.square(data=v), axis=1)
Expand Down
7 changes: 6 additions & 1 deletion tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,18 @@ def test_dns_to_csr(dns_in):
test_csr_to_dns((4, 4))
test_dns_to_csr([[0, 1, 0], [0, 2, 0], [3, 0, 0], [0, 0, 4], [5, 6, 0], [0, 0, 7]])


def test_sparse_dot():
def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs):
lhs_dns = rand_ndarray(lhs_shape, 'default')
lhs_nd = mx.nd.cast_storage(lhs_dns, storage_type='csr')
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=1)
rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense()
out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs)
assert out.storage_type == 'default'
if trans_lhs:
assert out.storage_type == 'row_sparse'
else:
assert out.storage_type == 'default'
out_expected = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs)
out_np = out_expected.asnumpy()
backward_trans = not trans_lhs
Expand All @@ -132,6 +136,7 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs):
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False)
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True)


def test_sparse_embedding():
in_dim = 10
out_dim = 4
Expand Down

0 comments on commit 8ed829f

Please sign in to comment.