Skip to content

Commit

Permalink
[MXNET-497] fix bugs in MKLDNN operators to handle the kAddTo request (
Browse files Browse the repository at this point in the history
…apache#11129)

* fix lint

* requests added to opattr

* comment out addto

* can invalidate kAddTo request mkldarrays

* revert adding kAddTo to invalidate

* use copy of output instead of creating new array

* convert output to default if fallback

* do not make copy when init

* copyex fallback copies to old array with kAddTo

* change input mem desc to output mem desc if not equal

* reorder memory in commitoutput

* allocate temp memory

* fix var names

* create helper reorder function to handle diff format/shapes

* fix typos

* fix typos

* remove unused code

* fix param

* fix header files

* force input memory to output

* reorder2default keeps pointer to mkldnn memory

* pass reference

* remove extra lines

* do not get raw mem from ptr

* remove isView check

* fallback writes back to output

* remove redundant line

* remove commented out code

* use fallback in copy (refactor)

* remove unused header

* fix lint

* reorder2default only if mkldnn flag

* only reorder if mkldnn

* does not assume 1 output

* sum compares input and output shape

* compare address and pd in sum

* refactor mkldnnsum

* fix const param

* fix header

* improve control flow when setting output blob

* fix merge

* remove kaddto comment

* add reqests to operators

* fix spacing

* do sum in place

* fix conditionals

* remove redundant reqs

* use wait to read all

* fix lint

* create multiple outputs

* create multiple copies for kaddto

* retrigger

* retriggrer

* retrigger

* retrigger

* another retrigger

* retrigger

* retrigger

* another another retrigger

* fix merge

* retrigger

* add kAddto to relu op

* retrigger
  • Loading branch information
azai91 authored and eric-haibin-lin committed Jul 8, 2018
1 parent 3c21216 commit 4520454
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 61 deletions.
15 changes: 13 additions & 2 deletions src/common/exec_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,21 @@ inline bool SetupDefaultBlobsOut(const std::vector<NDArray>& src,
is_default = nd.IsDefaultData();
#endif
if (!is_default) {
NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(),
true, nd.dtype());
#if MXNET_USE_MKLDNN == 1
NDArray temp;
if (bufs != nullptr) {
temp = bufs->at(i);
} else if (kAddTo == req->at(i) && nd.IsMKLDNNData()) {
temp = nd.Reorder2Default();
} else if (kAddTo == req->at(i)) {
temp = nd;
} else {
temp = NDArray(nd.shape(), nd.ctx(), true, nd.dtype());
}
CHECK(temp.IsDefaultData());
#else
NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(),
true, nd.dtype());
#endif
temp_src->emplace_back(nd);
temp_dst->emplace_back(temp);
Expand Down
6 changes: 3 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,

auto input_mem = in_buffer.GetMKLDNNData();
MKLDNNActForward &fwd = GetActForward(param, ctx, in_buffer, *input_mem);
auto out_mem = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer);
fwd.SetNewMem(*input_mem, *out_mem.second);
auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer);
fwd.SetNewMem(*input_mem, *out_mem_t.second);
stream->RegisterPrim(fwd.GetFwd());
CommitOutput(out_data, out_mem);
CommitOutput(out_data, out_mem_t);
stream->Submit();
}

Expand Down
17 changes: 14 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <atomic>
#include "./mkldnn_base-inl.h"
#include "./mkldnn_ops-inl.h"
#include "../../../common/exec_utils.h"
#include "../../operator_common.h"

namespace mxnet {
Expand Down Expand Up @@ -393,18 +394,28 @@ void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs,
MKLDNNStream::Get()->Submit();

std::vector<TBlob> out_blobs(outputs.size());
std::vector<NDArray> temp_src, temp_dst;
for (size_t i = 0; i < out_blobs.size(); i++) {
NDArray output = outputs[i];
// ensure output does not use mkldnn mem.
// for inplace, we already converted & copied input above.
if ((req[i] == kWriteTo) || (req[i] == kWriteInplace))
if ((req[i] == kWriteTo) || (req[i] == kWriteInplace)) {
const_cast<NDArray &>(output).InvalidateMKLDNNData();
else if (req[i] == kAddTo)
output = outputs[i].Reorder2Default();
} else if (req[i] == kAddTo && output.IsMKLDNNData()) {
NDArray temp = outputs[i].Reorder2Default();
temp_src.emplace_back(temp);
temp_dst.emplace_back(outputs[i]);
output = temp;
}
CHECK(output.IsDefaultData());
out_blobs[i] = output.data();
}

fn(attrs, ctx, in_blobs, req, out_blobs);
for (size_t i = 0; i < out_blobs.size(); i++) {
if (req[i] == kAddTo && outputs[i].IsMKLDNNData())
mxnet::common::CastNonDefaultStorage(temp_src, temp_dst, ctx, false);
}
}

template<typename DType>
Expand Down
15 changes: 7 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,13 @@ void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
auto in_mem = data.GetMKLDNNData();
if (req == kAddTo) {
TmpMemMgr::Get()->Init(ctx.requested[0]);
// We should try and force the output memory has the same format
// as the input memory. If not, we'll have to reorder memory.
auto out_mem = out_data.GetMKLDNNData(in_mem->get_primitive_desc());
if (out_mem == nullptr)
out_mem = out_data.GetMKLDNNData();
auto sum_res = TmpMemMgr::Get()->Alloc(out_mem->get_primitive_desc());
MKLDNNSum(*in_mem, *out_mem, *sum_res);
const_cast<NDArray &>(out_data).CopyFrom(*sum_res);
// We should try and force the input memory has the same format
// as the input output. If not, we'll have to reorder memory.
auto out_mem = out_data.GetMKLDNNData();
in_mem = data.GetMKLDNNData(out_mem ->get_primitive_desc());
if (in_mem == nullptr)
in_mem = data.GetMKLDNNDataReorder(out_mem->get_primitive_desc());
MKLDNNSum(*out_mem, *in_mem, *out_mem);
} else {
const_cast<NDArray &>(out_data).CopyFrom(*in_mem);
}
Expand Down
4 changes: 1 addition & 3 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ static void CopyEx(const nnvm::NodeAttrs& attrs,
// This happens if inputs are supposed to be in MKLDNN format
// but MKLDNN doesn't support the data type or the shape. We're
// forced to convert it to the default format.
std::vector<TBlob> in_blobs {inputs[0].data()};
std::vector<TBlob> out_blobs {outputs[0].data()};
UnaryOp::IdentityCompute<cpu>(attrs, ctx, in_blobs, req, out_blobs);
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
#endif
Expand Down
144 changes: 102 additions & 42 deletions tests/cpp/operator/mkldnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <cmath>
#include <climits>
#include <set>
#include "gtest/gtest.h"
#include "mxnet/imperative.h"
#include "../../src/operator/nn/mkldnn/mkldnn_base-inl.h"
Expand Down Expand Up @@ -363,6 +364,7 @@ struct NDArrayAttrs {
struct OpAttrs {
nnvm::NodeAttrs attrs;
std::vector<DispatchMode> dispatches;
std::set<OpReqType> requests;
int num_inputs;
int num_outputs;
};
Expand All @@ -375,6 +377,9 @@ OpAttrs GetCopyOp() {
attrs.dispatches.resize(2);
attrs.dispatches[0] = DispatchMode::kFCompute;
attrs.dispatches[1] = DispatchMode::kFComputeEx;
attrs.requests.insert(OpReqType::kWriteTo);
attrs.requests.insert(OpReqType::kWriteInplace);
attrs.requests.insert(OpReqType::kAddTo);
return attrs;
}

Expand All @@ -386,6 +391,9 @@ OpAttrs GetCopyBackwardsOp() {
attrs.dispatches.resize(2);
attrs.dispatches[0] = DispatchMode::kFCompute;
attrs.dispatches[1] = DispatchMode::kFComputeEx;
attrs.requests.insert(OpReqType::kWriteTo);
attrs.requests.insert(OpReqType::kWriteInplace);
attrs.requests.insert(OpReqType::kAddTo);
return attrs;
}

Expand All @@ -399,6 +407,9 @@ OpAttrs GetReluOp() {
attrs.dispatches.resize(2);
attrs.dispatches[0] = DispatchMode::kFCompute;
attrs.dispatches[1] = DispatchMode::kFComputeEx;
attrs.requests.insert(OpReqType::kWriteTo);
attrs.requests.insert(OpReqType::kWriteInplace);
attrs.requests.insert(OpReqType::kAddTo);
return attrs;
}

Expand All @@ -412,6 +423,9 @@ OpAttrs GetReluBackwardsOp() {
attrs.dispatches.resize(2);
attrs.dispatches[0] = DispatchMode::kFCompute;
attrs.dispatches[1] = DispatchMode::kFComputeEx;
attrs.requests.insert(OpReqType::kWriteTo);
attrs.requests.insert(OpReqType::kWriteInplace);
attrs.requests.insert(OpReqType::kAddTo);
return attrs;
}

Expand All @@ -423,6 +437,9 @@ OpAttrs GetSumOp() {
attrs.dispatches.resize(2);
attrs.dispatches[0] = DispatchMode::kFCompute;
attrs.dispatches[1] = DispatchMode::kFComputeEx;
attrs.requests.insert(OpReqType::kWriteTo);
attrs.requests.insert(OpReqType::kWriteInplace);
attrs.requests.insert(OpReqType::kAddTo);
return attrs;
}

Expand All @@ -434,6 +451,9 @@ OpAttrs GetSumBackwardsOp() {
attrs.dispatches.resize(2);
attrs.dispatches[0] = DispatchMode::kFCompute;
attrs.dispatches[1] = DispatchMode::kFComputeEx;
attrs.requests.insert(OpReqType::kWriteTo);
attrs.requests.insert(OpReqType::kWriteInplace);
attrs.requests.insert(OpReqType::kAddTo);
return attrs;
}

Expand Down Expand Up @@ -821,6 +841,21 @@ void VerifyConcatResult(const std::vector<NDArray *> &in_arrs,
}
}

void VerifyAddRequest(const std::vector<NDArray*> &in_arrs,
const std::vector<NDArray*> &original_outputs,
const std::vector<NDArray*> &new_outputs,
VerifyFunc verify_fn) {
CHECK(original_outputs.size() == new_outputs.size());
std::vector<NDArray*> tmp_outputs;
NDArray tmp;
for (size_t i = 0; i < new_outputs.size(); i++) {
tmp = new_outputs[i]->Reorder2Default() - original_outputs[i]->Reorder2Default();
tmp_outputs.push_back(&tmp);
}
Engine::Get()->WaitForAll();
verify_fn(in_arrs, tmp_outputs);
}

void VerifyConcatBackwardsResult(const std::vector<NDArray *> &in_arrs,
const std::vector<NDArray *> &out_arrs) {
// in_arrs is larger array, out_arr is ammler
Expand All @@ -846,15 +881,6 @@ void VerifyConcatBackwardsResult(const std::vector<NDArray *> &in_arrs,
}
}

void VerifyAddRequest(const std::vector<NDArray*> &in_arrs,
const std::vector<NDArray*> &original_outputs,
const std::vector<NDArray*> &new_outputs,
VerifyFunc verify_fn) {
NDArray tmp = new_outputs[0]->Reorder2Default() - original_outputs[0]->Reorder2Default();
tmp.WaitToRead();
verify_fn(in_arrs, {&tmp});
}

TEST(MKLDNN_NDArray, CopyFrom) {
TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
Expand All @@ -879,54 +905,88 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
std::vector<NDArray*> inputs(attrs.num_inputs);
std::vector<NDArray*> outputs(attrs.num_outputs);
std::vector<OpReqType> req(attrs.num_outputs);
std::vector<NDArrayAttrs> in_arrs;
std::vector<std::vector<NDArrayAttrs>> out_arrs(attrs.num_outputs);
std::vector<DispatchMode> dispatches = attrs.dispatches;

TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
for (auto &in_arr : in_arrs) {
if (attrs.requests.find(OpReqType::kWriteTo) != attrs.requests.end()) {
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
for (auto &in_arr : in_arrs) {
for (auto &dispatch : dispatches) {
std::vector<std::vector<NDArrayAttrs>> out_arrs(attrs.num_outputs);
for (int i = 0; i < attrs.num_outputs; i++)
out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds);
for (int i = 0; i < attrs.num_inputs; i++)
inputs[i] = &in_arr.arr;
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
for (int i = 0; i < attrs.num_outputs; i++) {
req[i] = kWriteTo;
outputs[i] = &out_arrs[i][output_i].arr;
}
PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
outputs, req, dispatch, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
verify_fn(inputs, outputs);
}
}
}
}

if (attrs.requests.find(OpReqType::kWriteInplace) != attrs.requests.end()) {
for (auto &dispatch : dispatches) {
std::vector<std::vector<NDArrayAttrs>> out_arrs(attrs.num_outputs);
for (int i = 0; i < attrs.num_outputs; i++)
out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds);
for (int i = 0; i < attrs.num_inputs; i++)
inputs[i] = &in_arr.arr;
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
in_arrs = GetTestInputArrays();
for (auto &arr : in_arrs) {
// If the array is a view, we shouldn't write data to it.
if (arr.arr.IsView())
continue;
NDArrayAttrs orig(arr.arr.Copy(arr.arr.ctx()), "InPlace Copy");
for (int i = 0; i < attrs.num_inputs; i++)
inputs[i] = &arr.arr;
for (int i = 0; i < attrs.num_outputs; i++) {
req[i] = kWriteTo;
outputs[i] = &out_arrs[i][output_i].arr;
req[i] = kWriteInplace;
outputs[i] = &arr.arr;
}
PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
outputs, req, dispatch, mxnet::OpStatePtr());
PrintVerifyMsg(orig, arr);
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req,
dispatch, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
verify_fn(inputs, outputs);
std::vector<NDArray *> orig_inputs(attrs.num_inputs);
for (int i = 0; i < attrs.num_inputs; i++)
orig_inputs[i] = &orig.arr;
verify_fn(orig_inputs, outputs);
}
}
}

for (auto &dispatch : dispatches) {
if (attrs.requests.find(OpReqType::kAddTo) != attrs.requests.end()) {
std::vector<NDArray*> original_outputs(attrs.num_outputs);
in_arrs = GetTestInputArrays();
for (auto &arr : in_arrs) {
// If the array is a view, we shouldn't write data to it.
if (arr.arr.IsView())
continue;
NDArrayAttrs orig(arr.arr.Copy(arr.arr.ctx()), "InPlace Copy");
for (int i = 0; i < attrs.num_inputs; i++)
inputs[i] = &arr.arr;
for (int i = 0; i < attrs.num_outputs; i++) {
req[i] = kWriteInplace;
outputs[i] = &arr.arr;
for (auto &in_arr : in_arrs) {
for (auto &dispatch : dispatches) {
for (int i = 0; i < attrs.num_outputs; i++)
out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds);
for (size_t i = 0; i < attrs.num_inputs; i++)
inputs[i] = &in_arr.arr;
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
NDArray tmp;
for (size_t i = 0; i < attrs.num_outputs; i++) {
auto out_arr = out_arrs[i][output_i];
tmp = out_arr.arr.Copy(out_arr.arr.ctx());
original_outputs[i] = &tmp;
outputs[i] = &out_arrs[i][output_i].arr;
req[i] = kAddTo;
}
PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
outputs, req, dispatch, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
VerifyAddRequest(inputs, original_outputs, outputs, verify_fn);
}
}
PrintVerifyMsg(orig, arr);
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req,
dispatch, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
std::vector<NDArray *> orig_inputs(attrs.num_inputs);
for (int i = 0; i < attrs.num_inputs; i++)
orig_inputs[i] = &orig.arr;
verify_fn(orig_inputs, outputs);
}
}
}
Expand Down

0 comments on commit 4520454

Please sign in to comment.