Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MKLDNN: SoftmaxGrad Op #11519

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/mkldnn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
GIT_TAG "db3424ad44901513c03a1ea31ccaacdf633fbe9f"
GIT_TAG "a29d8487a63afca3d5b8c5bbdbb473cf8ccc6e51"
PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}
Expand Down
217 changes: 167 additions & 50 deletions paddle/fluid/operators/softmax_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,81 @@ using paddle::platform::MKLDNNMemDesc;
using mkldnn::memory; // Note: paddle has also "memory" namespace
using mkldnn::primitive;
using mkldnn::softmax_forward;
using mkldnn::softmax_backward;
using mkldnn::prop_kind;
using mkldnn::stream;
using platform::to_void_cast;

class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
public:
SoftmaxMKLDNNHandler(
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
softmax_pd_(softmax_pd) {}

SoftmaxMKLDNNHandler(
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd,
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> softmax_bwd_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
softmax_pd_(softmax_pd),
softmax_bwd_pd_(softmax_bwd_pd) {
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
key_ += "-BWD";
}

std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
/*Generate key*/
auto prim_key = key_ + "@softmax_p";

auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((softmax_p != nullptr) || (is_reusing_ == false),
"Fail to find softmax primitive in device context");
if (softmax_p == nullptr) {
softmax_p = std::make_shared<mkldnn::softmax_forward>(
*(softmax_pd_.get()),
*(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(dst_memory_p.get())));
dev_ctx_.SetBlob(prim_key, softmax_p);
} else {
is_reusing_ = true;
}

return softmax_p;
}

std::shared_ptr<mkldnn::softmax_backward> AcquireSoftmaxBackward(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> diff_src_memory_p) {
auto prim_key = key_ + "@softmax_bwd_p";
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((softmax_bwd_p != nullptr) || (is_reusing_ == false),
"Fail to find softmax backward primitive in device context");
if (softmax_bwd_p == nullptr) {
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
*softmax_bwd_pd_, *(dst_memory_p.get()), *(diff_dst_memory_p.get()),
*(diff_src_memory_p.get()));
dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
} else {
is_reusing_ = true;
}

return softmax_bwd_p;
}

private:
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd_;
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> softmax_bwd_pd_;
};

template <typename T>
class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
Expand All @@ -54,56 +127,27 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
// Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Each MKLDNN operator may have diffrent hashing function
auto gethash = [](memory::dims& operand_dims) {
return std::string(std::to_string(operand_dims[0]) + "-" +
std::to_string(operand_dims[1]));
};
const std::string key = gethash(softmax_tz);
const std::string key_softmax_p = key + "@softmax_p";
const std::string key_softmax_src_mem_p = key + "@softmax_src_mem_p";
const std::string key_softmax_dst_mem_p = key + "@softmax_dst_mem_p";

std::shared_ptr<void> softmax_p = dev_ctx.GetBlob(key_softmax_p);
if (softmax_p == nullptr) {
// Currently only NC data format is supported
auto softmax_md =
MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc);
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
softmax_md, 1 /*dim: C*/);
// create memory primitives
auto softmax_src_memory_p = std::make_shared<memory>(
memory::primitive_desc{softmax_md, mkldnn_engine},
static_cast<void*>(const_cast<T*>(input_data)));
dev_ctx.SetBlob(key_softmax_src_mem_p, softmax_src_memory_p);
auto softmax_dst_memory_p = std::make_shared<memory>(
memory::primitive_desc{softmax_md, mkldnn_engine},
static_cast<void*>(output_data));
dev_ctx.SetBlob(key_softmax_dst_mem_p, softmax_dst_memory_p);

auto softmax_forward_pd =
std::make_shared<softmax_forward::primitive_desc>(softmax_desc,
mkldnn_engine);
softmax_p = std::make_shared<softmax_forward>(
*(softmax_forward_pd.get()),
*(static_cast<memory*>(softmax_src_memory_p.get())),
*(static_cast<memory*>(softmax_dst_memory_p.get())));
dev_ctx.SetBlob(key_softmax_p, softmax_p);
} else {
// Primitives already exist
auto src_memory_p = std::static_pointer_cast<memory>(
dev_ctx.GetBlob(key_softmax_src_mem_p));
PADDLE_ENFORCE(src_memory_p != nullptr,
"Fail to find softmax src mem_p in device context");
auto dst_memory_p = std::static_pointer_cast<memory>(
dev_ctx.GetBlob(key_softmax_dst_mem_p));
PADDLE_ENFORCE(dst_memory_p != nullptr,
"Fail to find softmax dst mem_p in device context");
src_memory_p->set_data_handle(
reinterpret_cast<void*>(const_cast<T*>(input_data)));
dst_memory_p->set_data_handle(output_data);
}
const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out"));
const std::string key_softmax_pd = key + "@softmax_pd";

// Currently only NC data format is supported
auto softmax_md = MKLDNNMemDesc(
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
softmax_md, 1 /*dim: C*/);
auto softmax_pd = std::make_shared<mkldnn::softmax_forward::primitive_desc>(
softmax_desc, mkldnn_engine);
dev_ctx.SetBlob(key_softmax_pd, softmax_pd);

SoftmaxMKLDNNHandler handler(softmax_pd, dev_ctx, mkldnn_engine, key);
auto softmax_src_memory_p =
handler.AcquireSrcMemory(softmax_md, to_void_cast<T>(input_data));
auto softmax_dst_memory_p =
handler.AcquireDstMemory(softmax_md, to_void_cast<T>(output_data));
auto softmax_p =
handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);

std::vector<primitive> pipeline{
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
Expand All @@ -120,10 +164,83 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
}
};

template <typename T>
class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");

auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* output = ctx.Input<Tensor>("Out");
const T* dst_data = output->data<T>();

auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
const auto* diff_dst_ptr = dout->template data<T>();

auto* dx =
ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
T* diff_src_ptr = dx->template mutable_data<T>(ctx.GetPlace());

std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
std::vector<int> src_tz(dst_tz);
PADDLE_ENFORCE(output->dims().size() == 2UL,
"The input of softmax op must be a 2D matrix.");
// MKL-DNN does support softmax over selected axis. Having 2D Tensor,
// we will make normalization after final eg. axis: 1
PADDLE_ENFORCE(((src_tz[0] == dst_tz[0]) && (src_tz[1] == dst_tz[1])),
"Softmax input and output dimensions should match");
// Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Currently only supports NC data format
// retrieve eltwise primitive desc from device context
const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Input("Out"));
const std::string key_softmax_pd = key + "@softmax_pd";

auto softmax_pd =
std::static_pointer_cast<mkldnn::softmax_forward::primitive_desc>(
dev_ctx.GetBlob(key_softmax_pd));
PADDLE_ENFORCE(softmax_pd != nullptr,
"Fail to find softmax_pd in device context");

// TODO(jczaja): Add layouts support when there is a need to do so
// Two dimensional softmax does support NC format
auto data_softmax_md = MKLDNNMemDesc(
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
auto diff_softmax_md = MKLDNNMemDesc(
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_bwd_desc =
softmax_backward::desc(diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
auto softmax_bwd_pd =
std::make_shared<mkldnn::softmax_backward::primitive_desc>(
softmax_bwd_desc, mkldnn_engine, *softmax_pd);

SoftmaxMKLDNNHandler handler(softmax_pd, softmax_bwd_pd, dev_ctx,
mkldnn_engine, key);
auto dst_memory_p =
handler.AcquireDstMemory(data_softmax_md, to_void_cast<T>(dst_data));
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(
diff_softmax_md, to_void_cast<T>(diff_dst_ptr));
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(
diff_softmax_md, to_void_cast<T>(diff_src_ptr));

// Get primitve from device context
auto softmax_bwd_p = handler.AcquireSoftmaxBackward(
dst_memory_p, diff_dst_memory_p, diff_src_memory_p);

std::vector<primitive> pipeline{*softmax_bwd_p};
stream(stream::kind::eager).submit(pipeline).wait();
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_KERNEL(softmax, MKLDNN, ::paddle::platform::CPUPlace,
ops::SoftmaxMKLDNNKernel<float>);
REGISTER_OP_KERNEL(softmax_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::SoftmaxMKLDNNGradKernel<float>);
22 changes: 18 additions & 4 deletions paddle/fluid/operators/softmax_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,30 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);

#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"float16 can only be used on GPU place");
}

return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_);
}
};

Expand Down
Loading