diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index 731d03d9be2b..816599b955c1 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -98,10 +98,21 @@ inline bool SetupDefaultBlobsOut(const std::vector& src, is_default = nd.IsDefaultData(); #endif if (!is_default) { - NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), - true, nd.dtype()); #if MXNET_USE_MKLDNN == 1 + NDArray temp; + if (bufs != nullptr) { + temp = bufs->at(i); + } else if (kAddTo == req->at(i) && nd.IsMKLDNNData()) { + temp = nd.Reorder2Default(); + } else if (kAddTo == req->at(i)) { + temp = nd; + } else { + temp = NDArray(nd.shape(), nd.ctx(), true, nd.dtype()); + } CHECK(temp.IsDefaultData()); +#else + NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), + true, nd.dtype()); #endif temp_src->emplace_back(nd); temp_dst->emplace_back(temp); diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index b21d1238f7aa..744fed2c299f 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -168,10 +168,10 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, auto input_mem = in_buffer.GetMKLDNNData(); MKLDNNActForward &fwd = GetActForward(param, ctx, in_buffer, *input_mem); - auto out_mem = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer); - fwd.SetNewMem(*input_mem, *out_mem.second); + auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer); + fwd.SetNewMem(*input_mem, *out_mem_t.second); stream->RegisterPrim(fwd.GetFwd()); - CommitOutput(out_data, out_mem); + CommitOutput(out_data, out_mem_t); stream->Submit(); } diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 2c8dea895823..4e4982e96ee5 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -22,6 +22,7 @@ #include #include "./mkldnn_base-inl.h" #include "./mkldnn_ops-inl.h" +#include "../../../common/exec_utils.h" #include "../../operator_common.h" namespace mxnet { @@ -393,18 +394,28 @@ void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs, MKLDNNStream::Get()->Submit(); std::vector out_blobs(outputs.size()); + std::vector temp_src, temp_dst; for (size_t i = 0; i < out_blobs.size(); i++) { NDArray output = outputs[i]; // ensure output does not use mkldnn mem. // for inplace, we already converted & copied input above. - if ((req[i] == kWriteTo) || (req[i] == kWriteInplace)) + if ((req[i] == kWriteTo) || (req[i] == kWriteInplace)) { const_cast(output).InvalidateMKLDNNData(); - else if (req[i] == kAddTo) - output = outputs[i].Reorder2Default(); + } else if (req[i] == kAddTo && output.IsMKLDNNData()) { + NDArray temp = outputs[i].Reorder2Default(); + temp_src.emplace_back(temp); + temp_dst.emplace_back(outputs[i]); + output = temp; + } CHECK(output.IsDefaultData()); out_blobs[i] = output.data(); } + fn(attrs, ctx, in_blobs, req, out_blobs); + for (size_t i = 0; i < out_blobs.size(); i++) { + if (req[i] == kAddTo && outputs[i].IsMKLDNNData()) + mxnet::common::CastNonDefaultStorage(temp_src, temp_dst, ctx, false); + } } template diff --git a/src/operator/nn/mkldnn/mkldnn_copy.cc b/src/operator/nn/mkldnn/mkldnn_copy.cc index 75e51aff0066..a7c280e1e713 100644 --- a/src/operator/nn/mkldnn/mkldnn_copy.cc +++ b/src/operator/nn/mkldnn/mkldnn_copy.cc @@ -44,14 +44,13 @@ void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx, auto in_mem = data.GetMKLDNNData(); if (req == kAddTo) { TmpMemMgr::Get()->Init(ctx.requested[0]); - // We should try and force the output memory has the same format - // as the input memory. If not, we'll have to reorder memory. - auto out_mem = out_data.GetMKLDNNData(in_mem->get_primitive_desc()); - if (out_mem == nullptr) - out_mem = out_data.GetMKLDNNData(); - auto sum_res = TmpMemMgr::Get()->Alloc(out_mem->get_primitive_desc()); - MKLDNNSum(*in_mem, *out_mem, *sum_res); - const_cast(out_data).CopyFrom(*sum_res); + // We should try and force the input memory has the same format + // as the input output. If not, we'll have to reorder memory. + auto out_mem = out_data.GetMKLDNNData(); + in_mem = data.GetMKLDNNData(out_mem ->get_primitive_desc()); + if (in_mem == nullptr) + in_mem = data.GetMKLDNNDataReorder(out_mem->get_primitive_desc()); + MKLDNNSum(*out_mem, *in_mem, *out_mem); } else { const_cast(out_data).CopyFrom(*in_mem); } diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 5b89d49f4304..30a0035acb90 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -171,9 +171,7 @@ static void CopyEx(const nnvm::NodeAttrs& attrs, // This happens if inputs are supposed to be in MKLDNN format // but MKLDNN doesn't support the data type or the shape. We're // forced to convert it to the default format. - std::vector in_blobs {inputs[0].data()}; - std::vector out_blobs {outputs[0].data()}; - UnaryOp::IdentityCompute(attrs, ctx, in_blobs, req, out_blobs); + FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, inputs, req, outputs); return; } #endif diff --git a/tests/cpp/operator/mkldnn.cc b/tests/cpp/operator/mkldnn.cc index 8e01216527c8..9a7d2fccd745 100644 --- a/tests/cpp/operator/mkldnn.cc +++ b/tests/cpp/operator/mkldnn.cc @@ -27,6 +27,7 @@ #include #include +#include #include "gtest/gtest.h" #include "mxnet/imperative.h" #include "../../src/operator/nn/mkldnn/mkldnn_base-inl.h" @@ -363,6 +364,7 @@ struct NDArrayAttrs { struct OpAttrs { nnvm::NodeAttrs attrs; std::vector dispatches; + std::set requests; int num_inputs; int num_outputs; }; @@ -375,6 +377,9 @@ OpAttrs GetCopyOp() { attrs.dispatches.resize(2); attrs.dispatches[0] = DispatchMode::kFCompute; attrs.dispatches[1] = DispatchMode::kFComputeEx; + attrs.requests.insert(OpReqType::kWriteTo); + attrs.requests.insert(OpReqType::kWriteInplace); + attrs.requests.insert(OpReqType::kAddTo); return attrs; } @@ -386,6 +391,9 @@ OpAttrs GetCopyBackwardsOp() { attrs.dispatches.resize(2); attrs.dispatches[0] = DispatchMode::kFCompute; attrs.dispatches[1] = DispatchMode::kFComputeEx; + attrs.requests.insert(OpReqType::kWriteTo); + attrs.requests.insert(OpReqType::kWriteInplace); + attrs.requests.insert(OpReqType::kAddTo); return attrs; } @@ -399,6 +407,9 @@ OpAttrs GetReluOp() { attrs.dispatches.resize(2); attrs.dispatches[0] = DispatchMode::kFCompute; attrs.dispatches[1] = DispatchMode::kFComputeEx; + attrs.requests.insert(OpReqType::kWriteTo); + attrs.requests.insert(OpReqType::kWriteInplace); + attrs.requests.insert(OpReqType::kAddTo); return attrs; } @@ -412,6 +423,9 @@ OpAttrs GetReluBackwardsOp() { attrs.dispatches.resize(2); attrs.dispatches[0] = DispatchMode::kFCompute; attrs.dispatches[1] = DispatchMode::kFComputeEx; + attrs.requests.insert(OpReqType::kWriteTo); + attrs.requests.insert(OpReqType::kWriteInplace); + attrs.requests.insert(OpReqType::kAddTo); return attrs; } @@ -423,6 +437,9 @@ OpAttrs GetSumOp() { attrs.dispatches.resize(2); attrs.dispatches[0] = DispatchMode::kFCompute; attrs.dispatches[1] = DispatchMode::kFComputeEx; + attrs.requests.insert(OpReqType::kWriteTo); + attrs.requests.insert(OpReqType::kWriteInplace); + attrs.requests.insert(OpReqType::kAddTo); return attrs; } @@ -434,6 +451,9 @@ OpAttrs GetSumBackwardsOp() { attrs.dispatches.resize(2); attrs.dispatches[0] = DispatchMode::kFCompute; attrs.dispatches[1] = DispatchMode::kFComputeEx; + attrs.requests.insert(OpReqType::kWriteTo); + attrs.requests.insert(OpReqType::kWriteInplace); + attrs.requests.insert(OpReqType::kAddTo); return attrs; } @@ -821,6 +841,21 @@ void VerifyConcatResult(const std::vector &in_arrs, } } +void VerifyAddRequest(const std::vector &in_arrs, + const std::vector &original_outputs, + const std::vector &new_outputs, + VerifyFunc verify_fn) { + CHECK(original_outputs.size() == new_outputs.size()); + std::vector tmp_outputs; + NDArray tmp; + for (size_t i = 0; i < new_outputs.size(); i++) { + tmp = new_outputs[i]->Reorder2Default() - original_outputs[i]->Reorder2Default(); + tmp_outputs.push_back(&tmp); + } + Engine::Get()->WaitForAll(); + verify_fn(in_arrs, tmp_outputs); +} + void VerifyConcatBackwardsResult(const std::vector &in_arrs, const std::vector &out_arrs) { // in_arrs is larger array, out_arr is ammler @@ -846,15 +881,6 @@ void VerifyConcatBackwardsResult(const std::vector &in_arrs, } } -void VerifyAddRequest(const std::vector &in_arrs, - const std::vector &original_outputs, - const std::vector &new_outputs, - VerifyFunc verify_fn) { - NDArray tmp = new_outputs[0]->Reorder2Default() - original_outputs[0]->Reorder2Default(); - tmp.WaitToRead(); - verify_fn(in_arrs, {&tmp}); -} - TEST(MKLDNN_NDArray, CopyFrom) { TestArrayShapes tas = GetTestArrayShapes(); std::vector pds = tas.pds; @@ -879,54 +905,88 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) { std::vector inputs(attrs.num_inputs); std::vector outputs(attrs.num_outputs); std::vector req(attrs.num_outputs); + std::vector in_arrs; + std::vector> out_arrs(attrs.num_outputs); std::vector dispatches = attrs.dispatches; TestArrayShapes tas = GetTestArrayShapes(); std::vector pds = tas.pds; - std::vector in_arrs = GetTestInputArrays(); - for (auto &in_arr : in_arrs) { + if (attrs.requests.find(OpReqType::kWriteTo) != attrs.requests.end()) { + std::vector in_arrs = GetTestInputArrays(); + for (auto &in_arr : in_arrs) { + for (auto &dispatch : dispatches) { + std::vector> out_arrs(attrs.num_outputs); + for (int i = 0; i < attrs.num_outputs; i++) + out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds); + for (int i = 0; i < attrs.num_inputs; i++) + inputs[i] = &in_arr.arr; + for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { + for (int i = 0; i < attrs.num_outputs; i++) { + req[i] = kWriteTo; + outputs[i] = &out_arrs[i][output_i].arr; + } + PrintVerifyMsg(in_arr, out_arrs[0][output_i]); + Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, + outputs, req, dispatch, mxnet::OpStatePtr()); + Engine::Get()->WaitForAll(); + verify_fn(inputs, outputs); + } + } + } + } + + if (attrs.requests.find(OpReqType::kWriteInplace) != attrs.requests.end()) { for (auto &dispatch : dispatches) { - std::vector> out_arrs(attrs.num_outputs); - for (int i = 0; i < attrs.num_outputs; i++) - out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds); - for (int i = 0; i < attrs.num_inputs; i++) - inputs[i] = &in_arr.arr; - for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { + in_arrs = GetTestInputArrays(); + for (auto &arr : in_arrs) { + // If the array is a view, we shouldn't write data to it. + if (arr.arr.IsView()) + continue; + NDArrayAttrs orig(arr.arr.Copy(arr.arr.ctx()), "InPlace Copy"); + for (int i = 0; i < attrs.num_inputs; i++) + inputs[i] = &arr.arr; for (int i = 0; i < attrs.num_outputs; i++) { - req[i] = kWriteTo; - outputs[i] = &out_arrs[i][output_i].arr; + req[i] = kWriteInplace; + outputs[i] = &arr.arr; } - PrintVerifyMsg(in_arr, out_arrs[0][output_i]); - Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, - outputs, req, dispatch, mxnet::OpStatePtr()); + PrintVerifyMsg(orig, arr); + Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req, + dispatch, mxnet::OpStatePtr()); Engine::Get()->WaitForAll(); - verify_fn(inputs, outputs); + std::vector orig_inputs(attrs.num_inputs); + for (int i = 0; i < attrs.num_inputs; i++) + orig_inputs[i] = &orig.arr; + verify_fn(orig_inputs, outputs); } } } - for (auto &dispatch : dispatches) { + if (attrs.requests.find(OpReqType::kAddTo) != attrs.requests.end()) { + std::vector original_outputs(attrs.num_outputs); in_arrs = GetTestInputArrays(); - for (auto &arr : in_arrs) { - // If the array is a view, we shouldn't write data to it. - if (arr.arr.IsView()) - continue; - NDArrayAttrs orig(arr.arr.Copy(arr.arr.ctx()), "InPlace Copy"); - for (int i = 0; i < attrs.num_inputs; i++) - inputs[i] = &arr.arr; - for (int i = 0; i < attrs.num_outputs; i++) { - req[i] = kWriteInplace; - outputs[i] = &arr.arr; + for (auto &in_arr : in_arrs) { + for (auto &dispatch : dispatches) { + for (int i = 0; i < attrs.num_outputs; i++) + out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds); + for (size_t i = 0; i < attrs.num_inputs; i++) + inputs[i] = &in_arr.arr; + for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { + NDArray tmp; + for (size_t i = 0; i < attrs.num_outputs; i++) { + auto out_arr = out_arrs[i][output_i]; + tmp = out_arr.arr.Copy(out_arr.arr.ctx()); + original_outputs[i] = &tmp; + outputs[i] = &out_arrs[i][output_i].arr; + req[i] = kAddTo; + } + PrintVerifyMsg(in_arr, out_arrs[0][output_i]); + Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, + outputs, req, dispatch, mxnet::OpStatePtr()); + Engine::Get()->WaitForAll(); + VerifyAddRequest(inputs, original_outputs, outputs, verify_fn); + } } - PrintVerifyMsg(orig, arr); - Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req, - dispatch, mxnet::OpStatePtr()); - Engine::Get()->WaitForAll(); - std::vector orig_inputs(attrs.num_inputs); - for (int i = 0; i < attrs.num_inputs; i++) - orig_inputs[i] = &orig.arr; - verify_fn(orig_inputs, outputs); } } }