Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-497] fix bugs in MKLDNN operators to handle the kAddTo request #11129

Merged
merged 69 commits into from
Jul 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
31fdc8b
fix lint
azai91 Jun 2, 2018
b644e02
requests added to opattr
azai91 Jun 4, 2018
612f64f
comment out addto
azai91 Jun 5, 2018
2c4b41e
can invalidate kAddTo request mkldarrays
azai91 Jun 6, 2018
2dc646a
revert adding kAddTo to invalidate
azai91 Jun 6, 2018
5278ef9
use copy of output instead of creating new array
azai91 Jun 6, 2018
3adcd8d
convert output to default if fallback
azai91 Jun 6, 2018
2489b86
do not make copy when init
azai91 Jun 6, 2018
c7e64f3
copyex fallback copies to old array with kAddTo
azai91 Jun 6, 2018
67001ce
change input mem desc to output mem desc if not equal
azai91 Jun 6, 2018
5a75f53
reorder memory in commitoutput
azai91 Jun 6, 2018
f5b63fc
allocate temp memory
azai91 Jun 7, 2018
4d52987
fix var names
azai91 Jun 7, 2018
6b62e97
create helper reorder function to handle diff format/shapes
azai91 Jun 7, 2018
9da3655
fix typos
azai91 Jun 7, 2018
c0c38ca
fix typos
azai91 Jun 7, 2018
2338046
remove unused code
azai91 Jun 7, 2018
f974c3c
fix param
azai91 Jun 7, 2018
918a864
fix header files
azai91 Jun 7, 2018
50fc6ca
force input memory to output
azai91 Jun 7, 2018
a9915be
reorder2default keeps pointer to mkldnn memory
azai91 Jun 7, 2018
630c091
pass reference
azai91 Jun 7, 2018
aa6c406
remove extra lines
azai91 Jun 7, 2018
75c5160
do not get raw mem from ptr
azai91 Jun 8, 2018
f65ea9c
remove isView check
azai91 Jun 8, 2018
3483f28
fallback writes back to output
azai91 Jun 8, 2018
0428e0f
remove redundant line
azai91 Jun 8, 2018
1cdd60c
remove commented out code
azai91 Jun 8, 2018
c9e8f85
use fallback in copy (refactor)
azai91 Jun 8, 2018
996d0ef
remove unused header
azai91 Jun 8, 2018
4532209
fix lint
azai91 Jun 8, 2018
410c491
reorder2default only if mkldnn flag
azai91 Jun 11, 2018
2efdc3b
only reorder if mkldnn
azai91 Jun 11, 2018
dc3cd8d
does not assume 1 output
azai91 Jun 12, 2018
ad66611
sum compares input and output shape
azai91 Jun 13, 2018
860fa21
compare address and pd in sum
azai91 Jun 13, 2018
a727eea
refactor mkldnnsum
azai91 Jun 13, 2018
c76aee3
fix const param
azai91 Jun 13, 2018
64422aa
fix header
azai91 Jun 13, 2018
ac2b3a1
Merge branch 'master' into test-kAddTo
azai91 Jun 25, 2018
bb10946
improve control flow when setting output blob
azai91 Jun 25, 2018
ad31578
fix merge
azai91 Jun 25, 2018
0e03c96
remove kaddto comment
azai91 Jun 25, 2018
6ef7b87
add reqests to operators
azai91 Jun 25, 2018
90c9acb
fix spacing
azai91 Jun 25, 2018
7d0f275
do sum in place
azai91 Jun 25, 2018
3edf492
fix conditionals
azai91 Jun 25, 2018
5c20e46
remove redundant reqs
azai91 Jun 25, 2018
cd70dac
use wait to read all
azai91 Jun 25, 2018
0972ffa
fix lint
azai91 Jun 25, 2018
637c76a
create multiple outputs
azai91 Jun 26, 2018
5718651
create multiple copies for kaddto
azai91 Jun 26, 2018
d91df93
retrigger
azai91 Jun 26, 2018
993c7aa
retriggrer
azai91 Jun 26, 2018
e7d18be
merge
azai91 Jun 26, 2018
e2a464d
retrigger
azai91 Jun 26, 2018
dc742c8
retrigger
azai91 Jun 26, 2018
92c50f0
another retrigger
azai91 Jun 27, 2018
eb97f3d
Merge branch 'master' into test-kAddTo
azai91 Jun 27, 2018
113903a
retrigger
azai91 Jun 27, 2018
ecbde64
retrigger
azai91 Jun 27, 2018
be84769
another another retrigger
azai91 Jun 27, 2018
5181420
Merge branch 'master' into test-kAddTo
azai91 Jun 27, 2018
0731a58
merge
azai91 Jun 29, 2018
ad3c70e
fix merge
azai91 Jun 29, 2018
2874d0a
retrigger
azai91 Jun 29, 2018
0e249f7
merge
azai91 Jul 2, 2018
581495f
add kAddto to relu op
azai91 Jul 3, 2018
9e7c22e
retrigger
azai91 Jul 5, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/common/exec_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,21 @@ inline bool SetupDefaultBlobsOut(const std::vector<NDArray>& src,
is_default = nd.IsDefaultData();
#endif
if (!is_default) {
NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(),
true, nd.dtype());
#if MXNET_USE_MKLDNN == 1
NDArray temp;
if (bufs != nullptr) {
temp = bufs->at(i);
} else if (kAddTo == req->at(i) && nd.IsMKLDNNData()) {
temp = nd.Reorder2Default();
} else if (kAddTo == req->at(i)) {
temp = nd;
} else {
temp = NDArray(nd.shape(), nd.ctx(), true, nd.dtype());
}
CHECK(temp.IsDefaultData());
#else
NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(),
true, nd.dtype());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent.

#endif
temp_src->emplace_back(nd);
temp_dst->emplace_back(temp);
Expand Down
6 changes: 3 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,

auto input_mem = in_buffer.GetMKLDNNData();
MKLDNNActForward &fwd = GetActForward(param, ctx, in_buffer, *input_mem);
auto out_mem = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer);
fwd.SetNewMem(*input_mem, *out_mem.second);
auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer);
fwd.SetNewMem(*input_mem, *out_mem_t.second);
stream->RegisterPrim(fwd.GetFwd());
CommitOutput(out_data, out_mem);
CommitOutput(out_data, out_mem_t);
stream->Submit();
}

Expand Down
17 changes: 14 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <atomic>
#include "./mkldnn_base-inl.h"
#include "./mkldnn_ops-inl.h"
#include "../../../common/exec_utils.h"
#include "../../operator_common.h"

namespace mxnet {
Expand Down Expand Up @@ -393,18 +394,28 @@ void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs,
MKLDNNStream::Get()->Submit();

std::vector<TBlob> out_blobs(outputs.size());
std::vector<NDArray> temp_src, temp_dst;
for (size_t i = 0; i < out_blobs.size(); i++) {
NDArray output = outputs[i];
// ensure output does not use mkldnn mem.
// for inplace, we already converted & copied input above.
if ((req[i] == kWriteTo) || (req[i] == kWriteInplace))
if ((req[i] == kWriteTo) || (req[i] == kWriteInplace)) {
const_cast<NDArray &>(output).InvalidateMKLDNNData();
else if (req[i] == kAddTo)
output = outputs[i].Reorder2Default();
} else if (req[i] == kAddTo && output.IsMKLDNNData()) {
NDArray temp = outputs[i].Reorder2Default();
temp_src.emplace_back(temp);
temp_dst.emplace_back(outputs[i]);
output = temp;
}
CHECK(output.IsDefaultData());
out_blobs[i] = output.data();
}

fn(attrs, ctx, in_blobs, req, out_blobs);
for (size_t i = 0; i < out_blobs.size(); i++) {
if (req[i] == kAddTo && outputs[i].IsMKLDNNData())
mxnet::common::CastNonDefaultStorage(temp_src, temp_dst, ctx, false);
}
}

template<typename DType>
Expand Down
15 changes: 7 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,13 @@ void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
auto in_mem = data.GetMKLDNNData();
if (req == kAddTo) {
TmpMemMgr::Get()->Init(ctx.requested[0]);
// We should try and force the output memory has the same format
// as the input memory. If not, we'll have to reorder memory.
auto out_mem = out_data.GetMKLDNNData(in_mem->get_primitive_desc());
if (out_mem == nullptr)
out_mem = out_data.GetMKLDNNData();
auto sum_res = TmpMemMgr::Get()->Alloc(out_mem->get_primitive_desc());
MKLDNNSum(*in_mem, *out_mem, *sum_res);
const_cast<NDArray &>(out_data).CopyFrom(*sum_res);
// We should try and force the input memory has the same format
// as the input output. If not, we'll have to reorder memory.
auto out_mem = out_data.GetMKLDNNData();
in_mem = data.GetMKLDNNData(out_mem ->get_primitive_desc());
if (in_mem == nullptr)
in_mem = data.GetMKLDNNDataReorder(out_mem->get_primitive_desc());
MKLDNNSum(*out_mem, *in_mem, *out_mem);
} else {
const_cast<NDArray &>(out_data).CopyFrom(*in_mem);
}
Expand Down
4 changes: 1 addition & 3 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,7 @@ static void CopyEx(const nnvm::NodeAttrs& attrs,
// This happens if inputs are supposed to be in MKLDNN format
// but MKLDNN doesn't support the data type or the shape. We're
// forced to convert it to the default format.
std::vector<TBlob> in_blobs {inputs[0].data()};
std::vector<TBlob> out_blobs {outputs[0].data()};
UnaryOp::IdentityCompute<cpu>(attrs, ctx, in_blobs, req, out_blobs);
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
#endif
Expand Down
144 changes: 102 additions & 42 deletions tests/cpp/operator/mkldnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <cmath>
#include <climits>
#include <set>
#include "gtest/gtest.h"
#include "mxnet/imperative.h"
#include "../../src/operator/nn/mkldnn/mkldnn_base-inl.h"
Expand Down Expand Up @@ -363,6 +364,7 @@ struct NDArrayAttrs {
struct OpAttrs {
nnvm::NodeAttrs attrs;
std::vector<DispatchMode> dispatches;
std::set<OpReqType> requests;
int num_inputs;
int num_outputs;
};
Expand All @@ -375,6 +377,9 @@ OpAttrs GetCopyOp() {
attrs.dispatches.resize(2);
attrs.dispatches[0] = DispatchMode::kFCompute;
attrs.dispatches[1] = DispatchMode::kFComputeEx;
attrs.requests.insert(OpReqType::kWriteTo);
attrs.requests.insert(OpReqType::kWriteInplace);
attrs.requests.insert(OpReqType::kAddTo);
return attrs;
}

Expand All @@ -386,6 +391,9 @@ OpAttrs GetCopyBackwardsOp() {
attrs.dispatches.resize(2);
attrs.dispatches[0] = DispatchMode::kFCompute;
attrs.dispatches[1] = DispatchMode::kFComputeEx;
attrs.requests.insert(OpReqType::kWriteTo);
attrs.requests.insert(OpReqType::kWriteInplace);
attrs.requests.insert(OpReqType::kAddTo);
return attrs;
}

Expand All @@ -399,6 +407,9 @@ OpAttrs GetReluOp() {
attrs.dispatches.resize(2);
attrs.dispatches[0] = DispatchMode::kFCompute;
attrs.dispatches[1] = DispatchMode::kFComputeEx;
attrs.requests.insert(OpReqType::kWriteTo);
attrs.requests.insert(OpReqType::kWriteInplace);
attrs.requests.insert(OpReqType::kAddTo);
return attrs;
}

Expand All @@ -412,6 +423,9 @@ OpAttrs GetReluBackwardsOp() {
attrs.dispatches.resize(2);
attrs.dispatches[0] = DispatchMode::kFCompute;
attrs.dispatches[1] = DispatchMode::kFComputeEx;
attrs.requests.insert(OpReqType::kWriteTo);
attrs.requests.insert(OpReqType::kWriteInplace);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why there isn't a kAdd test here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

attrs.requests.insert(OpReqType::kAddTo);
return attrs;
}

Expand All @@ -423,6 +437,9 @@ OpAttrs GetSumOp() {
attrs.dispatches.resize(2);
attrs.dispatches[0] = DispatchMode::kFCompute;
attrs.dispatches[1] = DispatchMode::kFComputeEx;
attrs.requests.insert(OpReqType::kWriteTo);
attrs.requests.insert(OpReqType::kWriteInplace);
attrs.requests.insert(OpReqType::kAddTo);
return attrs;
}

Expand All @@ -434,6 +451,9 @@ OpAttrs GetSumBackwardsOp() {
attrs.dispatches.resize(2);
attrs.dispatches[0] = DispatchMode::kFCompute;
attrs.dispatches[1] = DispatchMode::kFComputeEx;
attrs.requests.insert(OpReqType::kWriteTo);
attrs.requests.insert(OpReqType::kWriteInplace);
attrs.requests.insert(OpReqType::kAddTo);
return attrs;
}

Expand Down Expand Up @@ -821,6 +841,21 @@ void VerifyConcatResult(const std::vector<NDArray *> &in_arrs,
}
}

void VerifyAddRequest(const std::vector<NDArray*> &in_arrs,
const std::vector<NDArray*> &original_outputs,
const std::vector<NDArray*> &new_outputs,
VerifyFunc verify_fn) {
CHECK(original_outputs.size() == new_outputs.size());
std::vector<NDArray*> tmp_outputs;
NDArray tmp;
for (size_t i = 0; i < new_outputs.size(); i++) {
tmp = new_outputs[i]->Reorder2Default() - original_outputs[i]->Reorder2Default();
tmp_outputs.push_back(&tmp);
}
Engine::Get()->WaitForAll();
verify_fn(in_arrs, tmp_outputs);
}

void VerifyConcatBackwardsResult(const std::vector<NDArray *> &in_arrs,
const std::vector<NDArray *> &out_arrs) {
// in_arrs is larger array, out_arr is ammler
Expand All @@ -846,15 +881,6 @@ void VerifyConcatBackwardsResult(const std::vector<NDArray *> &in_arrs,
}
}

void VerifyAddRequest(const std::vector<NDArray*> &in_arrs,
const std::vector<NDArray*> &original_outputs,
const std::vector<NDArray*> &new_outputs,
VerifyFunc verify_fn) {
NDArray tmp = new_outputs[0]->Reorder2Default() - original_outputs[0]->Reorder2Default();
tmp.WaitToRead();
verify_fn(in_arrs, {&tmp});
}

TEST(MKLDNN_NDArray, CopyFrom) {
TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
Expand All @@ -879,54 +905,88 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
std::vector<NDArray*> inputs(attrs.num_inputs);
std::vector<NDArray*> outputs(attrs.num_outputs);
std::vector<OpReqType> req(attrs.num_outputs);
std::vector<NDArrayAttrs> in_arrs;
std::vector<std::vector<NDArrayAttrs>> out_arrs(attrs.num_outputs);
std::vector<DispatchMode> dispatches = attrs.dispatches;

TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
for (auto &in_arr : in_arrs) {
if (attrs.requests.find(OpReqType::kWriteTo) != attrs.requests.end()) {
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
for (auto &in_arr : in_arrs) {
for (auto &dispatch : dispatches) {
std::vector<std::vector<NDArrayAttrs>> out_arrs(attrs.num_outputs);
for (int i = 0; i < attrs.num_outputs; i++)
out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds);
for (int i = 0; i < attrs.num_inputs; i++)
inputs[i] = &in_arr.arr;
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
for (int i = 0; i < attrs.num_outputs; i++) {
req[i] = kWriteTo;
outputs[i] = &out_arrs[i][output_i].arr;
}
PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
outputs, req, dispatch, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
verify_fn(inputs, outputs);
}
}
}
}

if (attrs.requests.find(OpReqType::kWriteInplace) != attrs.requests.end()) {
for (auto &dispatch : dispatches) {
std::vector<std::vector<NDArrayAttrs>> out_arrs(attrs.num_outputs);
for (int i = 0; i < attrs.num_outputs; i++)
out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds);
for (int i = 0; i < attrs.num_inputs; i++)
inputs[i] = &in_arr.arr;
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
in_arrs = GetTestInputArrays();
for (auto &arr : in_arrs) {
// If the array is a view, we shouldn't write data to it.
if (arr.arr.IsView())
continue;
NDArrayAttrs orig(arr.arr.Copy(arr.arr.ctx()), "InPlace Copy");
for (int i = 0; i < attrs.num_inputs; i++)
inputs[i] = &arr.arr;
for (int i = 0; i < attrs.num_outputs; i++) {
req[i] = kWriteTo;
outputs[i] = &out_arrs[i][output_i].arr;
req[i] = kWriteInplace;
outputs[i] = &arr.arr;
}
PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
outputs, req, dispatch, mxnet::OpStatePtr());
PrintVerifyMsg(orig, arr);
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req,
dispatch, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
verify_fn(inputs, outputs);
std::vector<NDArray *> orig_inputs(attrs.num_inputs);
for (int i = 0; i < attrs.num_inputs; i++)
orig_inputs[i] = &orig.arr;
verify_fn(orig_inputs, outputs);
}
}
}

for (auto &dispatch : dispatches) {
if (attrs.requests.find(OpReqType::kAddTo) != attrs.requests.end()) {
std::vector<NDArray*> original_outputs(attrs.num_outputs);
in_arrs = GetTestInputArrays();
for (auto &arr : in_arrs) {
// If the array is a view, we shouldn't write data to it.
if (arr.arr.IsView())
continue;
NDArrayAttrs orig(arr.arr.Copy(arr.arr.ctx()), "InPlace Copy");
for (int i = 0; i < attrs.num_inputs; i++)
inputs[i] = &arr.arr;
for (int i = 0; i < attrs.num_outputs; i++) {
req[i] = kWriteInplace;
outputs[i] = &arr.arr;
for (auto &in_arr : in_arrs) {
for (auto &dispatch : dispatches) {
for (int i = 0; i < attrs.num_outputs; i++)
out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds);
for (size_t i = 0; i < attrs.num_inputs; i++)
inputs[i] = &in_arr.arr;
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
NDArray tmp;
for (size_t i = 0; i < attrs.num_outputs; i++) {
auto out_arr = out_arrs[i][output_i];
tmp = out_arr.arr.Copy(out_arr.arr.ctx());
original_outputs[i] = &tmp;
outputs[i] = &out_arrs[i][output_i].arr;
req[i] = kAddTo;
}
PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
outputs, req, dispatch, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
VerifyAddRequest(inputs, original_outputs, outputs, verify_fn);
}
}
PrintVerifyMsg(orig, arr);
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req,
dispatch, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
std::vector<NDArray *> orig_inputs(attrs.num_inputs);
for (int i = 0; i < attrs.num_inputs; i++)
orig_inputs[i] = &orig.arr;
verify_fn(orig_inputs, outputs);
}
}
}
Expand Down