Add axes support to Dropout for variational dropout in NLP #9931
Changes from 6 commits
14dbff4
372a154
2d4333b
9785fd6
b4057b8
d03be57
785c212
6971886
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ | |
* Copyright (c) 2015 by Contributors | ||
* \file dropout-inl.h | ||
* \brief | ||
* \author Bing Xu, Da Zheng | ||
* \author Bing Xu, Da Zheng, Hang Zhang | ||
*/ | ||
|
||
#ifndef MXNET_OPERATOR_NN_DROPOUT_INL_H_ | ||
|
@@ -37,6 +37,7 @@ | |
#include "../mxnet_op.h" | ||
#include "../mshadow_op.h" | ||
#include "../random/sampler.h" | ||
#include "../tensor/elemwise_binary_broadcast_op.h" | ||
|
||
#if defined(USE_MKL) && defined(_OPENMP) | ||
#include <omp.h> | ||
|
@@ -55,9 +56,12 @@ enum DropoutOpMode {kTraining, kAlways}; | |
namespace mxnet { | ||
namespace op { | ||
|
||
const int MAX_DIM = 5; | ||
|
||
struct DropoutParam : public dmlc::Parameter<DropoutParam> { | ||
float p; | ||
int mode; | ||
TShape axes; | ||
DMLC_DECLARE_PARAMETER(DropoutParam) { | ||
DMLC_DECLARE_FIELD(p).set_default(0.5) | ||
.set_range(0, 1) | ||
|
@@ -67,6 +71,8 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> { | |
.add_enum("always", dropout::kAlways) | ||
.set_default(dropout::kTraining) | ||
.describe("Whether to only turn on dropout during training or to also turn on for inference."); | ||
DMLC_DECLARE_FIELD(axes).set_default(TShape()) | ||
.describe("Axes for variational dropout kernel."); | ||
} | ||
}; // struct DropoutParam | ||
|
||
|
@@ -178,37 +184,42 @@ class DropoutOp { | |
/*! | ||
* \brief Dropout kernel, compute dropout tensor | ||
*/ | ||
struct DropoutKernel { | ||
/*! | ||
* \brief Dropout kernel function | ||
* \param id Thread number (0-based representing count) | ||
* \param gen Random number generator | ||
* \param N Total number of items in the output | ||
* \param step Step between items, related to parallelism | ||
* \param dropout_out Output dropout values | ||
* \param mask_out Output mask (is multiplied to create dropout output, may be 0) | ||
* \param input_data Input data to perform the dropout on | ||
* \param pkeep Dropout rate (keep when the generated random number is less than this value) | ||
*/ | ||
struct BernoulliKernel { | ||
/*! \brief Bernoulli kernel for generating mask */ | ||
MSHADOW_XINLINE static void Map(int id, | ||
RandGenerator<xpu, DType> gen, | ||
const int N, | ||
const int step, | ||
DType *dropout_out, | ||
DType *mask_out, | ||
const DType *input_data, | ||
const real_t pkeep) { | ||
RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, { | ||
const real_t rand_num = static_cast<real_t>(genImpl.uniform()); | ||
mask_out[i] = mshadow_op::threshold::Map<real_t>(rand_num, pkeep) * (1.0f / pkeep); | ||
dropout_out[i] = input_data[i] * mask_out[i]; | ||
}); | ||
} | ||
}; | ||
|
||
struct ElemwiseMulKernel { | ||
/*! Elementwise multiply kernel */ | ||
MSHADOW_XINLINE static void Map(int id, | ||
RandGenerator<xpu, DType> gen, | ||
const int N, | ||
const int step, | ||
DType *dropout_out, | ||
DType *mask_out, | ||
const DType *input_data) { | ||
const int start = id * step; | ||
const int end = start + step; | ||
for (int i = start; i < end && i < N; ++i) { | ||
dropout_out[i] = input_data[i] * mask_out[i]; | ||
} | ||
} | ||
}; | ||
|
||
void Init(const DropoutParam ¶m) { | ||
this->pkeep_ = 1.0f - param.p; | ||
this->mode_ = static_cast<dropout::DropoutOpMode>(param.mode); | ||
this->axes = param.axes; | ||
} | ||
|
||
void Forward(const OpContext &ctx, | ||
|
@@ -225,14 +236,37 @@ class DropoutOp { | |
if (ctx.is_train || this->mode_ == dropout::kAlways) { | ||
RandGenerator<xpu, DType> *pgen = ctx.requested[0].get_parallel_random<xpu, DType>(); | ||
CHECK_NOTNULL(pgen); | ||
if (!MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { | ||
if (this->axes.ndim() != 0 || !MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { | ||
const TBlob &mask = out_data[dropout::kMask]; | ||
CHECK(req[dropout::kOut] != kAddTo); | ||
LaunchRNG<DropoutKernel, xpu>(s, pgen, out.Size(), | ||
out.dptr<DType>(), | ||
mask.dptr<DType>(), | ||
in_data[dropout::kData].dptr<DType>(), | ||
this->pkeep_); | ||
// initialize the mask | ||
LaunchRNG<BernoulliKernel, xpu>(s, pgen, out.Size(), | ||
mask.dptr<DType>(), | ||
this->pkeep_); | ||
// broadcast mul | ||
TShape new_lshape, new_rshape, new_oshape; | ||
int ndim = BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_, | ||
mask.shape_, out.shape_, | ||
&new_lshape, &new_rshape, &new_oshape); | ||
if (!ndim) { | ||
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, { | ||
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch( | ||
s, out.Size(), out.dptr<DType>(), in_data[dropout::kData].dptr<DType>(), | ||
mask.dptr<DType>()); | ||
}); | ||
} else { | ||
BROADCAST_NDIM_SWITCH(ndim, NDim, { | ||
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>(); | ||
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>()); | ||
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>()); | ||
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, | ||
mshadow_op::mul>, xpu>:: | ||
template LaunchEx(s, new_oshape.Size(), req[dropout::kOut], | ||
lstride, rstride, oshape, | ||
in_data[dropout::kData].dptr<DType>(), | ||
mask.dptr<DType>(), out.dptr<DType>()); | ||
}); | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if MKL enables, the broadcast will not happen? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't updated the MKL code for variational dropout (enabling axes). I need help with MKL There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for non default axes, you can fall back to this op instead of the MKL op There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx @cjolivier01 . I added the condition check here https://github.com/apache/incubator-mxnet/pull/9931/files#diff-4aea2cc24c0bb4e8e48face9faf4aa26R249 |
||
} else { | ||
const TBlob& data = in_data[dropout::kData]; | ||
|
@@ -257,15 +291,31 @@ class DropoutOp { | |
using namespace mshadow::expr; | ||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
if (ctx.is_train || mode_ == dropout::kAlways) { | ||
if (!MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) { | ||
if (this->axes.ndim() != 0 || !MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) { | ||
const TBlob &gdata = in_grad[dropout::kData]; | ||
const TBlob &grad = out_grad[dropout::kOut]; | ||
const TBlob &mask = out_data[dropout::kMask]; | ||
CHECK_EQ(grad.Size(), mask.Size()); | ||
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { | ||
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch( | ||
s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>()); | ||
}); | ||
// broardcast mul | ||
TShape new_lshape, new_rshape, new_oshape; | ||
int ndim = BinaryBroadcastShapeCompact(grad.shape_, | ||
mask.shape_, gdata.shape_, | ||
&new_lshape, &new_rshape, &new_oshape); | ||
if (!ndim) { | ||
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { | ||
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch( | ||
s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>()); | ||
}); | ||
} else { | ||
BROADCAST_NDIM_SWITCH(ndim, NDim, { | ||
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>(); | ||
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>()); | ||
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>()); | ||
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, | ||
mshadow_op::mul>, xpu>:: | ||
template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, | ||
grad.dptr<DType>(), mask.dptr<DType>(), gdata.dptr<DType>()); | ||
}); | ||
} | ||
} | ||
} else { | ||
const TBlob& gdata = in_grad[dropout::kData]; | ||
|
@@ -286,6 +336,7 @@ class DropoutOp { | |
real_t pkeep_; | ||
/*! \brief Dropout mode */ | ||
dropout::DropoutOpMode mode_; | ||
TShape axes; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: member variable name should end in an underscore There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx 👍 |
||
}; // class DropoutOp | ||
|
||
template<typename xpu> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not saying it needs to be done, but have you considered merging this operation with the other kernel, perhaps by deriving from broadcast_kernel or passing a modified version of the mul OP to broadcast_kernel?
Making two full passes across the memory is going to cause a performance hit due to both OMP overhead as well as CPU cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand you clearly.
I separate the original dropout kernel into two parts: 1) BernoulliKernel 2) broad_cast Mul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand you clearly.
I separate the original dropout kernel into two parts: 1) BernoulliKernel 2) broad_cast mul
so that we can enable axes support for variational dropout.
Thx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. What's the performance impact of using two kernels instead of one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx @cjolivier01 . I get your point for efficiency. I have added a condition check for standard dropout, which has the same efficiency when none-axes provided:
https://github.com/apache/incubator-mxnet/pull/9931/files#diff-4aea2cc24c0bb4e8e48face9faf4aa26R252