Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Change idx type switch for aux data #6860

Merged
merged 2 commits into from
Jun 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mshadow
4 changes: 2 additions & 2 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class CommCPU : public Comm {
std::vector<bool> skip(num_in, false);
// the values tensor of the inputs
MSHADOW_TYPE_SWITCH(out->dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(out->aux_type(kIdx), IType, {
std::vector<Tensor<cpu, 2, DType>> in_vals(num_in);
std::vector<Tensor<cpu, 1, IType>> in_indices(num_in);
// offset to the values tensor of all inputs
Expand Down Expand Up @@ -350,7 +350,7 @@ class CommCPU : public Comm {
<< out->storage_type() << " given)";

MSHADOW_TYPE_SWITCH(out->dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(out->aux_type(kIdx), IType, {
std::vector<IType> uniq_row_idx;
GetUniqueRspRowIdx(nds, &uniq_row_idx);
out->CheckAndAlloc({mshadow::Shape1(uniq_row_idx.size())});
Expand Down
12 changes: 6 additions & 6 deletions src/operator/nn/cast_storage-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ inline void CastStorageDnsRspImpl(mshadow::Stream<cpu>* s, const TBlob& dns, NDA
CHECK_EQ(rsp->storage_type(), kRowSparseStorage);
CHECK_EQ(dns.shape_, rsp->shape());
MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type
MSHADOW_INT_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type
MSHADOW_IDX_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type
const index_t num_rows = dns.shape_[0];
const index_t num_cols = dns.shape_[1];
rsp->CheckAndAllocAuxData(rowsparse::kIdx, mshadow::Shape1(num_rows));
Expand Down Expand Up @@ -102,7 +102,7 @@ void CastStorageRspDnsImpl(mshadow::Stream<xpu>* s, const NDArray& rsp, TBlob* d
using namespace mshadow::expr;
CHECK_EQ(rsp.storage_type(), kRowSparseStorage);
MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, {
MSHADOW_INT_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
// assign zeros
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(s, dns->Size(), dns->dptr<DType>());
if (rsp.storage_initialized()) {
Expand Down Expand Up @@ -186,8 +186,8 @@ inline void CastStorageDnsCsrImpl(mshadow::Stream<cpu>* s, const TBlob& dns, NDA
CHECK_EQ(dns.shape_.ndim(), 2);
CHECK_EQ(dns.shape_, csr->shape());
MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type
MSHADOW_INT_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type
MSHADOW_INT_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col idx type
MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col idx type
const index_t num_rows = dns.shape_[0];
const index_t num_cols = dns.shape_[1];
csr->CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(num_rows+1));
Expand Down Expand Up @@ -248,8 +248,8 @@ void CastStorageCsrDnsImpl(mshadow::Stream<xpu>* s, const NDArray& csr, TBlob* d
CHECK_EQ(dns->shape_.ndim(), 2);
CHECK_EQ(dns->shape_, csr.shape());
MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { // data type
MSHADOW_INT_TYPE_SWITCH(csr.aux_type(csr::kIndPtr), IType, { // indptr type
MSHADOW_INT_TYPE_SWITCH(csr.aux_type(csr::kIdx), CType, { // col idx type
MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(csr::kIndPtr), IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(csr::kIdx), CType, { // col idx type
const index_t num_rows = dns->shape_[0];
const index_t num_cols = dns->shape_[1];
DType* dns_data = dns->dptr<DType>();
Expand Down
4 changes: 2 additions & 2 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
CHECK_GT(weight.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
auto weight_data = weight.dptr<DType>();
auto grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>();
Expand Down Expand Up @@ -364,7 +364,7 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
CHECK_GT(mom.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_INT_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
auto weight_data = weight.dptr<DType>();
auto grad_idx = grad.aux_data(kIdx).dptr<IType>();
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs,
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(output_data.type_flag_, DType, { // output data type
MSHADOW_INT_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type
MSHADOW_IDX_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type
MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type
Kernel<set_zero, xpu>::Launch(s, output_data.Size(), output_data.dptr<DType>());
Kernel<SparseRetainRspForward, xpu>::Launch(s, idx_data.Size(), output_data.dptr<DType>(),
Expand Down Expand Up @@ -949,7 +949,7 @@ void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs,
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(out_grad_data.type_flag_, DType, { // output data type
MSHADOW_INT_TYPE_SWITCH(in_grad_idx.type_flag_, RType, { // row index data type
MSHADOW_IDX_TYPE_SWITCH(in_grad_idx.type_flag_, RType, { // row index data type
MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type
MXNET_ASSIGN_REQ_SWITCH(req[sr::kArr], req_type, {
Kernel<SparseRetainRspBackward<req_type>, xpu>::Launch(
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
using namespace mxnet_op;
CHECK_EQ(dst->storage_type(), kRowSparseStorage);
MSHADOW_REAL_TYPE_SWITCH(dst->dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
auto num_rows = dst->shape()[0];
dst->CheckAndAlloc({Shape1(num_rows)});
auto idx = dst->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
Expand Down
8 changes: 4 additions & 4 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,8 @@ void DotCsrDnsDnsImpl(const OpContext& ctx,
const TBlob data_out = *ret;

MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_INT_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
if (std::is_same<xpu, cpu>::value) { // cpu parallelization by row blocks
if (kWriteTo == req) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
Expand Down Expand Up @@ -1157,8 +1157,8 @@ void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
return;
}
// assume idx indptr share the same type
MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, {
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
auto in_indptr = in.aux_data(kIndPtr).dptr<RType>();
auto out_indptr = out.aux_data(kIndPtr).dptr<RType>();
Expand Down