Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Integrating the MKL VML functions to MXNET to speed-up the (element-wised) mathematic computation #14893

Merged
merged 43 commits into from May 22, 2019
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f0c7264
mkl_func test with erf&log op, build success~
pengxin99 Jan 25, 2019
9311777
fix lint and build issues
TaoLv Jan 25, 2019
a79f7db
Try to add support to sparse array
juliusshufan Feb 22, 2019
015fd0a
fix build
TaoLv Mar 3, 2019
495ce36
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Apr 17, 2019
672be6a
add functions
TaoLv Apr 18, 2019
c69a25c
Fix review comments
juliusshufan Mar 9, 2019
2c5c20c
remove unecessary code
juliusshufan Mar 9, 2019
b1b6355
Update test case
juliusshufan Mar 10, 2019
f96c34a
minor fix
juliusshufan Mar 11, 2019
06c51e9
move the position of MKL_Compute
juliusshufan Apr 18, 2019
acd7b56
Merge pull request #6 from juliusshufan/erf
TaoLv Apr 25, 2019
dc0086f
mkl_func test with erf&log op, build success~
pengxin99 Jan 25, 2019
1758e91
fix lint and build issues
TaoLv Jan 25, 2019
4461f62
Try to add support to sparse array
juliusshufan Feb 22, 2019
a3efd02
fix build
TaoLv Mar 3, 2019
64d01a4
add functions
TaoLv Apr 18, 2019
7edca49
Fix review comments
juliusshufan Mar 9, 2019
d6139fc
remove unecessary code
juliusshufan Mar 9, 2019
46a49d6
Update test case
juliusshufan Mar 10, 2019
0e36f93
minor fix
juliusshufan Mar 11, 2019
c6e2518
move the position of MKL_Compute
juliusshufan Apr 18, 2019
1153479
Merge branch 'vml' of https://github.com/juliusshufan/incubator-mxnet…
juliusshufan May 6, 2019
e60493c
fix cpplint
juliusshufan May 7, 2019
f360320
cpp lint
juliusshufan May 8, 2019
15f2f20
trigger ci
juliusshufan May 8, 2019
01d3f7e
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv May 14, 2019
7a360e8
address comments
TaoLv May 14, 2019
22a9c4c
coding style
TaoLv May 14, 2019
2b9eca4
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv May 16, 2019
56384df
enable layernorm
TaoLv May 17, 2019
8d1dfee
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv May 17, 2019
c557788
fix windows build
TaoLv May 17, 2019
7e99f3e
revert changes to FComputeEx
TaoLv May 17, 2019
94bafb0
int -> index_t
TaoLv May 17, 2019
a3e07c5
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv May 17, 2019
e275daa
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv May 21, 2019
a383f46
remove workspace
TaoLv May 21, 2019
ff76244
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv May 21, 2019
b13d6ef
fix lint
TaoLv May 21, 2019
eb4c82b
clean code
TaoLv May 22, 2019
0cb1120
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv May 22, 2019
fc51292
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv May 22, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
202 changes: 202 additions & 0 deletions src/operator/mkl_functions-inl.h
@@ -0,0 +1,202 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file mkl_functions-inl.h
* \brief Wrapper for MKL VML functions
* \author Tao Lv, Shufan Wu
*/
#ifndef MXNET_OPERATOR_MKL_FUNCTIONS_INL_H_
#define MXNET_OPERATOR_MKL_FUNCTIONS_INL_H_

#if MSHADOW_USE_MKL == 1
#include "mkl_vml.h"

namespace mxnet {
namespace op {
namespace mkl_func {

MSHADOW_XINLINE
static bool check_size(const size_t n) {
pengzhao-intel marked this conversation as resolved.
Show resolved Hide resolved
const size_t MKL_INT_MAX = (sizeof(MKL_INT) == sizeof(int)) ? INT_MAX : LLONG_MAX;
return (n <= MKL_INT_MAX);
}

MSHADOW_XINLINE
static bool check_type(const int t) {
return (t == mshadow::kFloat32 || t == mshadow::kFloat64);
}

#define MXNET_MKL_UNARY_MATH_FUNC(name, func) \
struct name { \
MSHADOW_XINLINE static void Vectorize(const index_t n, const float *src, float *dst) { \
vs##func(static_cast<MKL_INT>(n), src, dst); \
} \
MSHADOW_XINLINE static void Vectorize(const index_t n, const double *src, double *dst) { \
vd##func(static_cast<MKL_INT>(n), src, dst); \
} \
};

#define MXNET_MKL_BINARY_MATH_FUNC(name, func) \
struct name { \
MSHADOW_XINLINE static void Vectorize(const index_t n, \
const float *a, \
const float *b, \
float *c) { \
vs##func(static_cast<MKL_INT>(n), a, b, c); \
} \
MSHADOW_XINLINE static void Vectorize(const index_t n, \
const double *a, \
const double *b, \
double *c) { \
vd##func(static_cast<MKL_INT>(n), a, b, c); \
} \
};

MXNET_MKL_UNARY_MATH_FUNC(erf, Erf);
MXNET_MKL_UNARY_MATH_FUNC(exp, Exp);
MXNET_MKL_UNARY_MATH_FUNC(exp2, Exp2);
MXNET_MKL_UNARY_MATH_FUNC(exp10, Exp10);
MXNET_MKL_UNARY_MATH_FUNC(expm1, Expm1);
MXNET_MKL_UNARY_MATH_FUNC(log, Ln);
MXNET_MKL_UNARY_MATH_FUNC(log2, Log2);
MXNET_MKL_UNARY_MATH_FUNC(log10, Log10);
MXNET_MKL_UNARY_MATH_FUNC(log1p, Log1p);

MXNET_MKL_UNARY_MATH_FUNC(sin, Sin);
MXNET_MKL_UNARY_MATH_FUNC(cos, Cos);
MXNET_MKL_UNARY_MATH_FUNC(tan, Tan);
MXNET_MKL_UNARY_MATH_FUNC(asin, Asin);
MXNET_MKL_UNARY_MATH_FUNC(acos, Acos);
MXNET_MKL_UNARY_MATH_FUNC(atan, Atan);

MXNET_MKL_UNARY_MATH_FUNC(sinh, Sinh);
MXNET_MKL_UNARY_MATH_FUNC(cosh, Cosh);
MXNET_MKL_UNARY_MATH_FUNC(tanh, Tanh);
MXNET_MKL_UNARY_MATH_FUNC(asinh, Asinh);
MXNET_MKL_UNARY_MATH_FUNC(acosh, Acosh);
MXNET_MKL_UNARY_MATH_FUNC(atanh, Atanh);

MXNET_MKL_UNARY_MATH_FUNC(sqrt, Sqrt);
MXNET_MKL_UNARY_MATH_FUNC(abs, Abs);
MXNET_MKL_UNARY_MATH_FUNC(cbrt, Cbrt);
MXNET_MKL_UNARY_MATH_FUNC(round, Round);
MXNET_MKL_UNARY_MATH_FUNC(ceil, Ceil);
MXNET_MKL_UNARY_MATH_FUNC(floor, Floor);
MXNET_MKL_UNARY_MATH_FUNC(trunc, Trunc);

MXNET_MKL_UNARY_MATH_FUNC(lgamma, LGamma);
MXNET_MKL_UNARY_MATH_FUNC(tgamma, TGamma);
MXNET_MKL_UNARY_MATH_FUNC(square, Sqr);

MXNET_MKL_BINARY_MATH_FUNC(add, Add);
MXNET_MKL_BINARY_MATH_FUNC(sub, Sub);
MXNET_MKL_BINARY_MATH_FUNC(mul, Mul);
MXNET_MKL_BINARY_MATH_FUNC(pow, Pow);
MXNET_MKL_BINARY_MATH_FUNC(hypot, Hypot);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does all of these functions will be mapped automatically when MKL is enabled?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. We just put all the VML functions here. We think these functions can be leveraged by MXNet in the future. But currently it need to change the registration of each operator to use these functions. In this PR we only optimized some operators which are used in BERT. We propose to optimize others when we face performance problems on them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. We can add it back when we use it; otherwise, it is a little confusion for other developers.



template <typename DType>
MSHADOW_XINLINE static void sub_(index_t n, DType *in, DType b, DType *dst) {
for (index_t i = 0; i < n; i++)
pengzhao-intel marked this conversation as resolved.
Show resolved Hide resolved
dst[i] = in[i] - b;
}

template <typename DType>
MSHADOW_XINLINE static void div_(index_t n, DType *in, DType b, DType *dst) {
for (index_t i = 0; i < n; i++)
dst[i] = in[i] / b;
}

template <typename DType>
MSHADOW_XINLINE static void sum_(index_t n, DType *in, DType *dst) {
// dst[0] = cblas_sasum(n, in, 1);
DType sum = 0.0f;
for (index_t i = 0; i < n; i++)
sum += in[i];

dst[0] = sum;
}

template <typename DType>
MSHADOW_XINLINE static void max_(index_t n, DType *in, DType *dst) {
dst[0] = in[0];
for (index_t i = 1; i < n; i++)
dst[0] = (dst[0] < in[i]) ? in[i] : dst[0];
}

// LayerNorm on the last dimension
template <typename DType>
MSHADOW_XINLINE static void LayerNormLastDim(index_t m,
index_t n,
DType *a,
DType *b,
DType *ws,
DType *gamma,
DType *beta,
DType *mean,
DType *var,
DType eps) {
auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
#pragma omp parallel for num_threads(nthreads)
for (index_t i = 0; i < m; i++) {
DType* in_offset = a + i * n;
DType* out_offset = b + i * n;
DType* ws_offset = ws + i * n;

sum_(n, in_offset, &(mean[i]));
mean[i] /= n;
sub_(n, in_offset, mean[i], out_offset);
square::Vectorize(n, out_offset, ws_offset);
sum_(n, ws_offset, &(var[i]));
var[i] = math::sqrt(var[i] / n + eps);

mul::Vectorize(n, out_offset, gamma, out_offset);
div_(n, out_offset, var[i], out_offset);
add::Vectorize(n, out_offset, beta, out_offset);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance to fusion some of these operations to reduce the memory bandwidth?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much faster is this version compared to the mshadow one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading the code, I think the current implementation, which relies on the vectorized operations, should be fast at scaling and shifting the data (data * gamma & data + beta). One possible improvement is to use the Welford's online algorithm (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance) to calculate the mean/variance in one pass, the code will look like this:

template <typename DType>
MSHADOW_XINLINE static void mean_var_(index_t n, DType *in, DType *mean, DType* variance) {
  DType sigma2 = 0;
  DType mean_v = 0;
  DType old_mean_v = 0;
  for (index_t i = 0; i < n; i++) {
    DType x = in[i];
    old_mean_v = mean_v;
    mean_v += (x - old_mean_v) / (i + 1);
    sigma2 += (x - old_mean_v) * (x - mean_v);
  }    
  mean[0] = mean_v;
  variance[0] = sigma2 / n;
}


template <typename DType>
MSHADOW_XINLINE static void LayerNormLastDim(index_t m,
                                             index_t n,
                                             DType *a,
                                             DType *b,
                                             DType *ws,
                                             DType *gamma,
                                             DType *beta,
                                             DType *mean,
                                             DType *var,
                                             DType eps) {
  auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
#pragma omp parallel for num_threads(nthreads)
  for (index_t i = 0; i < m; i++) {
    DType ele_mean, ele_var;
    DType* in_offset = a + i * n;
    DType* out_offset = b + i * n;
    mean_var_(n, in_offset, &ele_mean, &ele_var);
    sub_(n, in_offset, ele_mean, out_offset);
    ele_var = math::sqrt(ele_var + eps);
    mul::Vectorize(n, out_offset, gamma, out_offset);
    div_(n, out_offset, ele_var, out_offset);
    add::Vectorize(n, out_offset, beta, out_offset);
    mean[i] = ele_mean;
    var[i] = ele_var;
  }
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengzhao-intel @sxjscience loops are fused in the latest commit. I also removed the required workspace but that means we can not leverage VML functions and need rely on compiler for vectorization.

}
}

template <typename DType>
MSHADOW_XINLINE static void LogSoftmaxLastDim(index_t m,
index_t n,
DType *a,
DType *b) {
auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
#pragma omp parallel for num_threads(nthreads)
for (index_t i = 0; i < m; i++) {
DType* in_offset = a + i * n;
DType* out_offset = b + i * n;

DType b, logsum;
max_(n, in_offset, &b);
sub_(n, in_offset, b, out_offset);
exp::Vectorize(n, out_offset, out_offset);
sum_(n, out_offset, &logsum);
logsum = b + logf(logsum);
sub_(n, in_offset, logsum, out_offset);
}
}

} // namespace mkl_func
} // namespace op
} // namespace mxnet
#endif // MSHADOW_USE_MKL == 1
#endif // MXNET_OPERATOR_MKL_FUNCTIONS_INL_H_
8 changes: 4 additions & 4 deletions src/operator/nn/layer_norm-inl.h
Expand Up @@ -63,6 +63,9 @@ struct LayerNormParam : public dmlc::Parameter<LayerNormParam> {
}
};

static int GetRealAxis(int axis, int ndim) {
return axis < 0 ? (axis + ndim) : axis;
}

template<typename xpu>
void LayerNormCompute(const nnvm::NodeAttrs& attrs,
Expand All @@ -74,10 +77,7 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs,
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
if (req[0] == kNullOp) return;
CHECK_NE(req[0], kAddTo);
int axis = param.axis;
if (axis < 0) {
axis += static_cast<int>(inputs[0].ndim());
}
int axis = GetRealAxis(param.axis, inputs[0].ndim());
CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis;
CHECK_EQ(inputs.size(), 3U);
Stream<xpu> *s = ctx.get_stream<xpu>();
Expand Down
65 changes: 61 additions & 4 deletions src/operator/nn/layer_norm.cc
Expand Up @@ -27,6 +27,10 @@
#include <nnvm/op_attr_types.h>
#include "../elemwise_op_common.h"

#if MSHADOW_USE_MKL == 1
#include "../mkl_functions-inl.h"
#endif

namespace mxnet {
namespace op {

Expand All @@ -39,10 +43,7 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const mxnet::TShape &dshape = in_shape->at(layernorm::kData);
int axis = param.axis;
if (axis < 0) {
axis += dshape.ndim();
}
int axis = GetRealAxis(param.axis, dshape.ndim());
CHECK(axis >= 0 && axis < dshape.ndim())
<< "Channel axis out of range: axis=" << param.axis;

Expand All @@ -64,6 +65,58 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
return true;
}

#if MSHADOW_USE_MKL == 1
void LayerNormComputeMKL(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
if (req[0] == kNullOp) return;
CHECK_NE(req[0], kAddTo);
CHECK_EQ(inputs.size(), 3U);
int axis = GetRealAxis(param.axis, inputs[0].ndim());

if (axis == (inputs[layernorm::kData].ndim() - 1) ||
(inputs[0].type_flag_ != kFloat32 && inputs[0].type_flag_ != kFloat64)) {
Stream<cpu> *s = ctx.get_stream<cpu>();
// Compute necessary data for the reduce operation.
mxnet::TShape red_src_shape, red_dst_shape;
BroadcastReduceShapeCompact(inputs[layernorm::kData].shape_, outputs[layernorm::kMean].shape_,
&red_src_shape, &red_dst_shape);
const TBlob in_data = inputs[layernorm::kData].reshape(red_src_shape);
const TBlob mean_data = outputs[layernorm::kMean].reshape(red_dst_shape);
const TBlob std_data = outputs[layernorm::kStd].reshape(red_dst_shape);
const int outter_size = red_dst_shape.Size();
const int channel_size = red_src_shape.Size() / red_dst_shape.Size();

// Initialize the workspace
Tensor<cpu, 1, char> workspace;
size_t workspace_size = 0;
MSHADOW_SGL_DBL_TYPE_SWITCH(in_data.type_flag_, DType, {
workspace_size = in_data.Size() * sizeof (DType);
});
workspace = ctx.requested[0].get_space_typed<cpu, 1, char>(Shape1(workspace_size), s);

// call
MSHADOW_SGL_DBL_TYPE_SWITCH(in_data.type_flag_, DType, {
mkl_func::LayerNormLastDim(outter_size, channel_size,
in_data.dptr<DType>(),
outputs[layernorm::kOut].dptr<DType>(),
reinterpret_cast<DType*>(workspace.dptr_),
inputs[layernorm::kGamma].dptr<DType>(),
inputs[layernorm::kBeta].dptr<DType>(),
outputs[layernorm::kMean].dptr<DType>(),
outputs[layernorm::kStd].dptr<DType>(),
static_cast<DType>(param.eps));
});
} else {
// fallback
LayerNormCompute<cpu>(attrs, ctx, inputs, req, outputs);
}
}
#endif

NNVM_REGISTER_OP(LayerNorm)
.describe(R"code(Layer normalization.
Expand Down Expand Up @@ -110,7 +163,11 @@ axis to be the last item in the input shape.
})
.set_attr<mxnet::FInferShape>("FInferShape", LayerNormShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 3>)
#if MSHADOW_USE_MKL == 1
.set_attr<FCompute>("FCompute<cpu>", LayerNormComputeMKL)
#else
.set_attr<FCompute>("FCompute<cpu>", LayerNormCompute<cpu>)
#endif
.set_attr<nnvm::FGradient>("FGradient", [](const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads;
Expand Down