From 192cc5dd3260bede2ff9cadd90f9249d853f0cf0 Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Tue, 13 Mar 2018 11:07:08 -0400 Subject: [PATCH 1/4] Implementation of MKLDNN LRN --- paddle/fluid/operators/lrn_mkldnn_op.cc | 189 ++++++++++++++++++ paddle/fluid/operators/lrn_op.cc | 55 ++++- .../fluid/tests/unittests/test_lrn_op.py | 10 + 3 files changed, 253 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/lrn_mkldnn_op.cc diff --git a/paddle/fluid/operators/lrn_mkldnn_op.cc b/paddle/fluid/operators/lrn_mkldnn_op.cc new file mode 100644 index 0000000000000..334597ab05ecc --- /dev/null +++ b/paddle/fluid/operators/lrn_mkldnn_op.cc @@ -0,0 +1,189 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/lrn_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; + +namespace { +mkldnn::algorithm LRNAlgorithm(const paddle::framework::ExecutionContext& ctx) { + mkldnn::algorithm algorithm = mkldnn::lrn_across_channels; + + std::string algorithm_str = ctx.Attr("algorithm"); + if (algorithm_str == "WITHIN_CHANNEL") { + algorithm = mkldnn::lrn_within_channel; + } + return algorithm; +} +} // namespace + +template +class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(std::is_same::value, + "MKLDNN LRN must use float data."); + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "MKLDNN LRN must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + auto mid = ctx.Output("MidOut"); + + auto input_data = x->data(); + auto output_data = out->mutable_data(ctx.GetPlace()); + mid->mutable_data(ctx.GetPlace()); + + const std::string key = ctx.op().Output("Out"); + const std::string key_src_memory = key + "@lrn_src_memory"; + const std::string key_pd = key + "@lrn_pd"; + const std::string key_workspace_memory = key + "@lrn_workspace_memory"; + + const int n = ctx.Attr("n"); + const float alpha = ctx.Attr("alpha"); + const float beta = ctx.Attr("beta"); + const float k = ctx.Attr("k"); + + auto algorithm = LRNAlgorithm(ctx); + + auto e_mid = framework::EigenTensor::From(*mid); + e_mid = e_mid.constant(k); + + auto dims = paddle::framework::vectorize2int(x->dims()); + + auto src_md = paddle::platform::MKLDNNMemDesc( + dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); + + auto dst_md = paddle::platform::MKLDNNMemDesc( + dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); + + auto forward_desc = mkldnn::lrn_forward::desc{ + mkldnn::prop_kind::forward, algorithm, src_md, n, alpha, beta, k}; + + auto forward_pd = std::make_shared( + forward_desc, mkldnn_engine); + + dev_ctx.SetBlob(key_pd, forward_pd); + + auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine}; + auto src_memory = std::make_shared( + src_memory_pd, static_cast(const_cast(input_data))); + + dev_ctx.SetBlob(key_src_memory, src_memory); + auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine}, + static_cast(output_data)}; + + auto workspace_md = forward_pd->workspace_primitive_desc(); + auto workspace_memory = std::make_shared(workspace_md); + + dev_ctx.SetBlob(key_workspace_memory, workspace_memory); + + auto forward_op = mkldnn::lrn_forward{*forward_pd, *src_memory, + *workspace_memory, dst_memory}; + + std::vector pipeline = {forward_op}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + } +}; + +template +class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(std::is_same::value, + "MKLDNN LRN must use float data."); + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "MKLDNN LRN must use CPUPlace."); + + auto x = ctx.Input("X"); + + auto out_grad = ctx.Input(framework::GradVarName("Out")); + auto x_grad = ctx.Output(framework::GradVarName("X")); + + const std::string key = ctx.op().Input("Out"); + const std::string key_src_memory = key + "@lrn_src_memory"; + const std::string key_pd = key + "@lrn_pd"; + const std::string key_workspace_memory = key + "@lrn_workspace_memory"; + + const int n = ctx.Attr("n"); + const float alpha = ctx.Attr("alpha"); + const float beta = ctx.Attr("beta"); + const float k = ctx.Attr("k"); + + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + auto x_grad_data = x_grad->mutable_data(ctx.GetPlace()); + auto out_grad_data = out_grad->data(); + + auto dims = paddle::framework::vectorize2int(x->dims()); + + auto src_md = paddle::platform::MKLDNNMemDesc( + dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); + + auto diff_src_md = paddle::platform::MKLDNNMemDesc( + dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); + + auto diff_dst_md = paddle::platform::MKLDNNMemDesc( + dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); + + auto diff_dst_memory = + mkldnn::memory{{diff_dst_md, mkldnn_engine}, + static_cast(const_cast(out_grad_data))}; + + auto diff_src_memory = mkldnn::memory{{diff_src_md, mkldnn_engine}, + static_cast(x_grad_data)}; + + auto algorithm = LRNAlgorithm(ctx); + + auto backward_desc = mkldnn::lrn_backward::desc{ + algorithm, src_md, diff_src_md, n, alpha, beta, k}; + + auto forward_pd = dev_ctx.GetBlob(key_pd); + + auto backward_pd = mkldnn::lrn_backward::primitive_desc{ + backward_desc, mkldnn_engine, + *static_cast(forward_pd.get())}; + + std::shared_ptr workspace_memory = + dev_ctx.GetBlob(key_workspace_memory); + + auto src_memory = dev_ctx.GetBlob(key_src_memory); + auto backward_op = mkldnn::lrn_backward{ + backward_pd, *static_cast(src_memory.get()), + diff_dst_memory, *static_cast(workspace_memory.get()), + diff_src_memory}; + + std::vector pipeline = {backward_op}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(lrn, MKLDNN, paddle::platform::CPUPlace, + ops::LRNMKLDNNOpKernel); +REGISTER_OP_KERNEL(lrn_grad, MKLDNN, paddle::platform::CPUPlace, + ops::LRNMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 692e85dcffa58..6bd451a118afe 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/lrn_op.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -135,6 +138,24 @@ class LRNOp : public framework::OperatorWithKernel { ctx->SetOutputDim("MidOut", x_dim); ctx->ShareLoD("X", /*->*/ "Out"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library_{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + } +#endif + + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + layout_, library_); + } }; template @@ -176,6 +197,21 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker { "beta is the power number.") .SetDefault(0.75) .GreaterThan(0.0); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("AnyLayout"); + AddAttr("algorithm", + "(string default ACROSS_CHANNELS" + "An optional string: \"ACROSS_CHANNELS\", " + "\"WITHIN_CHANNEL\". Used by MKLDNN library") + .SetDefault("ACROSS_CHANNELS"); AddComment(R"DOC( Local Response Normalization Operator. @@ -223,8 +259,25 @@ class LRNOpGrad : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } -}; + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library_{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + } +#endif + + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + layout_, library_); + } +}; } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_lrn_op.py b/python/paddle/fluid/tests/unittests/test_lrn_op.py index eaff45cbb2a58..2268eafdbd08c 100644 --- a/python/paddle/fluid/tests/unittests/test_lrn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lrn_op.py @@ -87,5 +87,15 @@ def test_check_grad_normal(self): self.check_grad(['X'], 'Out', max_relative_error=0.01) +class TestLRNMKLDNNOp(TestLRNOp): + def get_attrs(self): + attrs = TestLRNOp.get_attrs(self) + attrs['use_mkldnn'] = True + return attrs + + def test_check_output(self): + self.check_output(atol=0.002) + + if __name__ == "__main__": unittest.main() From c51c446221ce63890a0c099da7f26b9bfa41cb48 Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Fri, 16 Mar 2018 10:05:54 -0400 Subject: [PATCH 2/4] Content of GetExpectedKernelType moved to standalone function --- paddle/fluid/operators/lrn_op.cc | 54 ++++++++++++++------------------ 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 6bd451a118afe..00db09ece3215 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -119,6 +119,26 @@ struct LRNGradFunctor { template struct LRNGradFunctor; template struct LRNGradFunctor; +namespace { + framework::OpKernelType GetExpectedLRNKernel( + const framework::ExecutionContext& ctx) { + framework::LibraryType library_{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + } +#endif + + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + layout_, library_); + } +} + class LRNOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -140,21 +160,8 @@ class LRNOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - framework::LibraryType library_{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - } -#endif - - std::string data_format = ctx.Attr("data_format"); - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout_, library_); + const framework::ExecutionContext& ctx) const override { + return GetExpectedLRNKernel(ctx); } }; @@ -261,21 +268,8 @@ class LRNOpGrad : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - framework::LibraryType library_{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - } -#endif - - std::string data_format = ctx.Attr("data_format"); - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout_, library_); + const framework::ExecutionContext& ctx) const override { + return GetExpectedLRNKernel(ctx); } }; } // namespace operators From 2d95527527fe3b27e06f254965c8eb4fbacb4abf Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Mon, 19 Mar 2018 06:10:27 -0400 Subject: [PATCH 3/4] Removing WITHIN_CHANNEL algorithm for lrn. CPU lrn operator works only with ACROSS_CHANNELS --- paddle/fluid/operators/lrn_mkldnn_op.cc | 27 ++++++-------------- paddle/fluid/operators/lrn_op.cc | 33 +++++++++++-------------- 2 files changed, 22 insertions(+), 38 deletions(-) diff --git a/paddle/fluid/operators/lrn_mkldnn_op.cc b/paddle/fluid/operators/lrn_mkldnn_op.cc index 334597ab05ecc..a2971fcd14469 100644 --- a/paddle/fluid/operators/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/lrn_mkldnn_op.cc @@ -22,18 +22,6 @@ namespace operators { using paddle::framework::Tensor; using paddle::platform::MKLDNNDeviceContext; -namespace { -mkldnn::algorithm LRNAlgorithm(const paddle::framework::ExecutionContext& ctx) { - mkldnn::algorithm algorithm = mkldnn::lrn_across_channels; - - std::string algorithm_str = ctx.Attr("algorithm"); - if (algorithm_str == "WITHIN_CHANNEL") { - algorithm = mkldnn::lrn_within_channel; - } - return algorithm; -} -} // namespace - template class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -64,8 +52,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { const float beta = ctx.Attr("beta"); const float k = ctx.Attr("k"); - auto algorithm = LRNAlgorithm(ctx); - auto e_mid = framework::EigenTensor::From(*mid); e_mid = e_mid.constant(k); @@ -77,8 +63,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_md = paddle::platform::MKLDNNMemDesc( dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); - auto forward_desc = mkldnn::lrn_forward::desc{ - mkldnn::prop_kind::forward, algorithm, src_md, n, alpha, beta, k}; + auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward, + mkldnn::lrn_across_channels, + src_md, + n, + alpha, + beta, + k}; auto forward_pd = std::make_shared( forward_desc, mkldnn_engine); @@ -154,10 +145,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto diff_src_memory = mkldnn::memory{{diff_src_md, mkldnn_engine}, static_cast(x_grad_data)}; - auto algorithm = LRNAlgorithm(ctx); - auto backward_desc = mkldnn::lrn_backward::desc{ - algorithm, src_md, diff_src_md, n, alpha, beta, k}; + mkldnn::lrn_across_channels, src_md, diff_src_md, n, alpha, beta, k}; auto forward_pd = dev_ctx.GetBlob(key_pd); diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 00db09ece3215..bd72f0435e524 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -120,24 +120,24 @@ template struct LRNGradFunctor; template struct LRNGradFunctor; namespace { - framework::OpKernelType GetExpectedLRNKernel( - const framework::ExecutionContext& ctx) { - framework::LibraryType library_{framework::LibraryType::kPlain}; +framework::OpKernelType GetExpectedLRNKernel( + const framework::ExecutionContext& ctx) { + framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - } + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + } #endif - std::string data_format = ctx.Attr("data_format"); - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout_, library_); - } + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + layout_, library_); } +} // namespace class LRNOp : public framework::OperatorWithKernel { public: @@ -214,11 +214,6 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker { "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") .SetDefault("AnyLayout"); - AddAttr("algorithm", - "(string default ACROSS_CHANNELS" - "An optional string: \"ACROSS_CHANNELS\", " - "\"WITHIN_CHANNEL\". Used by MKLDNN library") - .SetDefault("ACROSS_CHANNELS"); AddComment(R"DOC( Local Response Normalization Operator. From 72cc64e40e5d624bcc97bd81f144fcb446167a21 Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Wed, 21 Mar 2018 10:20:29 -0400 Subject: [PATCH 4/4] Device blobs are created only in training. Added testing attribute --- paddle/fluid/operators/lrn_mkldnn_op.cc | 71 ++++++++++++++++++------- paddle/fluid/operators/lrn_op.cc | 1 + 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/lrn_mkldnn_op.cc b/paddle/fluid/operators/lrn_mkldnn_op.cc index a2971fcd14469..3bead16ce44c2 100644 --- a/paddle/fluid/operators/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/lrn_mkldnn_op.cc @@ -22,6 +22,22 @@ namespace operators { using paddle::framework::Tensor; using paddle::platform::MKLDNNDeviceContext; +namespace { +template +std::shared_ptr insert_to_context(const std::string& key, + const MKLDNNDeviceContext& dev_ctx, + Args&&... args) { + auto p = std::static_pointer_cast(dev_ctx.GetBlob(key)); + + if (!p) { + p = std::make_shared(args...); + dev_ctx.SetBlob(key, std::static_pointer_cast(p)); + } + + return p; +} +} // namespace + template class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -42,15 +58,11 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto output_data = out->mutable_data(ctx.GetPlace()); mid->mutable_data(ctx.GetPlace()); - const std::string key = ctx.op().Output("Out"); - const std::string key_src_memory = key + "@lrn_src_memory"; - const std::string key_pd = key + "@lrn_pd"; - const std::string key_workspace_memory = key + "@lrn_workspace_memory"; - const int n = ctx.Attr("n"); const float alpha = ctx.Attr("alpha"); const float beta = ctx.Attr("beta"); const float k = ctx.Attr("k"); + const bool is_test = ctx.Attr("is_test"); auto e_mid = framework::EigenTensor::From(*mid); e_mid = e_mid.constant(k); @@ -71,28 +83,47 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { beta, k}; - auto forward_pd = std::make_shared( - forward_desc, mkldnn_engine); - - dev_ctx.SetBlob(key_pd, forward_pd); - auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine}; - auto src_memory = std::make_shared( - src_memory_pd, static_cast(const_cast(input_data))); - - dev_ctx.SetBlob(key_src_memory, src_memory); auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine}, static_cast(output_data)}; - auto workspace_md = forward_pd->workspace_primitive_desc(); - auto workspace_memory = std::make_shared(workspace_md); + std::unique_ptr forward_op = nullptr; + + if (!is_test) { + const std::string key = ctx.op().Output("Out"); + const std::string key_src_memory = key + "@lrn_src_memory"; + const std::string key_pd = key + "@lrn_pd"; + const std::string key_workspace_memory = key + "@lrn_workspace_memory"; + + auto forward_pd = insert_to_context( + key_pd, dev_ctx, forward_desc, mkldnn_engine); + + auto src_memory = insert_to_context( + key_src_memory, dev_ctx, src_memory_pd); + + src_memory->set_data_handle( + static_cast(const_cast(input_data))); + + auto workspace_memory = insert_to_context( + key_workspace_memory, dev_ctx, + forward_pd->workspace_primitive_desc()); + + forward_op.reset(new mkldnn::lrn_forward{*forward_pd, *src_memory, + *workspace_memory, dst_memory}); - dev_ctx.SetBlob(key_workspace_memory, workspace_memory); + } else { + auto forward_pd = + mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine}; + auto src_memory = mkldnn::memory{ + src_memory_pd, static_cast(const_cast(input_data))}; + auto workspace_memory = + mkldnn::memory{forward_pd.workspace_primitive_desc()}; - auto forward_op = mkldnn::lrn_forward{*forward_pd, *src_memory, - *workspace_memory, dst_memory}; + forward_op.reset(new mkldnn::lrn_forward{forward_pd, src_memory, + workspace_memory, dst_memory}); + } - std::vector pipeline = {forward_op}; + std::vector pipeline = {*forward_op}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); } }; diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index bd72f0435e524..2b1947a187bbd 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -214,6 +214,7 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker { "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") .SetDefault("AnyLayout"); + AddAttr("is_test", "").SetDefault(false); AddComment(R"DOC( Local Response Normalization Operator.