Skip to content

Commit

Permalink
Device blobs are created only in training. Added testing attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomasz Patejko committed Mar 21, 2018
1 parent 2d95527 commit 72cc64e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 20 deletions.
71 changes: 51 additions & 20 deletions paddle/fluid/operators/lrn_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;

namespace {
template <typename T, typename... Args>
std::shared_ptr<T> insert_to_context(const std::string& key,
const MKLDNNDeviceContext& dev_ctx,
Args&&... args) {
auto p = std::static_pointer_cast<T, void>(dev_ctx.GetBlob(key));

if (!p) {
p = std::make_shared<T>(args...);
dev_ctx.SetBlob(key, std::static_pointer_cast<void, T>(p));
}

return p;
}
} // namespace

template <typename T>
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
Expand All @@ -42,15 +58,11 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto output_data = out->mutable_data<T>(ctx.GetPlace());
mid->mutable_data<T>(ctx.GetPlace());

const std::string key = ctx.op().Output("Out");
const std::string key_src_memory = key + "@lrn_src_memory";
const std::string key_pd = key + "@lrn_pd";
const std::string key_workspace_memory = key + "@lrn_workspace_memory";

const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha");
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
const bool is_test = ctx.Attr<bool>("is_test");

auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);
Expand All @@ -71,28 +83,47 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
beta,
k};

auto forward_pd = std::make_shared<mkldnn::lrn_forward::primitive_desc>(
forward_desc, mkldnn_engine);

dev_ctx.SetBlob(key_pd, forward_pd);

auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
auto src_memory = std::make_shared<mkldnn::memory>(
src_memory_pd, static_cast<void*>(const_cast<float*>(input_data)));

dev_ctx.SetBlob(key_src_memory, src_memory);
auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine},
static_cast<void*>(output_data)};

auto workspace_md = forward_pd->workspace_primitive_desc();
auto workspace_memory = std::make_shared<mkldnn::memory>(workspace_md);
std::unique_ptr<mkldnn::lrn_forward> forward_op = nullptr;

if (!is_test) {
const std::string key = ctx.op().Output("Out");
const std::string key_src_memory = key + "@lrn_src_memory";
const std::string key_pd = key + "@lrn_pd";
const std::string key_workspace_memory = key + "@lrn_workspace_memory";

auto forward_pd = insert_to_context<mkldnn::lrn_forward::primitive_desc>(
key_pd, dev_ctx, forward_desc, mkldnn_engine);

auto src_memory = insert_to_context<mkldnn::memory>(
key_src_memory, dev_ctx, src_memory_pd);

src_memory->set_data_handle(
static_cast<void*>(const_cast<T*>(input_data)));

auto workspace_memory = insert_to_context<mkldnn::memory>(
key_workspace_memory, dev_ctx,
forward_pd->workspace_primitive_desc());

forward_op.reset(new mkldnn::lrn_forward{*forward_pd, *src_memory,
*workspace_memory, dst_memory});

dev_ctx.SetBlob(key_workspace_memory, workspace_memory);
} else {
auto forward_pd =
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
auto src_memory = mkldnn::memory{
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
auto workspace_memory =
mkldnn::memory{forward_pd.workspace_primitive_desc()};

auto forward_op = mkldnn::lrn_forward{*forward_pd, *src_memory,
*workspace_memory, dst_memory};
forward_op.reset(new mkldnn::lrn_forward{forward_pd, src_memory,
workspace_memory, dst_memory});
}

std::vector<mkldnn::primitive> pipeline = {forward_op};
std::vector<mkldnn::primitive> pipeline = {*forward_op};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
};
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/lrn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddAttr<bool>("is_test", "").SetDefault(false);

AddComment(R"DOC(
Local Response Normalization Operator.
Expand Down

0 comments on commit 72cc64e

Please sign in to comment.