Skip to content

Commit

Permalink
Removing WITHIN_CHANNEL algorithm for lrn. CPU lrn operator works onl…
Browse files Browse the repository at this point in the history
…y with ACROSS_CHANNELS
  • Loading branch information
Tomasz Patejko committed Mar 19, 2018
1 parent c51c446 commit 2d95527
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 38 deletions.
27 changes: 8 additions & 19 deletions paddle/fluid/operators/lrn_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,6 @@ namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;

namespace {
mkldnn::algorithm LRNAlgorithm(const paddle::framework::ExecutionContext& ctx) {
mkldnn::algorithm algorithm = mkldnn::lrn_across_channels;

std::string algorithm_str = ctx.Attr<std::string>("algorithm");
if (algorithm_str == "WITHIN_CHANNEL") {
algorithm = mkldnn::lrn_within_channel;
}
return algorithm;
}
} // namespace

template <typename T>
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -64,8 +52,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");

auto algorithm = LRNAlgorithm(ctx);

auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);

Expand All @@ -77,8 +63,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);

auto forward_desc = mkldnn::lrn_forward::desc{
mkldnn::prop_kind::forward, algorithm, src_md, n, alpha, beta, k};
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
mkldnn::lrn_across_channels,
src_md,
n,
alpha,
beta,
k};

auto forward_pd = std::make_shared<mkldnn::lrn_forward::primitive_desc>(
forward_desc, mkldnn_engine);
Expand Down Expand Up @@ -154,10 +145,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto diff_src_memory = mkldnn::memory{{diff_src_md, mkldnn_engine},
static_cast<void*>(x_grad_data)};

auto algorithm = LRNAlgorithm(ctx);

auto backward_desc = mkldnn::lrn_backward::desc{
algorithm, src_md, diff_src_md, n, alpha, beta, k};
mkldnn::lrn_across_channels, src_md, diff_src_md, n, alpha, beta, k};

auto forward_pd = dev_ctx.GetBlob(key_pd);

Expand Down
33 changes: 14 additions & 19 deletions paddle/fluid/operators/lrn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,24 @@ template struct LRNGradFunctor<platform::CPUDeviceContext, float>;
template struct LRNGradFunctor<platform::CPUDeviceContext, double>;

namespace {
framework::OpKernelType GetExpectedLRNKernel(
const framework::ExecutionContext& ctx) {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::OpKernelType GetExpectedLRNKernel(
const framework::ExecutionContext& ctx) {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
}
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
}
#endif

std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_);
}
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_);
}
} // namespace

class LRNOp : public framework::OperatorWithKernel {
public:
Expand Down Expand Up @@ -214,11 +214,6 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddAttr<std::string>("algorithm",
"(string default ACROSS_CHANNELS"
"An optional string: \"ACROSS_CHANNELS\", "
"\"WITHIN_CHANNEL\". Used by MKLDNN library")
.SetDefault("ACROSS_CHANNELS");

AddComment(R"DOC(
Local Response Normalization Operator.
Expand Down

0 comments on commit 2d95527

Please sign in to comment.