Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Backport of improve log_softmax op performance by using DNNL #18469

Merged
merged 1 commit into from
Jun 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
83 changes: 83 additions & 0 deletions src/operator/nn/log_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,79 @@
#include "../tensor/elemwise_unary_op.h"
#include "../tensor/elemwise_binary_op.h"
#include "../operator_common.h"
#if MXNET_USE_MKLDNN == 1
#include "mkldnn/mkldnn_base-inl.h"
#include "mkldnn/mkldnn_ops-inl.h"
#endif

namespace mxnet {
namespace op {

#if MXNET_USE_MKLDNN == 1
static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNLogSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNRun(MKLDNNLogSoftmaxForward, attrs, ctx, inputs[0], req[0], outputs[0]);
auto fn = SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>;
MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs);
return;
}
FallBackCompute(SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>, attrs, ctx,
inputs, req, outputs);
}

static void LogSoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNLogSoftmax(param, inputs[1], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNRun(MKLDNNLogSoftmaxBackward, attrs, ctx, inputs, req, outputs);
auto fn = SoftmaxGradCompute<cpu, op::mshadow_op::left, mxnet_op::log_softmax_bwd>;
MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs);
return;
}
FallBackCompute(SoftmaxGradCompute<cpu, op::mshadow_op::left, mxnet_op::log_softmax_bwd>,
attrs, ctx, inputs, req, outputs);
}

inline static bool LogSoftmaxStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
}

inline static bool LogSoftmaxGradStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
bool support = true;
int num_inputs = 2U;
if (softmax_has_dtype_override(attrs)) {
support = false;
num_inputs = 3U;
}

CHECK_EQ(in_attrs->size(), num_inputs);
CHECK_EQ(out_attrs->size(), 1U);
return MKLDNNStorageType(attrs, dev_mask, support, dispatch_mode, in_attrs, out_attrs);
}
#endif

NNVM_REGISTER_OP(log_softmax)
.add_alias("_npx_log_softmax")
.describe(R"code(Computes the log softmax of the input.
Expand All @@ -49,7 +118,16 @@ Examples::

)code")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs){
return std::vector<std::string>{"data"};
})
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", LogSoftmaxComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", LogSoftmaxStorageType)
#endif
.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_log_softmax"})
.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
.set_num_inputs(1)
Expand All @@ -71,6 +149,11 @@ NNVM_REGISTER_OP(_backward_log_softmax)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
.set_attr_parser(ParamParser<SoftmaxParam>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", LogSoftmaxGradComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", LogSoftmaxGradStorageType)
#endif
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, mshadow_op::left,
mxnet_op::log_softmax_bwd>);

Expand Down
2 changes: 2 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ bool SupportQuantizedMKLDNNAct(const ActivationParam &param);
bool SupportMKLDNNConv(const ConvolutionParam &params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param, const NDArray &input,
const NDArray &output);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
} // namespace op
Expand Down
226 changes: 226 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_log_softmax.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_log_softmax.cc
* \brief Implementation of log_softmax function with MKLDNN support
*/

#include "../softmax-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

#if MXNET_USE_MKLDNN == 1
namespace mxnet {
namespace op {

static mkldnn::logsoftmax_forward::primitive_desc GetLogSoftmaxFwdPd(
bool is_train,
const int axis,
const mkldnn::memory &input_mem) {
mkldnn::memory::desc data_md = input_mem.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto prop = is_train ? mkldnn::prop_kind::forward_training
: mkldnn::prop_kind::forward_scoring;
auto desc = mkldnn::logsoftmax_forward::desc(prop, data_md, axis);
return mkldnn::logsoftmax_forward::primitive_desc(desc, cpu_engine);
}

static mkldnn::logsoftmax_backward::primitive_desc GetLogSoftmaxBwdPd(
const mkldnn::memory &diff_mem,
const mkldnn::memory &data_mem,
const int axis,
const mkldnn::logsoftmax_forward::primitive_desc &hint_fwd_pd) {
mkldnn::memory::desc diff_md = diff_mem.get_desc();
mkldnn::memory::desc data_md = data_mem.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto desc = mkldnn::logsoftmax_backward::desc(diff_md, data_md, axis);
return mkldnn::logsoftmax_backward::primitive_desc(desc, cpu_engine, hint_fwd_pd);
}


bool SupportMKLDNNLogSoftmax(const SoftmaxParam &param,
const NDArray &data,
const NDArray &output) {
const int ndim = data.shape().ndim();
const int in_dtype = data.dtype();
const int out_dtype = output.dtype();
const int axis = CheckAxis(param.axis, ndim);
// MKLDNN does not support temperature argument in their log_softmax function
// now. Need update this once they start to support it.
// Currently, MKLDNN shows bad performance when log_softmax is not performed on the last dimension
if (param.temperature.has_value() ||
in_dtype != mshadow::kFloat32 ||
in_dtype != out_dtype ||
axis != (ndim - 1)) {
return false;
}

// only supports ndim = 1, 2, 3, 4 for now
return (ndim >= 1 && ndim <= 4);
}

class MKLDNNLogSoftmaxFwd {
public:
mkldnn::logsoftmax_forward::primitive_desc pd;

MKLDNNLogSoftmaxFwd(const bool is_train,
const int axis,
const mkldnn::memory &input) : pd(GetLogSoftmaxFwdPd(is_train, axis, input)) {
fwd_ = std::make_shared<mkldnn::logsoftmax_forward>(pd);
}

const mkldnn::logsoftmax_forward &GetFwd() const {
return *fwd_;
}

private:
std::shared_ptr<mkldnn::logsoftmax_forward> fwd_;
};

typedef ParamOpSign<SoftmaxParam> MKLDNNSoftmaxSignature;

static MKLDNNLogSoftmaxFwd &GetLogSoftmaxFwd(const SoftmaxParam &param,
const int real_axis,
const bool is_train,
const NDArray &data,
const NDArray &output) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNSoftmaxSignature,
MKLDNNLogSoftmaxFwd,
OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNSoftmaxSignature,
MKLDNNLogSoftmaxFwd,
OpHash> fwds;
#endif

MKLDNNSoftmaxSignature key(param);
key.AddSign(real_axis);
key.AddSign(is_train);
key.AddSign(data);
key.AddSign(output);

auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNLogSoftmaxFwd fwd(is_train, real_axis, *(data.GetMKLDNNData()));
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

void MKLDNNLogSoftmaxForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &in_data,
const OpReqType &req,
const NDArray &out_data) {
if (req == kNullOp) return;
// same as the FCompute path, log_softmax only supports kWriteTo and kWriteInplace for now.
CHECK_NE(req, kAddTo);

const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, in_data.shape().ndim());
auto fwd = GetLogSoftmaxFwd(param, axis, ctx.is_train, in_data, out_data);

auto in_mem = in_data.GetMKLDNNData();
auto out_mem = out_data.GetMKLDNNData(fwd.pd.dst_desc());
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrimArgs(fwd.GetFwd(), {{MKLDNN_ARG_SRC, *in_mem}, {MKLDNN_ARG_DST, *out_mem}});
stream->Submit();
}

class MKLDNNLogSoftmaxBwd {
public:
mkldnn::logsoftmax_backward::primitive_desc pd;

MKLDNNLogSoftmaxBwd(const mkldnn::memory &diff_mem,
const mkldnn::memory &data_mem,
const int axis,
const mkldnn::logsoftmax_forward::primitive_desc &hint_fwd_pd) :
pd(GetLogSoftmaxBwdPd(diff_mem, data_mem, axis, hint_fwd_pd)) {
bwd_ = std::make_shared<mkldnn::logsoftmax_backward>(pd);
}

const mkldnn::logsoftmax_backward &GetBwd() const {
return *bwd_;
}

private:
std::shared_ptr<mkldnn::logsoftmax_backward> bwd_;
};

static MKLDNNLogSoftmaxBwd &GetLogSoftmaxBwd(const SoftmaxParam &param,
const int real_axis,
const std::vector<NDArray> &data,
const std::vector<NDArray> &output) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNSoftmaxSignature,
MKLDNNLogSoftmaxBwd,
OpHash> bwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNSoftmaxSignature,
MKLDNNLogSoftmaxBwd,
OpHash> bwds;
#endif

MKLDNNSoftmaxSignature key(param);
key.AddSign(real_axis);
key.AddSign(data);
key.AddSign(output);

auto it = bwds.find(key);
if (it == bwds.end()) {
auto diff_mem = data[0].GetMKLDNNData();
auto data_mem = data[1].GetMKLDNNData();
auto fwd_pd = GetLogSoftmaxFwdPd(true, real_axis, *data_mem);
MKLDNNLogSoftmaxBwd bwd(*diff_mem, *data_mem, real_axis, fwd_pd);
it = AddToCache(&bwds, key, bwd);
}
return it->second;
}

void MKLDNNLogSoftmaxBackward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
if (req[0] == kNullOp) return;
CHECK_EQ(in_data.size(), 2U);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, in_data[1].shape().ndim());
auto diff_mem = in_data[0].GetMKLDNNData();
auto data_mem = in_data[1].GetMKLDNNData();
auto bwd = GetLogSoftmaxBwd(param, axis, in_data, out_data);

auto out_mem = CreateMKLDNNMem(out_data[0], bwd.pd.diff_src_desc(), req[0]);
MKLDNNStream *stream = MKLDNNStream::Get();
mkldnn_args_map_t args = {
{ MKLDNN_ARG_DST, *data_mem },
{ MKLDNN_ARG_DIFF_DST, *diff_mem },
{ MKLDNN_ARG_DIFF_SRC, *out_mem.second }
};

stream->RegisterPrimArgs(bwd.GetBwd(), args);
CommitOutput(out_data[0], out_mem);
stream->Submit();
}

} // namespace op
} // namespace mxnet
#endif
9 changes: 9 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

/* For log_softmax */
void MKLDNNLogSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
void MKLDNNLogSoftmaxBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

/* For softmax_output */
void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def test_loss_ndarray():

loss = gluon.loss.SoftmaxCrossEntropyLoss()
L = loss(output, label).asnumpy()
assert_almost_equal(L, np.array([ 2.12692809, 0.04858733]))
assert_almost_equal(L, np.array([ 2.12692809, 0.04858733]), rtol=1e-3, atol=1e-4)

L = loss(output, label, weighting).asnumpy()
assert_almost_equal(L, np.array([ 1.06346405, 0.04858733]))
assert_almost_equal(L, np.array([ 1.06346405, 0.04858733]), rtol=1e-3, atol=1e-4)


def get_net(num_hidden, flatten=True):
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_numpy_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def test_np_loss_ndarray():

loss = gluon.loss.SoftmaxCrossEntropyLoss()
L = loss(output, label).asnumpy()
assert_almost_equal(L, _np.array([2.12692809, 0.04858733]), use_broadcast=False)
assert_almost_equal(L, _np.array([2.12692809, 0.04858733]), use_broadcast=False, rtol=1e-3)

L = loss(output, label, weighting).asnumpy()
assert_almost_equal(L, _np.array([1.06346405, 0.04858733]), use_broadcast=False)
assert_almost_equal(L, _np.array([1.06346405, 0.04858733]), use_broadcast=False, rtol=1e-3)


@with_seed()
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5338,8 +5338,8 @@ def test_log_softmax():
axis = np.random.randint(0, ndim)
data = np.random.uniform(-2, 2, size=shape)
sym = mx.sym.log_softmax(axis=axis-ndim)
check_symbolic_forward(sym, [data], [np.log(np_softmax(data, axis=axis)+1e-20)])
check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3)
check_symbolic_forward(sym, [data], [np.log(np_softmax(data, axis=axis)+1e-20)], rtol=1e-3, atol=1e-4)
check_numeric_gradient(sym, [data], rtol=1e-1, atol=1e-2)

def test_softmax_with_large_inputs():
def softmax_forward(input_data, true_output):
Expand Down