From d7d9b257e687b98026c3914a06da8eb462ef909c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20=C5=BBelazko?= Date: Thu, 1 Mar 2018 13:06:44 +0100 Subject: [PATCH 1/4] MKLDNN conv2 OP kernel added --- paddle/fluid/framework/operator.cc | 17 ++ paddle/fluid/framework/operator.h | 2 + paddle/fluid/operators/CMakeLists.txt | 26 +- paddle/fluid/operators/conv_mkldnn_op.cc | 274 ++++++++++++++++++ paddle/fluid/operators/conv_op.cc | 39 +-- paddle/fluid/platform/device_context.cc | 76 ++--- paddle/fluid/platform/device_context.h | 43 +-- python/paddle/fluid/layers/nn.py | 4 +- python/paddle/fluid/nets.py | 12 +- .../fluid/tests/unittests/test_conv2d_op.py | 24 +- 10 files changed, 399 insertions(+), 118 deletions(-) create mode 100644 paddle/fluid/operators/conv_mkldnn_op.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index ac6289c5abe8f..4ee17a6da1dd2 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -600,6 +600,23 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( return static_cast(data_type); } +bool OperatorWithKernel::CanCUDNNBeUsed(const ExecutionContext& ctx) const { + bool use_cudnn = ctx.Attr("use_cudnn"); + use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (use_cudnn) { + auto& dev_ctx = ctx.template device_context(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif + return use_cudnn; +} + +bool OperatorWithKernel::CanMKLDNNBeUsed(const ExecutionContext& ctx) const { + bool use_mkldnn = ctx.Attr("use_mkldnn"); + return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); +} + OpKernelType OperatorWithKernel::GetExpectedKernelType( const ExecutionContext& ctx) const { return OpKernelType(IndicateDataType(ctx), ctx.GetPlace()); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 41214b41cb68c..4366aa67483b6 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -392,6 +392,8 @@ class OperatorWithKernel : public OperatorBase { virtual OpKernelType GetKernelTypeForVar( const std::string& var_name, const Tensor& tensor, const OpKernelType& expected_kernel_type) const; + bool CanCUDNNBeUsed(const ExecutionContext& ctx) const; + bool CanMKLDNNBeUsed(const ExecutionContext& ctx) const; private: // indicate kernel DataType by input data. By default all input data must be diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e1c02ec1613a7..3647a0bb9b4e3 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -1,5 +1,7 @@ file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") +string(REPLACE "_mkldnn" "" GENERAL_OPS "${GENERAL_OPS}") string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}") +list(REMOVE_DUPLICATES GENERAL_OPS) set(DEPS_OPS "") set(pybind_file ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/pybind.h) file(WRITE ${pybind_file} "// Generated by the paddle/operator/CMakeLists.txt. DO NOT EDIT!\n\n") @@ -13,6 +15,8 @@ function(op_library TARGET) set(cu_cc_srcs) set(cudnn_cu_cc_srcs) set(CUDNN_FILE) + set(mkldnn_cc_srcs) + set(MKLDNN_FILE) set(op_common_deps operator op_registry math_function) set(options "") set(oneValueArgs "") @@ -36,12 +40,20 @@ function(op_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc) list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc) endif() + if(WITH_MKLDNN) + string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_FILE}.cc) + list(APPEND mkldnn_cc_srcs ${MKLDNN_FILE}.cc) + endif() + endif() else() foreach(src ${op_library_SRCS}) if (${src} MATCHES ".*\\.cu$") list(APPEND cu_srcs ${src}) elseif(${src} MATCHES ".*_cudnn_op.cu.cc$") list(APPEND cudnn_cu_cc_srcs ${src}) + elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$") + list(APPEND mkldnn_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cu.cc$") list(APPEND cu_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cc$") @@ -62,11 +74,11 @@ function(op_library TARGET) set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) endif() if (WITH_GPU) - nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} + nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) else() - cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS} - ${op_common_deps}) + cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} + ${op_common_deps}) endif() # Define operators that don't need pybind here. @@ -101,7 +113,8 @@ function(op_library TARGET) # pybind USE_CPU_ONLY_OP list(LENGTH cu_srcs cu_srcs_len) list(LENGTH cu_cc_srcs cu_cc_srcs_len) - if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0) + list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) + if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0) file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") set(pybind_flag 1) endif() @@ -112,6 +125,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") endif() + # pybind USE_OP_DEVICE_KERNEL for MKLDNN + if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") + endif() + # pybind USE_OP if (${pybind_flag} EQUAL 0) file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc new file mode 100644 index 0000000000000..fb3872e730825 --- /dev/null +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -0,0 +1,274 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/conv_op.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; + +using mkldnn::memory; // Note: paddle has also "memory" namespace +using mkldnn::primitive; +using mkldnn::convolution_forward; +using mkldnn::convolution_backward_weights; +using mkldnn::convolution_backward_data; +using mkldnn::convolution_direct; +using mkldnn::prop_kind; +using mkldnn::padding_kind; +using mkldnn::stream; + +template +class ConvOpMkldnnKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context(); + auto mkldnn_engine = dev_ctx.GetEngine(); + + auto* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + + // Get an unique name from "argument" name of "Output" variable + // This name will be used as key when saving info into device context + const std::string key = ctx.op().Output("Output"); + const std::string key_conv_pd = key + "@conv_pd"; + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + + PADDLE_ENFORCE(groups == 1, "MKLDNN doesn't support group convolution yet"); + PADDLE_ENFORCE( + dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, + "MKLDNN doesn't support dilation in convolution yet"); + + const T* input_data = input->data(); + const T* filter_data = filter->data(); + // allocate memory for output + T* output_data = output->mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE(input->dims().size() == 4, + "Input must be with 4 dimensions, i.e. NCHW"); + PADDLE_ENFORCE(filter->dims().size() == 4, + "Filter must be with 4 dimensions, i.e. OIHW"); + + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector weights_tz = + paddle::framework::vectorize2int(filter->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + + // MKLDNN primitives + // memory descriptors for convolution src/weight/dst + memory::dims conv_src_tz = {src_tz[0], src_tz[1], src_tz[2], src_tz[3]}; + memory::dims conv_weights_tz = {weights_tz[0], weights_tz[1], weights_tz[2], + weights_tz[3]}; + memory::dims conv_dst_tz = {dst_tz[0], dst_tz[1], dst_tz[2], dst_tz[3]}; + + memory::dims conv_strides = {strides[0], strides[1]}; + memory::dims conv_padding = {paddings[0], paddings[1]}; + + auto conv_src_md = memory::desc({conv_src_tz}, memory::data_type::f32, + memory::format::nchw); + auto conv_weights_md = memory::desc( + {conv_weights_tz}, memory::data_type::f32, memory::format::oihw); + auto conv_dst_md = memory::desc({conv_dst_tz}, memory::data_type::f32, + memory::format::nchw); + + // create memory primitives + auto conv_src_memory = + memory({conv_src_md, mkldnn_engine}, (void*)input_data); + auto conv_weights_memory = + memory({conv_weights_md, mkldnn_engine}, (void*)filter_data); + auto conv_dst_memory = memory({conv_dst_md, mkldnn_engine}, output_data); + + // create convolution op descriptor + auto conv_desc = convolution_forward::desc( + prop_kind::forward, convolution_direct, conv_src_md, conv_weights_md, + conv_dst_md, conv_strides, conv_padding, conv_padding, + padding_kind::zero); + + // conv primitive desc need be used in backward path + // so we need allocate it in heap (instead of in stack) + convolution_forward::primitive_desc* p_conv_pd; + p_conv_pd = + new convolution_forward::primitive_desc(conv_desc, mkldnn_engine); + // save conv_pd into dev_ctx to be referred in backward path + std::shared_ptr conv_pd; + conv_pd.reset(p_conv_pd); + dev_ctx.SetBlob(key_conv_pd, conv_pd); + + // create convolution op primitive + auto conv_prim = convolution_forward(*p_conv_pd, conv_src_memory, + conv_weights_memory, conv_dst_memory); + + // push op to stream and wait MKLDNN until it's executed + std::vector pipeline{conv_prim}; + stream(stream::kind::eager).submit(pipeline).wait(); + } +}; + +template +class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context(); + auto mkldnn_engine = dev_ctx.GetEngine(); + + const Tensor* input = ctx.Input("Input"); + const Tensor* filter = ctx.Input("Filter"); + const Tensor* output = ctx.Input("Output"); + const Tensor* output_grad = + ctx.Input(framework::GradVarName("Output")); + Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); + Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); + + if (!input_grad && !filter_grad) return; + + // Get an unique name from "argument" name of "Output" variable + // This name will be used as key when saving info into device context + const std::string key = ctx.op().Input("Output"); + const std::string key_conv_pd = key + "@conv_pd"; + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + + const T* input_data = input->data(); + const T* filter_data = filter->data(); + const T* output_grad_data = output_grad->data(); + T* input_grad_data = nullptr; + T* filter_grad_data = nullptr; + + // allocate memory for gradient of input/filter + if (input_grad) { + input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + } + if (filter_grad) { + filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); + } + + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector weights_tz = + paddle::framework::vectorize2int(filter->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + + // MKLDNN primitives + // memory descriptors for convolution src/weight/dst + memory::dims conv_src_tz = {src_tz[0], src_tz[1], src_tz[2], src_tz[3]}; + memory::dims conv_weights_tz = {weights_tz[0], weights_tz[1], weights_tz[2], + weights_tz[3]}; + memory::dims conv_dst_tz = {dst_tz[0], dst_tz[1], dst_tz[2], dst_tz[3]}; + + memory::dims conv_strides = {strides[0], strides[1]}; + memory::dims conv_padding = {paddings[0], paddings[1]}; + + auto conv_src_md = memory::desc({conv_src_tz}, memory::data_type::f32, + memory::format::nchw); + auto conv_diff_src_md = memory::desc({conv_src_tz}, memory::data_type::f32, + memory::format::nchw); + auto conv_weights_md = memory::desc( + {conv_weights_tz}, memory::data_type::f32, memory::format::oihw); + auto conv_diff_weights_md = memory::desc( + {conv_weights_tz}, memory::data_type::f32, memory::format::oihw); + auto conv_diff_dst_md = memory::desc({conv_dst_tz}, memory::data_type::f32, + memory::format::nchw); + + // create memory + auto conv_diff_dst_memory = + memory({conv_diff_weights_md, mkldnn_engine}, (void*)output_grad_data); + // Retrieve conv_pd from device context + std::shared_ptr conv_pd; + convolution_forward::primitive_desc* p_conv_pd; + + conv_pd = dev_ctx.GetBlob(key_conv_pd); + PADDLE_ENFORCE(conv_pd != nullptr, + "Fail to find conv_pd in device context"); + p_conv_pd = + static_cast(conv_pd.get()); + + // create backward conv primitive for weights + if (filter_grad) { + // create memory + auto conv_diff_weights_memory = memory( + {conv_diff_weights_md, mkldnn_engine}, (void*)filter_grad_data); + auto conv_src_memory = + memory({conv_src_md, mkldnn_engine}, (void*)input_data); + + // create primitive descriptor + auto conv_bwd_weights_desc = convolution_backward_weights::desc( + convolution_direct, conv_src_md, conv_diff_weights_md, + conv_diff_dst_md, conv_strides, conv_padding, conv_padding, + padding_kind::zero); + auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc( + conv_bwd_weights_desc, mkldnn_engine, + *p_conv_pd); // Need to hint forward desc + + // create backward conv primitive for weights + auto conv_bwd_weights_prim = convolution_backward_weights( + conv_bwd_weights_pd, conv_src_memory, conv_diff_dst_memory, + conv_diff_weights_memory); + + // push primitive and execute it + std::vector pipeline{conv_bwd_weights_prim}; + stream(stream::kind::eager).submit(pipeline).wait(); + } + + if (input_grad) { + // create memory + auto conv_diff_src_memory = + memory({conv_diff_src_md, mkldnn_engine}, (void*)input_grad_data); + auto conv_weights_memory = + memory({conv_weights_md, mkldnn_engine}, (void*)filter_data); + // create primitive descriptor + auto conv_bwd_data_desc = convolution_backward_data::desc( + convolution_direct, conv_diff_src_md, conv_weights_md, + conv_diff_dst_md, conv_strides, conv_padding, conv_padding, + padding_kind::zero); + auto conv_bwd_data_pd = convolution_backward_data::primitive_desc( + conv_bwd_data_desc, mkldnn_engine, + *p_conv_pd); // Need to hint forward desc + + // create backward conv primitive for data + auto conv_bwd_data_prim = + convolution_backward_data(conv_bwd_data_pd, conv_diff_dst_memory, + conv_weights_memory, conv_diff_src_memory); + + // push primitive and execute it + std::vector pipeline{conv_bwd_data_prim}; + stream(stream::kind::eager).submit(pipeline).wait(); + } + } // Compute() +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace, + ops::ConvOpMkldnnKernel); + +REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, + ops::ConvGradOpMkldnnKernel); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 83b7708bf337b..aff79bca9200c 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -64,19 +64,11 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - bool use_cudnn = ctx.Attr("use_cudnn"); - use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); -#ifdef PADDLE_WITH_CUDA - if (platform::is_gpu_place(ctx.GetPlace())) { - auto& dev_ctx = ctx.template device_context(); - use_cudnn &= dev_ctx.cudnn_handle() != nullptr; - } -#endif - framework::LibraryType library_; - if (use_cudnn) { + framework::LibraryType library_{framework::LibraryType::kPlain}; + if (CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; - } else { - library_ = framework::LibraryType::kPlain; + } else if (CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; } std::string data_format = ctx.Attr("data_format"); @@ -131,6 +123,9 @@ Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker) "use_cudnn", "(bool, default false) Only used in cudnn kernel, need install cudnn") .SetDefault(false); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddAttr( "data_format", "(string, default NCHW) Only used in " @@ -224,6 +219,9 @@ Conv3DOpMaker::Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker) "use_cudnn", "(bool, default false) Only used in cudnn kernel, need install cudnn") .SetDefault(false); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddAttr( "data_format", "(string, default NCHW) Only used in " @@ -284,20 +282,11 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - bool use_cudnn = ctx.Attr("use_cudnn"); - use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); -#ifdef PADDLE_WITH_CUDA - if (platform::is_gpu_place(ctx.GetPlace())) { - auto& dev_ctx = ctx.template device_context(); - use_cudnn &= dev_ctx.cudnn_handle() != nullptr; - } -#endif - - framework::LibraryType library_; - if (use_cudnn) { + framework::LibraryType library_{framework::LibraryType::kPlain}; + if (CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; - } else { - library_ = framework::LibraryType::kPlain; + } else if (CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; } std::string data_format = ctx.Attr("data_format"); diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 7da6e04d0a8b8..326ff67ab9a01 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -33,9 +33,15 @@ DeviceContextPool::DeviceContextPool( PADDLE_ENFORCE_GT(places.size(), 0); for (size_t i = 0; i < places.size(); i++) { if (platform::is_cpu_place(places[i])) { +#ifdef PADDLE_WITH_MKLDNN + device_contexts_.emplace(places[i], + new platform::MKLDNNDeviceContext( + boost::get(places[i]))); +#else device_contexts_.emplace(places[i], new platform::CPUDeviceContext( boost::get(places[i]))); +#endif } else if (platform::is_gpu_place(places[i])) { #ifdef PADDLE_WITH_CUDA device_contexts_.emplace(places[i], @@ -170,64 +176,38 @@ cudaStream_t CUDADeviceContext::stream() const { return stream_; } #ifdef PADDLE_WITH_MKLDNN MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) - : CPUDeviceContext(place), ready_(false) { - stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); - engine_.reset(new mkldnn::engine(mkldnn::engine::cpu, 0)); + : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() { + p_blobs_.reset(new std::unordered_map>()); } -template -void MKLDNNDeviceContext::AddElement(const std::string& op_key, - const T& value) { - if (GetElement(op_key)) { - return; - } - GetElementPool().emplace(op_key, std::move(value)); -} +void MKLDNNDeviceContext::SetBlob(const std::string& name, + std::shared_ptr data) const { + std::unordered_map>* p; + p = p_blobs_.get(); -template -const T& MKLDNNDeviceContext::GetElement(const std::string& op_key) const { - auto it = GetElementPool().find(op_key); - return it == GetElementPool().end() ? nullptr : it->second; -} + auto it = p->find(name); -template <> -const std::unordered_map>& -MKLDNNDeviceContext::GetElementPool() const { - return memory_pool_; -} + if (it == p->end()) { + (*p)[name] = data; // create new blob + } else { + it->second = data; // set data to existing blob + } -template <> -const std::unordered_map>& -MKLDNNDeviceContext::GetElementPool() const { - return primitive_pool_; + return; } -template <> -const std::unordered_map>& -MKLDNNDeviceContext::GetElementPool() const { - return primitive_desc_pool_; -} +std::shared_ptr MKLDNNDeviceContext::GetBlob( + const std::string& name) const { + std::unordered_map>* p; + p = p_blobs_.get(); -void MKLDNNDeviceContext::Execute(bool block) { - if (pipeline_.empty()) { - return; - } - ResetStream(); - stream_->submit(pipeline_).wait(block); - ready_ = false; - pipeline_.clear(); -} + auto it = p->find(name); -void MKLDNNDeviceContext::ResetStream() { - if (ready_) { - return; + if (it != p->end()) { + return it->second; } - // TODO(TJ): change me when mkldnn have specific method to reset this state - stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); - ready_ = true; + + return nullptr; } #endif diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index a294ba5101528..2cfce226b77dd 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -114,46 +114,19 @@ class MKLDNNDeviceContext : public CPUDeviceContext { public: explicit MKLDNNDeviceContext(CPUPlace place); - /* \brief Add new element: memory, primitive or primitive desc */ - template - void AddElement(const std::string& op_key, const T& value); - - /* \brief Get existed element: memory, primitive or primitive desc */ - template - const T& GetElement(const std::string& op_key) const; - - /* \brief Get element pool: memory, primitive or primitive desc pool */ - template - const std::unordered_map>& - GetElementPool() const; - /* \brief Get the active engine */ - const MKLDNNEngine& engine() const { return *engine_; } - - /* \brief Submit primitive to pipeline */ - void Submit(const MKLDNNPrimitivePtr& p) { pipeline_.push_back(*p); } + const mkldnn::engine& GetEngine() const { return engine_; } - /*! \brief Execute all submitted primitives in pipeline */ - void Execute(bool block = true); + // Set data to blob (i.e. name/data pair). Create blob if not existing + void SetBlob(const std::string& name, std::shared_ptr data) const; - protected: - /*! \brief Reset the stream to prepare next exectue */ - void ResetStream(); + // Find a saved blob. Return nullptr if not found + std::shared_ptr GetBlob(const std::string& name) const; private: - std::unordered_map> - memory_pool_; - std::unordered_map> - primitive_pool_; - std::unordered_map> - primitive_desc_pool_; - std::vector pipeline_; - MKLDNNStreamPtr stream_; - MKLDNNEnginePtr engine_; - bool ready_; + mkldnn::engine engine_; + std::shared_ptr>> + p_blobs_; }; #endif diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e10a01a5d7cb5..199542281bfc7 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1146,6 +1146,7 @@ def conv2d(input, param_attr=None, bias_attr=None, use_cudnn=True, + use_mkldnn=False, act=None): """ **Convlution2D Layer** @@ -1287,7 +1288,8 @@ def _get_default_param_initializer(): 'strides': stride, 'paddings': padding, 'groups': groups, - 'use_cudnn': use_cudnn + 'use_cudnn': use_cudnn, + 'use_mkldnn': use_mkldnn }) pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index c161d93854ae2..8c627ad55bcba 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -29,14 +29,16 @@ def simple_img_conv_pool(input, act, param_attr=None, pool_type='max', - use_cudnn=True): + use_cudnn=True, + use_mkldnn=False): conv_out = layers.conv2d( input=input, num_filters=num_filters, filter_size=filter_size, param_attr=param_attr, act=act, - use_cudnn=use_cudnn) + use_cudnn=use_cudnn, + use_mkldnn=use_mkldnn) pool_out = layers.pool2d( input=conv_out, @@ -58,7 +60,8 @@ def img_conv_group(input, conv_batchnorm_drop_rate=0.0, pool_stride=1, pool_type=None, - use_cudnn=True): + use_cudnn=True, + use_mkldnn=False): """ Image Convolution Group, Used for vgg net. """ @@ -90,7 +93,8 @@ def __extend_list__(obj): padding=conv_padding[i], param_attr=param_attr[i], act=local_conv_act, - use_cudnn=use_cudnn) + use_cudnn=use_cudnn, + use_mkldnn=use_mkldnn) if conv_with_batchnorm[i]: tmp = layers.batch_norm(input=tmp, act=conv_act) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 1321cfd484ec8..a49fecf09509f 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -64,6 +64,7 @@ def conv2d_forward_naive(input, filter, group, conv_param): class TestConv2dOp(OpTest): def setUp(self): self.use_cudnn = False + self.use_mkldnn = False self.init_op_type() self.init_group() self.init_dilation() @@ -85,7 +86,8 @@ def setUp(self): 'paddings': self.pad, 'groups': self.groups, 'dilations': self.dilations, - 'use_cudnn': self.use_cudnn + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn } self.outputs = {'Output': output} @@ -290,5 +292,25 @@ def init_test_case(self): # def init_op_type(self): # self.op_type = "conv_cudnn" + +#----------------Conv2dMKLDNN---------------- +class TestMKLDNN(TestConv2dOp): + def init_op_type(self): + self.use_mkldnn = True + self.op_type = "conv2d" + + +class TestMKLDNNWithPad(TestWithPad): + def init_op_type(self): + self.use_mkldnn = True + self.op_type = "conv2d" + + +class TestMKLDNNWithStride(TestWithStride): + def init_op_type(self): + self.use_mkldnn = True + self.op_type = "conv2d" + + if __name__ == '__main__': unittest.main() From 06ca8d241278d96dd49cbeeeeee15e2146574b0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20=C5=BBelazko?= Date: Fri, 2 Mar 2018 13:34:01 +0100 Subject: [PATCH 2/4] TODOs added --- paddle/fluid/operators/conv_mkldnn_op.cc | 8 +++++--- paddle/fluid/operators/conv_op.cc | 2 ++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index fb3872e730825..48c1aa42a65d7 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -15,7 +15,6 @@ #include "mkldnn.hpp" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/conv_op.h" -#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { @@ -41,7 +40,7 @@ class ConvOpMkldnnKernel : public paddle::framework::OpKernel { "It must use CPUPlace."); auto& dev_ctx = ctx.template device_context(); - auto mkldnn_engine = dev_ctx.GetEngine(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); auto* input = ctx.Input("Input"); auto* filter = ctx.Input("Filter"); @@ -57,6 +56,7 @@ class ConvOpMkldnnKernel : public paddle::framework::OpKernel { std::vector dilations = ctx.Attr>("dilations"); int groups = ctx.Attr("groups"); + // TODO(pzelazko-intel) enable group convolution PADDLE_ENFORCE(groups == 1, "MKLDNN doesn't support group convolution yet"); PADDLE_ENFORCE( dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, @@ -87,6 +87,7 @@ class ConvOpMkldnnKernel : public paddle::framework::OpKernel { memory::dims conv_strides = {strides[0], strides[1]}; memory::dims conv_padding = {paddings[0], paddings[1]}; + // TODO(pzelazko-intel): support more formats auto conv_src_md = memory::desc({conv_src_tz}, memory::data_type::f32, memory::format::nchw); auto conv_weights_md = memory::desc( @@ -135,7 +136,7 @@ class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel { "It must use CPUPlace."); auto& dev_ctx = ctx.template device_context(); - auto mkldnn_engine = dev_ctx.GetEngine(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); const Tensor* input = ctx.Input("Input"); const Tensor* filter = ctx.Input("Filter"); @@ -184,6 +185,7 @@ class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel { memory::dims conv_strides = {strides[0], strides[1]}; memory::dims conv_padding = {paddings[0], paddings[1]}; + // TODO(pzelazko-intel): support more formats auto conv_src_md = memory::desc({conv_src_tz}, memory::data_type::f32, memory::format::nchw); auto conv_diff_src_md = memory::desc({conv_src_tz}, memory::data_type::f32, diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index aff79bca9200c..35bd3a991ebec 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -72,6 +72,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( } std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::DataLayout layout_ = framework::StringToDataLayout(data_format); return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), @@ -290,6 +291,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( } std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::DataLayout layout_ = framework::StringToDataLayout(data_format); return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), From 0e3f110d86ab52954ae762c9dbe2b5b50688ae88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20=C5=BBelazko?= Date: Mon, 5 Mar 2018 14:39:30 +0100 Subject: [PATCH 3/4] mkldnn conv2d OP refactor --- paddle/fluid/operators/conv_mkldnn_op.cc | 177 ++++++++++++++--------- paddle/fluid/platform/mkldnn_helper.h | 8 + 2 files changed, 115 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 48c1aa42a65d7..d59cc2c9d424f 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -15,12 +15,14 @@ #include "mkldnn.hpp" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/conv_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" namespace paddle { namespace operators { using paddle::framework::Tensor; using paddle::platform::MKLDNNDeviceContext; +using paddle::platform::MKLDNNMemDesc; using mkldnn::memory; // Note: paddle has also "memory" namespace using mkldnn::primitive; @@ -32,6 +34,28 @@ using mkldnn::prop_kind; using mkldnn::padding_kind; using mkldnn::stream; +namespace { +std::unique_ptr +ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, + const memory::desc& dst, const std::vector& strides, + const std::vector& paddings, + const mkldnn::engine& engine); + +convolution_backward_weights::primitive_desc ConvBwdWeightsPrimitiveDesc( + const memory::desc& src, const memory::desc& diff_weights, + const memory::desc& diff_dst, const std::vector& strides, + const std::vector& paddings, + const convolution_forward::primitive_desc& conv_pd, + const mkldnn::engine& engine); + +convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc( + const memory::desc& diff_src, const memory::desc& weights, + const memory::desc& diff_dst, const std::vector& strides, + const std::vector& paddings, + const convolution_forward::primitive_desc& conv_pd, + const mkldnn::engine& engine); +} // anonymous namespace + template class ConvOpMkldnnKernel : public paddle::framework::OpKernel { public: @@ -56,11 +80,11 @@ class ConvOpMkldnnKernel : public paddle::framework::OpKernel { std::vector dilations = ctx.Attr>("dilations"); int groups = ctx.Attr("groups"); - // TODO(pzelazko-intel) enable group convolution - PADDLE_ENFORCE(groups == 1, "MKLDNN doesn't support group convolution yet"); + // TODO(pzelazko-intel) add support for group convolution and dilation + PADDLE_ENFORCE(groups == 1, "group convolution is not implemented yet"); PADDLE_ENFORCE( dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, - "MKLDNN doesn't support dilation in convolution yet"); + "dilation in convolution is not implemented yet"); const T* input_data = input->data(); const T* filter_data = filter->data(); @@ -77,23 +101,14 @@ class ConvOpMkldnnKernel : public paddle::framework::OpKernel { paddle::framework::vectorize2int(filter->dims()); std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); - // MKLDNN primitives - // memory descriptors for convolution src/weight/dst - memory::dims conv_src_tz = {src_tz[0], src_tz[1], src_tz[2], src_tz[3]}; - memory::dims conv_weights_tz = {weights_tz[0], weights_tz[1], weights_tz[2], - weights_tz[3]}; - memory::dims conv_dst_tz = {dst_tz[0], dst_tz[1], dst_tz[2], dst_tz[3]}; - - memory::dims conv_strides = {strides[0], strides[1]}; - memory::dims conv_padding = {paddings[0], paddings[1]}; - // TODO(pzelazko-intel): support more formats - auto conv_src_md = memory::desc({conv_src_tz}, memory::data_type::f32, - memory::format::nchw); - auto conv_weights_md = memory::desc( - {conv_weights_tz}, memory::data_type::f32, memory::format::oihw); - auto conv_dst_md = memory::desc({conv_dst_tz}, memory::data_type::f32, - memory::format::nchw); + // memory descriptors for convolution src/weight/dst + auto conv_src_md = + MKLDNNMemDesc(src_tz, memory::data_type::f32, memory::format::nchw); + auto conv_weights_md = + MKLDNNMemDesc(weights_tz, memory::data_type::f32, memory::format::oihw); + auto conv_dst_md = + MKLDNNMemDesc(dst_tz, memory::data_type::f32, memory::format::nchw); // create memory primitives auto conv_src_memory = @@ -102,21 +117,14 @@ class ConvOpMkldnnKernel : public paddle::framework::OpKernel { memory({conv_weights_md, mkldnn_engine}, (void*)filter_data); auto conv_dst_memory = memory({conv_dst_md, mkldnn_engine}, output_data); - // create convolution op descriptor - auto conv_desc = convolution_forward::desc( - prop_kind::forward, convolution_direct, conv_src_md, conv_weights_md, - conv_dst_md, conv_strides, conv_padding, conv_padding, - padding_kind::zero); + std::unique_ptr conv_pd = + ConvFwdPrimitiveDesc(conv_src_md, conv_weights_md, conv_dst_md, strides, + paddings, mkldnn_engine); - // conv primitive desc need be used in backward path - // so we need allocate it in heap (instead of in stack) - convolution_forward::primitive_desc* p_conv_pd; - p_conv_pd = - new convolution_forward::primitive_desc(conv_desc, mkldnn_engine); - // save conv_pd into dev_ctx to be referred in backward path - std::shared_ptr conv_pd; - conv_pd.reset(p_conv_pd); - dev_ctx.SetBlob(key_conv_pd, conv_pd); + // save p_conv_pd into dev_ctx to be referred in backward path + auto p_conv_pd = conv_pd.get(); + std::shared_ptr conv_pd_value = std::move(conv_pd); + dev_ctx.SetBlob(key_conv_pd, conv_pd_value); // create convolution op primitive auto conv_prim = convolution_forward(*p_conv_pd, conv_src_memory, @@ -175,27 +183,17 @@ class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel { paddle::framework::vectorize2int(filter->dims()); std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); - // MKLDNN primitives - // memory descriptors for convolution src/weight/dst - memory::dims conv_src_tz = {src_tz[0], src_tz[1], src_tz[2], src_tz[3]}; - memory::dims conv_weights_tz = {weights_tz[0], weights_tz[1], weights_tz[2], - weights_tz[3]}; - memory::dims conv_dst_tz = {dst_tz[0], dst_tz[1], dst_tz[2], dst_tz[3]}; - - memory::dims conv_strides = {strides[0], strides[1]}; - memory::dims conv_padding = {paddings[0], paddings[1]}; - // TODO(pzelazko-intel): support more formats - auto conv_src_md = memory::desc({conv_src_tz}, memory::data_type::f32, - memory::format::nchw); - auto conv_diff_src_md = memory::desc({conv_src_tz}, memory::data_type::f32, - memory::format::nchw); - auto conv_weights_md = memory::desc( - {conv_weights_tz}, memory::data_type::f32, memory::format::oihw); - auto conv_diff_weights_md = memory::desc( - {conv_weights_tz}, memory::data_type::f32, memory::format::oihw); - auto conv_diff_dst_md = memory::desc({conv_dst_tz}, memory::data_type::f32, - memory::format::nchw); + auto conv_src_md = + MKLDNNMemDesc(src_tz, memory::data_type::f32, memory::format::nchw); + auto conv_diff_src_md = + MKLDNNMemDesc(src_tz, memory::data_type::f32, memory::format::nchw); + auto conv_weights_md = + MKLDNNMemDesc(weights_tz, memory::data_type::f32, memory::format::oihw); + auto conv_diff_weights_md = + MKLDNNMemDesc(weights_tz, memory::data_type::f32, memory::format::oihw); + auto conv_diff_dst_md = + MKLDNNMemDesc(dst_tz, memory::data_type::f32, memory::format::nchw); // create memory auto conv_diff_dst_memory = @@ -212,21 +210,18 @@ class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel { // create backward conv primitive for weights if (filter_grad) { + // create primitive descriptor + convolution_backward_weights::primitive_desc conv_bwd_weights_pd = + ConvBwdWeightsPrimitiveDesc(conv_src_md, conv_diff_weights_md, + conv_diff_dst_md, strides, paddings, + *p_conv_pd, mkldnn_engine); + // create memory auto conv_diff_weights_memory = memory( {conv_diff_weights_md, mkldnn_engine}, (void*)filter_grad_data); auto conv_src_memory = memory({conv_src_md, mkldnn_engine}, (void*)input_data); - // create primitive descriptor - auto conv_bwd_weights_desc = convolution_backward_weights::desc( - convolution_direct, conv_src_md, conv_diff_weights_md, - conv_diff_dst_md, conv_strides, conv_padding, conv_padding, - padding_kind::zero); - auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc( - conv_bwd_weights_desc, mkldnn_engine, - *p_conv_pd); // Need to hint forward desc - // create backward conv primitive for weights auto conv_bwd_weights_prim = convolution_backward_weights( conv_bwd_weights_pd, conv_src_memory, conv_diff_dst_memory, @@ -238,19 +233,17 @@ class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel { } if (input_grad) { + // create primitive descriptor + convolution_backward_data::primitive_desc conv_bwd_data_pd = + ConvBwdDataPrimitiveDesc(conv_diff_src_md, conv_weights_md, + conv_diff_dst_md, strides, paddings, + *p_conv_pd, mkldnn_engine); + // create memory auto conv_diff_src_memory = memory({conv_diff_src_md, mkldnn_engine}, (void*)input_grad_data); auto conv_weights_memory = memory({conv_weights_md, mkldnn_engine}, (void*)filter_data); - // create primitive descriptor - auto conv_bwd_data_desc = convolution_backward_data::desc( - convolution_direct, conv_diff_src_md, conv_weights_md, - conv_diff_dst_md, conv_strides, conv_padding, conv_padding, - padding_kind::zero); - auto conv_bwd_data_pd = convolution_backward_data::primitive_desc( - conv_bwd_data_desc, mkldnn_engine, - *p_conv_pd); // Need to hint forward desc // create backward conv primitive for data auto conv_bwd_data_prim = @@ -264,6 +257,50 @@ class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel { } // Compute() }; +namespace { +std::unique_ptr ConvFwdPrimitiveDesc( + const memory::desc& src, const memory::desc& weights, + const memory::desc& dst, const std::vector& strides, + const std::vector& paddings, const mkldnn::engine& engine) { + mkldnn::memory::dims stride_dims = {strides[0], strides[1]}; + mkldnn::memory::dims padding_dims = {paddings[0], paddings[1]}; + + auto conv_desc = mkldnn::convolution_forward::desc( + mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, dst, + stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); + + auto p_conv_pd = new convolution_forward::primitive_desc(conv_desc, engine); + + return std::unique_ptr( + p_conv_pd); +} + +convolution_backward_weights::primitive_desc ConvBwdWeightsPrimitiveDesc( + const memory::desc& src, const memory::desc& diff_weights, + const memory::desc& diff_dst, const std::vector& strides, + const std::vector& paddings, + const convolution_forward::primitive_desc& conv_pd, + const mkldnn::engine& engine) { + auto conv_bwd_weights_desc = convolution_backward_weights::desc( + convolution_direct, src, diff_weights, diff_dst, strides, paddings, + paddings, padding_kind::zero); + return convolution_backward_weights::primitive_desc(conv_bwd_weights_desc, + engine, conv_pd); +} + +convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc( + const memory::desc& diff_src, const memory::desc& weights, + const memory::desc& diff_dst, const std::vector& strides, + const std::vector& paddings, + const convolution_forward::primitive_desc& conv_pd, + const mkldnn::engine& engine) { + auto conv_bwd_data_desc = convolution_backward_data::desc( + convolution_direct, diff_src, weights, diff_dst, strides, paddings, + paddings, padding_kind::zero); + return convolution_backward_data::primitive_desc(conv_bwd_data_desc, engine, + conv_pd); +} +} // anonymous namespace } // namespace operators } // namespace paddle diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 6d71f352c6eda..e00ee9c52c502 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -22,6 +22,7 @@ namespace platform { using MKLDNNStream = mkldnn::stream; using MKLDNNEngine = mkldnn::engine; using MKLDNNMemory = mkldnn::memory; +using MKLDNNMemoryDescriptor = mkldnn::memory::desc; using MKLDNNPrimitive = mkldnn::primitive; using MKLDNNPrimitiveDesc = mkldnn::handle; @@ -31,5 +32,12 @@ typedef std::unique_ptr MKLDNNMemoryPtr; typedef std::unique_ptr MKLDNNPrimitivePtr; typedef std::unique_ptr MKLDNNPrimitiveDescPtr; +inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector& dims, + mkldnn::memory::data_type data_type, + mkldnn::memory::format format) { + mkldnn::memory::dims tz = dims; + return mkldnn::memory::desc({tz}, data_type, format); +} + } // namespace platform } // namespace paddle From dad7a0274acb1ceb0ee24452a0f581cfb98acf0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20=C5=BBelazko?= Date: Tue, 6 Mar 2018 11:53:29 +0100 Subject: [PATCH 4/4] CanCUDNNBeUsed and CanMKLDNNBeUsed moved --- paddle/fluid/framework/operator.cc | 17 ----------------- paddle/fluid/framework/operator.h | 2 -- paddle/fluid/operators/conv_op.cc | 26 ++++++++++++++++++++++---- paddle/fluid/platform/cudnn_helper.h | 14 ++++++++++++++ paddle/fluid/platform/device_context.h | 2 +- paddle/fluid/platform/mkldnn_helper.h | 7 +++++++ 6 files changed, 44 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 4ee17a6da1dd2..ac6289c5abe8f 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -600,23 +600,6 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( return static_cast(data_type); } -bool OperatorWithKernel::CanCUDNNBeUsed(const ExecutionContext& ctx) const { - bool use_cudnn = ctx.Attr("use_cudnn"); - use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); -#ifdef PADDLE_WITH_CUDA - if (use_cudnn) { - auto& dev_ctx = ctx.template device_context(); - use_cudnn &= dev_ctx.cudnn_handle() != nullptr; - } -#endif - return use_cudnn; -} - -bool OperatorWithKernel::CanMKLDNNBeUsed(const ExecutionContext& ctx) const { - bool use_mkldnn = ctx.Attr("use_mkldnn"); - return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); -} - OpKernelType OperatorWithKernel::GetExpectedKernelType( const ExecutionContext& ctx) const { return OpKernelType(IndicateDataType(ctx), ctx.GetPlace()); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 4366aa67483b6..41214b41cb68c 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -392,8 +392,6 @@ class OperatorWithKernel : public OperatorBase { virtual OpKernelType GetKernelTypeForVar( const std::string& var_name, const Tensor& tensor, const OpKernelType& expected_kernel_type) const; - bool CanCUDNNBeUsed(const ExecutionContext& ctx) const; - bool CanMKLDNNBeUsed(const ExecutionContext& ctx) const; private: // indicate kernel DataType by input data. By default all input data must be diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 35bd3a991ebec..4b02b80d7772f 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/conv_op.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -65,11 +71,17 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library_{framework::LibraryType::kPlain}; - if (CanCUDNNBeUsed(ctx)) { +#ifdef PADDLE_WITH_CUDA + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; - } else if (CanMKLDNNBeUsed(ctx)) { + } +#endif +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; } +#endif std::string data_format = ctx.Attr("data_format"); // TODO(pzelazko-intel): enable MKLDNN layout when it's ready @@ -284,11 +296,17 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library_{framework::LibraryType::kPlain}; - if (CanCUDNNBeUsed(ctx)) { +#ifdef PADDLE_WITH_CUDA + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; - } else if (CanMKLDNNBeUsed(ctx)) { + } +#endif +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; } +#endif std::string data_format = ctx.Attr("data_format"); // TODO(pzelazko-intel): enable MKLDNN layout when it's ready diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 48c967de1155a..1842ecd745e3f 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include + +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" @@ -282,5 +284,17 @@ class ScopedPoolingDescriptor { DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor); }; +inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { + bool use_cudnn = ctx.Attr("use_cudnn"); + use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (use_cudnn) { + auto& dev_ctx = ctx.template device_context(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif + return use_cudnn; +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 2cfce226b77dd..01de8c4ab3c54 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -22,7 +22,7 @@ limitations under the License. */ #endif #ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" +#include #endif #include "paddle/fluid/platform/enforce.h" diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index e00ee9c52c502..90b78142b845e 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -16,6 +16,8 @@ limitations under the License. */ #include +#include "paddle/fluid/framework/operator.h" + namespace paddle { namespace platform { @@ -39,5 +41,10 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector& dims, return mkldnn::memory::desc({tz}, data_type, format); } +inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { + bool use_mkldnn = ctx.Attr("use_mkldnn"); + return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); +} + } // namespace platform } // namespace paddle