Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Sparse dot enhancement #6842

Merged
merged 7 commits into from
Jun 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 135 additions & 10 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,13 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
out_attrs->at(0) = kDefaultStorage;
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
if (param.transpose_a && kCSRStorage == (*in_attrs)[0]
&& kDefaultStorage == (*in_attrs)[1]) {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage);
} else {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage);
}
return true;
}

Expand All @@ -493,8 +499,14 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 2U);
out_attrs->at(0) = kDefaultStorage;
out_attrs->at(1) = kDefaultStorage;
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage);
if (!param.transpose_a && kDefaultStorage == (*in_attrs)[0]
&& kCSRStorage == (*in_attrs)[1] && kDefaultStorage == (*in_attrs)[2]) {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kRowSparseStorage);
} else {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kDefaultStorage);
}
return true;
}

Expand Down Expand Up @@ -642,6 +654,45 @@ struct DotCsrTransDnsDnsByRowBlocks {
}
};

/*!
* \brief Kernel of dot(csr.T(), dns) = rsp
* Parallelization by row blocks.
* This kernel fills up the row_idx array
* of the rsp with 1 for nonzero rows and 0
* for zero rows.
* The matrix will be compacted after this kernel call.
*/
struct DotCsrTransDnsRspByRowBlocks {
/*!
* \brief
* \param i the i-th thread
*/
template<typename DType, typename RType, typename IType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx, const DType* data_l,
const IType* indptr_l, const CType* col_idx_l,
const DType* data_r, const size_t seg_len,
const size_t num_rows_l, const size_t num_rows,
const size_t num_cols) {
const size_t seg_start = i * seg_len;
if (seg_start >= num_rows) return;
const size_t seg_end = (i + 1) * seg_len;
for (size_t j = 0; j < num_rows_l; ++j) {
if (indptr_l[j] == indptr_l[j+1]) continue;
const size_t offset_r = j * num_cols;
for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
const auto col_idx = col_idx_l[k];
if (col_idx < seg_start || col_idx >= seg_end) continue;
const size_t offset_out = col_idx * num_cols;
row_idx[col_idx] = 1;
const auto val = data_l[k];
for (size_t l = 0; l < num_cols; ++l) {
out[offset_out+l] += data_r[offset_r+l] * val;
}
}
}
}
};

template<typename xpu>
void DotCsrDnsDnsImpl(const OpContext& ctx,
const NDArray& lhs,
Expand Down Expand Up @@ -702,6 +753,75 @@ void DotCsrDnsDnsImpl(const OpContext& ctx,
});
}

template<typename xpu>
void DotCsrDnsRspImpl(const OpContext& ctx,
const NDArray& lhs,
const TBlob& rhs,
const OpReqType req,
const bool trans_lhs,
NDArray* ret) {
if (kNullOp == req) return;
CHECK_EQ(lhs.storage_type(), kCSRStorage);
CHECK_EQ(ret->storage_type(), kRowSparseStorage);
if (!lhs.storage_initialized()) return;

mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob data_l = lhs.data();
const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
const TBlob& data_r = rhs;

// pre-allocate spaces for ret using the dense dimension size
ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])});
const TBlob data_out = ret->data();
const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx);

MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, { // col idx type
if (std::is_same<xpu, cpu>::value) { // cpu parallelization by row blocks
if (kWriteTo == req) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
s, data_out.Size(), data_out.dptr<DType>());
}
RType* row_idx = row_idx_out.dptr<RType>();
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
s, row_idx_out.Size(), row_idx);
int num_threads = mxnet_op::get_num_threads<xpu>(data_out.shape_[0]);
size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
if (trans_lhs) {
mxnet_op::Kernel<DotCsrTransDnsRspByRowBlocks, xpu>::Launch(s, num_threads,
data_out.dptr<DType>(), row_idx, data_l.dptr<DType>(),
indptr_l.dptr<IType>(), col_idx_l.dptr<CType>(), data_r.dptr<DType>(),
seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]);
index_t nnr = 0;
nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr);
ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr));
ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1]));
if (0 == nnr) return;
mshadow::Tensor<xpu, 2, DType> rsp_data = data_out.FlatTo2D<xpu, DType>(s);
size_t idx = 0;
for (index_t i = 0; i < ret->shape()[0]; ++i) {
if (row_idx[i] > 0) {
row_idx[idx] = i;
mshadow::Copy(rsp_data[idx], rsp_data[i], s);
++idx;
}
}
} else {
LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet."
" Only the cpu version of dot(csr.T, dns)=rsp is supported now";
}
} else {
LOG(FATAL) << "DotCsrDnsRspImpl has not implemented GPU version yet.";
}
});
});
});
});
}

template<typename xpu>
void DotCsrRspDnsImpl(const OpContext& ctx,
const NDArray& lhs,
Expand Down Expand Up @@ -803,10 +923,12 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs,
out_stype == kDefaultStorage) {
TBlob ret = outputs[0].data();
DotCsrRspDnsImpl<xpu>(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret);
} else { // TODO(junwu): add fallback
LOG(FATAL) << "Not supported dot operation for lhs.storage_type = "
<< inputs[0].storage_type() << ", rhs.storage_type = " << inputs[1].storage_type()
<< ", out.storage_type = " << outputs[0].storage_type();
} else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage
&& out_stype == kRowSparseStorage) {
NDArray out = outputs[0];
DotCsrDnsRspImpl<xpu>(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out);
} else {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, DotForward_<xpu>, "DotForward_");
}
}

Expand All @@ -823,7 +945,6 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs,
<< "sparse dot does not support computing the gradient of the csr/lhs";
CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace";

// TODO(junwu): check whether this CHECK is reasonable
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)";
auto ograd_stype = inputs[0].storage_type();
Expand All @@ -836,11 +957,15 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs,
// dns, csr, dns => *, dns
DotBackwardCsrDnsDns<xpu>(attrs, ctx, inputs, req, outputs);
} else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage &&
rhs_stype == kRowSparseStorage && outputs[1].storage_type() == kDefaultStorage) {
rhs_stype == kRowSparseStorage && outputs[1].storage_type() == kDefaultStorage) {
// dns, csr, rsp => *, dns
DotBackwardCsrRspDns<xpu>(attrs, ctx, inputs, req, outputs);
} else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage &&
rhs_stype == kDefaultStorage && outputs[1].storage_type() == kRowSparseStorage) {
NDArray grad_rhs = outputs[1];
DotCsrDnsRspImpl<xpu>(ctx, inputs[1], inputs[2].data(), req[1], !param.transpose_a, &grad_rhs);
} else {
LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients";
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, DotBackward_<xpu>, "DotBackward_");
}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def fm_model(k, feature_dim):
x = mx.symbol.Variable("data", storage_type='csr')
v = mx.symbol.Variable("v", shape=(feature_dim, k), init=norm, storage_type='row_sparse')

w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm)
w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm, storage_type='row_sparse')
w1 = mx.symbol.dot(x, w1_weight)

v_s = mx.symbol.sum(data=mx.symbol.square(data=v), axis=1)
Expand Down
7 changes: 6 additions & 1 deletion tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,18 @@ def test_dns_to_csr(dns_in):
test_csr_to_dns((4, 4))
test_dns_to_csr([[0, 1, 0], [0, 2, 0], [3, 0, 0], [0, 0, 4], [5, 6, 0], [0, 0, 7]])


def test_sparse_dot():
def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs):
lhs_dns = rand_ndarray(lhs_shape, 'default')
lhs_nd = mx.nd.cast_storage(lhs_dns, storage_type='csr')
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=1)
rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense()
out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs)
assert out.storage_type == 'default'
if trans_lhs:
assert out.storage_type == 'row_sparse'
else:
assert out.storage_type == 'default'
out_expected = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs)
out_np = out_expected.asnumpy()
backward_trans = not trans_lhs
Expand All @@ -132,6 +136,7 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs):
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False)
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True)


def test_sparse_embedding():
in_dim = 10
out_dim = 4
Expand Down