Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improve copy sparse tensors #7003

Merged
merged 2 commits into from
Jul 15, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,29 +398,41 @@ void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) {
// if storage type doesn't match, cast the storage first
auto from_stype = from.storage_type();
auto to_stype = to->storage_type();
NDArray casted_nd;
if (from_stype != to_stype) {
TShape shape = from.shape();
auto from_ctx = from.ctx();
auto s = ctx.get_stream<from_xpu>();
// TODO(haibin) inplace conversion
CHECK(from_stype == kDefaultStorage
|| to_stype == kDefaultStorage
|| from_stype == to_stype)
<< "Copying ndarray of stype = " << from_stype
<< " to stype = " << to_stype << " is not supported";
const auto from_ctx = from.ctx();
const auto to_ctx = to->ctx();
auto s = ctx.get_stream<from_xpu>();
if (from_ctx == to_ctx && from_stype != to_stype) {
// same ctx, different stypes, use cast op directly without copying
common::CastStorageDispatch<from_xpu>(s, from, *to);
} else {
NDArray casted_nd; // an intermediate result before copying from to to
if (from_stype == to_stype) {
casted_nd = from; // same stype, no need to cast from
} else { // different stypes on different ctx needs an temporary casted_nd
TShape shape = from.shape();
if (to_stype == kDefaultStorage) {
casted_nd = NDArray(shape, from_ctx);
} else {
casted_nd = NDArray(to_stype, shape, from_ctx);
}
// convert from_nd to the same stype as to_nd
common::CastStorageDispatch<from_xpu>(s, from, casted_nd);
}

if (to_stype == kDefaultStorage) {
casted_nd = NDArray(shape, from_ctx);
CopyFromToDnsImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kRowSparseStorage) {
CopyFromToRspImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kCSRStorage) {
CopyFromToCsrImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else {
casted_nd = NDArray(to_stype, shape, from_ctx);
LOG(FATAL) << "unknown storage type" << to_stype;
}
common::CastStorageDispatch<from_xpu>(s, from, casted_nd);
} else {
casted_nd = from;
}
if (to_stype == kDefaultStorage) {
CopyFromToDnsImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kRowSparseStorage) {
CopyFromToRspImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kCSRStorage) {
CopyFromToCsrImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else {
LOG(FATAL) << "unknown storage type" << to_stype;
}
if (is_same<from_xpu, mshadow::gpu>::value || is_same<to_xpu, mshadow::gpu>::value) {
// Wait GPU kernel to complete
Expand Down