Skip to content

Commit

Permalink
CanCUDNNBeUsed and CanMKLDNNBeUsed moved
Browse files Browse the repository at this point in the history
  • Loading branch information
pzelazko-intel committed Mar 6, 2018
1 parent 0e3f110 commit dad7a02
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 24 deletions.
17 changes: 0 additions & 17 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -600,23 +600,6 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
return static_cast<proto::VarType::Type>(data_type);
}

bool OperatorWithKernel::CanCUDNNBeUsed(const ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (use_cudnn) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
return use_cudnn;
}

bool OperatorWithKernel::CanMKLDNNBeUsed(const ExecutionContext& ctx) const {
bool use_mkldnn = ctx.Attr<bool>("use_mkldnn");
return use_mkldnn && platform::is_cpu_place(ctx.GetPlace());
}

OpKernelType OperatorWithKernel::GetExpectedKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,6 @@ class OperatorWithKernel : public OperatorBase {
virtual OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) const;
bool CanCUDNNBeUsed(const ExecutionContext& ctx) const;
bool CanMKLDNNBeUsed(const ExecutionContext& ctx) const;

private:
// indicate kernel DataType by input data. By default all input data must be
Expand Down
26 changes: 22 additions & 4 deletions paddle/fluid/operators/conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/conv_op.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -65,11 +71,17 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
if (CanCUDNNBeUsed(ctx)) {
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
} else if (CanMKLDNNBeUsed(ctx)) {
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
}
#endif

std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
Expand Down Expand Up @@ -284,11 +296,17 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
if (CanCUDNNBeUsed(ctx)) {
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
} else if (CanMKLDNNBeUsed(ctx)) {
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
}
#endif

std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/platform/cudnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License. */
#pragma once

#include <vector>

#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
Expand Down Expand Up @@ -282,5 +284,17 @@ class ScopedPoolingDescriptor {
DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
};

inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (use_cudnn) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
return use_cudnn;
}

} // namespace platform
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License. */
#endif

#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#include <mkldnn.hpp>
#endif

#include "paddle/fluid/platform/enforce.h"
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/platform/mkldnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License. */

#include <mkldnn.hpp>

#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace platform {

Expand All @@ -39,5 +41,10 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims,
return mkldnn::memory::desc({tz}, data_type, format);
}

inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_mkldnn = ctx.Attr<bool>("use_mkldnn");
return use_mkldnn && platform::is_cpu_place(ctx.GetPlace());
}

} // namespace platform
} // namespace paddle

0 comments on commit dad7a02

Please sign in to comment.