Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
CreateMKLDNNMem accepts input
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Jun 15, 2018
1 parent c5dc502 commit 1cceb3c
Show file tree
Hide file tree
Showing 14 changed files with 31 additions and 52 deletions.
12 changes: 6 additions & 6 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &out_data) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);

NDArray in_buffer = in_data;
std::vector<NDArray> in_buffer = {in_data};
MKLDNNStream *stream = MKLDNNStream::Get();

if (in_data.IsView() && in_data.IsMKLDNNData())
in_buffer = in_data.Reorder2Default();
in_buffer[0] = in_data.Reorder2Default();

auto input_mem = in_buffer.GetMKLDNNData();
MKLDNNActForward &fwd = GetActForward(param, ctx, in_buffer, *input_mem);
auto out_mem = CreateMKLDNNMemory(out_data, in_buffer, fwd.fwd_pd.dst_primitive_desc(), req);
auto input_mem = in_buffer[0].GetMKLDNNData();
MKLDNNActForward &fwd = GetActForward(param, ctx, in_buffer[0], *input_mem);
auto out_mem = CreateMKLDNNMem(out_data, in_buffer, fwd.fwd_pd.dst_primitive_desc(), req);
fwd.SetNewMem(*input_mem, *out_mem.second);
stream->RegisterPrim(fwd.GetFwd());
CommitOutput(out_data, out_mem);
Expand Down Expand Up @@ -210,7 +210,7 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
fw_pdesc);

diff_src_memory = CreateMKLDNNMem(in_grad,
diff_src_memory = CreateMKLDNNMem(in_grad, {out_grad, in_data},
bw_pdesc.diff_src_primitive_desc(), req);
stream->RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem,
*diff_dst_memory,
Expand Down
7 changes: 2 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,10 @@ typedef std::pair<OutDataOp, mkldnn::memory *> mkldnn_output_t;
* If these two functions are used, we have to call CommitOutput to write
* the output back to the output NDArray.
*/
mkldnn_output_t CreateMKLDNNMem(const NDArray &arr,
mkldnn_output_t CreateMKLDNNMem(const NDArray &out_arr,
const std::vector<NDArray> in_arrs,
const mkldnn::memory::primitive_desc &desc,
OpReqType req);
mkldnn_output_t CreateMKLDNNMemory(const NDArray &out_arr,
const NDArray &in_arr,
const mkldnn::memory::primitive_desc &desc,
OpReqType req);
mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray &arr,
const mkldnn::memory::primitive_desc &desc,
OpReqType req);
Expand Down
26 changes: 3 additions & 23 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,36 +77,16 @@ mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::primitive_desc &pd) {
}
}

mkldnn_output_t CreateMKLDNNMem(const NDArray &arr,
mkldnn_output_t CreateMKLDNNMem(const NDArray &out_arr,
const std::vector<NDArray> &in_arrs,
const mkldnn::memory::primitive_desc &desc,
OpReqType req) {
if (kAddTo == req) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::AddBack, tmp);
} else if (req == kWriteInplace) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::CopyBack, tmp);
}
mkldnn::memory *mem = const_cast<NDArray &>(arr).CreateMKLDNNData(desc);
if (mem == nullptr) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::CopyBack, tmp);
}
return mkldnn_output_t(OutDataOp::Noop, mem);
}

mkldnn_output_t CreateMKLDNNMemory(const NDArray &out_arr,
const NDArray &in_arr,
const mkldnn::memory::primitive_desc &desc,
OpReqType req) {
if (kAddTo == req) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::AddBack, tmp);
} else if (req == kWriteInplace) {
// can only WriteInPlace if data_handle and pdesc are the same
// we assume arr is both input and output
if (out_arr.GetMKLDNNData()->get_primitive_desc() == desc &&
in_arr.GetMKLDNNData()->get_data_handle() == out_arr.GetMKLDNNData()->get_data_handle()) {
in_arrs[0].GetMKLDNNData()->get_data_handle() == out_arr.GetMKLDNNData()->get_data_handle()) {
mkldnn::memory *mem = const_cast<NDArray &>(out_arr).CreateMKLDNNData(desc);
return mkldnn_output_t(OutDataOp::Noop, mem);
}
Expand Down
5 changes: 3 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
}
MKLDNNConcatFwd &fwd = GetConcatForward(concat_dim, in_data, data_md);
mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data[concat_enum::kOut],
in_data,
fwd.fwd_pd.dst_primitive_desc(),
req[concat_enum::kOut]);
fwd.SetNewMem(data_mem, *out_mem.second);
Expand All @@ -124,7 +125,7 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
}

void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
Expand All @@ -143,7 +144,7 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
static_cast<int>(inputs[i+1].shape()[2]),
static_cast<int>(inputs[i+1].shape()[3])};
auto diff_src_mpd = inputs[i+1].GetMKLDNNData()->get_primitive_desc();
auto gradi_mem_ = CreateMKLDNNMem(outputs[i], diff_src_mpd, req[i]);
auto gradi_mem_ = CreateMKLDNNMem(outputs[i], inputs, diff_src_mpd, req[i]);
// create view from gy to gxs[i]
std::shared_ptr<mkldnn::view::primitive_desc> view_pd;
view_pd.reset(new mkldnn::view::primitive_desc(gz_pd, diff_src_tz, offsets));
Expand Down
6 changes: 3 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
}
}
auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(),
auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], in_data, fwd.fwd_pd.dst_primitive_desc(),
req[conv::kOut]);
const mkldnn::memory *bias_mem = nullptr;
if (!param.no_bias)
Expand Down Expand Up @@ -303,7 +303,7 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
if (req[conv::kData]) {
auto weight_mem = GetWeights(inputs[conv::kWeight + 1],
bwdData_pd.weights_primitive_desc(), param.num_group);
auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData],
auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData], inputs,
bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data(bwdData_pd,
*out_grad_mem, *weight_mem, *in_grad_mem.second));
Expand All @@ -327,7 +327,7 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
} else {
in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias],
in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias], inputs,
bwdWeights_pd.diff_bias_primitive_desc(),
req[conv::kBias]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param,
CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc());
}
}
auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut],
auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut], in_data,
fwd_pd.diff_src_primitive_desc(), req[deconv::kOut]);
auto output = out_mem.second;
this->data->set_data_handle(data_mem->get_data_handle());
Expand Down Expand Up @@ -325,7 +325,7 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &
auto weight_mem = GetWeights(inputs[deconv::kWeight + 1],
bwdData_pd.weights_primitive_desc(),
param.num_group);
auto in_grad_mem = CreateMKLDNNMem(in_grad[deconv::kData],
auto in_grad_mem = CreateMKLDNNMem(in_grad[deconv::kData], inputs,
bwdData_pd.dst_primitive_desc(),
req[deconv::kData]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_forward(bwdData_pd,
Expand Down
6 changes: 3 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train);
auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc());
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], in_data,
ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]);
if (param.no_bias) {
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(
Expand Down Expand Up @@ -167,7 +167,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdData_pd.diff_dst_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc());
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], inputs,
ipBwdData_pd.diff_src_primitive_desc(),
req[fullc::kData]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data(
Expand All @@ -189,7 +189,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
} else {
in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], inputs,
ipBwdWeights_pd.diff_bias_primitive_desc(),
req[fullc::kBias]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
Expand Down
3 changes: 2 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_lrn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam &param,
if (req == kNullOp) {
return;
}
std::vector<NDArray> inputs = {in_data};
// Repeat FW for getting workspace
const mkldnn::memory *data_mem = in_data.GetMKLDNNData();
const mkldnn::memory::desc data_md = data_mem->get_primitive_desc().desc();
Expand All @@ -223,7 +224,7 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam &param,
const mkldnn::memory::desc diff_md = diff_mem->get_primitive_desc().desc();
const mkldnn::lrn_backward::primitive_desc pdesc_bwd = GetLRNBwd(param, data_in_md,
diff_md, pdesc_fwd);
mkldnn_output_t diff_src_mem = CreateMKLDNNMem(in_grad,
mkldnn_output_t diff_src_mem = CreateMKLDNNMem(in_grad, inputs,
pdesc_bwd.diff_src_primitive_desc(), req);

MKLDNNStream::Get()->RegisterPrim(
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,
if (req == kNullOp) {
return;
}

std::vector<NDArray> inputs = {out_grad, in_data};
TmpMemMgr::Get()->Init(ctx.requested[0]);
// mkldnn::memory
auto diff_dst_mem = out_grad.GetMKLDNNData();
Expand Down Expand Up @@ -312,7 +312,7 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,
const pooling_backward::primitive_desc pdesc(desc, cpu_engine, pdesc_fwd);

auto diff_src_mem =
CreateMKLDNNMem(in_grad, pdesc.diff_src_primitive_desc(), req);
CreateMKLDNNMem(in_grad, inputs, pdesc.diff_src_primitive_desc(), req);

if (MKLDNNRequireWorkspace(param)) {
CHECK(workspace != nullptr);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
} else {
// req == kWriteInplace but cannot be handled by mkldnn and
// req == kAddTo will run into this branch
auto mem = CreateMKLDNNMem(out_data, pdesc.dst_primitive_desc(), req);
auto mem = CreateMKLDNNMem(out_data, inputs, pdesc.dst_primitive_desc(), req);
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *mem.second));
CommitOutput(out_data, mem);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ static void MKLDNNDequantizeComputeKer(const std::vector<NDArray> &inputs,
i_fmt);
auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr);
auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]);
auto o_mem = CreateMKLDNNMem(outputs[0], inputs, o_mpd, req[0]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second));
CommitOutput(outputs[0], o_mem);
MKLDNNStream::Get()->Submit();
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/mkldnn/mkldnn_quantize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ static void MKLDNNQuantizeComputeKer(const std::vector<NDArray>& inputs,
i_fmt);
auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr);
auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]);
auto o_mem = CreateMKLDNNMem(outputs[0], inputs, o_mpd, req[0]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second));
CommitOutput(outputs[0], o_mem);
MKLDNNStream::Get()->Submit();
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs,
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
}
auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(),
auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], in_data, fwd.fwd_pd.dst_primitive_desc(),
req[conv::kOut]);
const mkldnn::memory *bias_mem = nullptr;
if (!param.no_bias)
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/mkldnn/mkldnn_requantize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ static void MKLDNNRequantizeForwardKer(const nnvm::NodeAttrs& attrs,
i_fmt);
auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr);
auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]);
auto o_mem = CreateMKLDNNMem(outputs[0], inputs, o_mpd, req[0]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second));
CommitOutput(outputs[0], o_mem);
MKLDNNStream::Get()->Submit();
Expand Down

0 comments on commit 1cceb3c

Please sign in to comment.