Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add axes support to Dropout for variational dropout in NLP #9931

Merged
merged 8 commits into from Mar 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
97 changes: 88 additions & 9 deletions src/operator/nn/dropout-inl.h
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2015 by Contributors
* \file dropout-inl.h
* \brief
* \author Bing Xu, Da Zheng
* \author Bing Xu, Da Zheng, Hang Zhang
*/

#ifndef MXNET_OPERATOR_NN_DROPOUT_INL_H_
Expand All @@ -37,6 +37,7 @@
#include "../mxnet_op.h"
#include "../mshadow_op.h"
#include "../random/sampler.h"
#include "../tensor/elemwise_binary_broadcast_op.h"

#if defined(USE_MKL) && defined(_OPENMP)
#include <omp.h>
Expand All @@ -55,9 +56,12 @@ enum DropoutOpMode {kTraining, kAlways};
namespace mxnet {
namespace op {

const int MAX_DIM = 5;

struct DropoutParam : public dmlc::Parameter<DropoutParam> {
float p;
int mode;
TShape axes;
DMLC_DECLARE_PARAMETER(DropoutParam) {
DMLC_DECLARE_FIELD(p).set_default(0.5)
.set_range(0, 1)
Expand All @@ -67,6 +71,8 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> {
.add_enum("always", dropout::kAlways)
.set_default(dropout::kTraining)
.describe("Whether to only turn on dropout during training or to also turn on for inference.");
DMLC_DECLARE_FIELD(axes).set_default(TShape())
.describe("Axes for variational dropout kernel.");
}
}; // struct DropoutParam

Expand Down Expand Up @@ -205,10 +211,25 @@ class DropoutOp {
});
}
};
struct BernoulliKernel {
/*! \brief Bernoulli kernel for generating mask */
MSHADOW_XINLINE static void Map(int id,
RandGenerator<xpu, DType> gen,
const int N,
const int step,
DType *mask_out,
const real_t pkeep) {
RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, {
const real_t rand_num = static_cast<real_t>(genImpl.uniform());
mask_out[i] = mshadow_op::threshold::Map<real_t>(rand_num, pkeep) * (1.0f / pkeep);
});
}
};

void Init(const DropoutParam &param) {
this->pkeep_ = 1.0f - param.p;
this->mode_ = static_cast<dropout::DropoutOpMode>(param.mode);
this->axes_ = param.axes;
}

void Forward(const OpContext &ctx,
Expand All @@ -225,14 +246,46 @@ class DropoutOp {
if (ctx.is_train || this->mode_ == dropout::kAlways) {
RandGenerator<xpu, DType> *pgen = ctx.requested[0].get_parallel_random<xpu, DType>();
CHECK_NOTNULL(pgen);
if (!MKLForward(s, pgen, this->pkeep_, in_data, out_data)) {
if (this->axes_.ndim() != 0 || !MKLForward(s, pgen, this->pkeep_, in_data, out_data)) {
const TBlob &mask = out_data[dropout::kMask];
CHECK(req[dropout::kOut] != kAddTo);
LaunchRNG<DropoutKernel, xpu>(s, pgen, out.Size(),
if (this->axes_.ndim() == 0) {
// standard case for dropout
LaunchRNG<DropoutKernel, xpu>(s, pgen, out.Size(),
out.dptr<DType>(),
mask.dptr<DType>(),
in_data[dropout::kData].dptr<DType>(),
this->pkeep_);
return;
}
// initialize the mask
LaunchRNG<BernoulliKernel, xpu>(s, pgen, out.Size(),
mask.dptr<DType>(),
this->pkeep_);
// broadcast mul
TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_,
mask.shape_, out.shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
s, out.Size(), out.dptr<DType>(), in_data[dropout::kData].dptr<DType>(),
mask.dptr<DType>());
});
} else {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType,
mshadow_op::mul>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[dropout::kOut],
lstride, rstride, oshape,
in_data[dropout::kData].dptr<DType>(),
mask.dptr<DType>(), out.dptr<DType>());
});
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if MKL enables, the broadcast will not happen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't updated the MKL code for variational dropout (enabling axes). I need help with MKL

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for non default axes, you can fall back to this op instead of the MKL op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

} else {
const TBlob& data = in_data[dropout::kData];
Expand All @@ -257,15 +310,40 @@ class DropoutOp {
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
if (ctx.is_train || mode_ == dropout::kAlways) {
if (!MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) {
if (this->axes_.ndim() != 0 || !MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) {
const TBlob &gdata = in_grad[dropout::kData];
const TBlob &grad = out_grad[dropout::kOut];
const TBlob &mask = out_data[dropout::kMask];
CHECK_EQ(grad.Size(), mask.Size());
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>());
});
if (this->axes_.ndim() == 0) {
// standard case for dropout
CHECK_EQ(grad.Size(), mask.Size());
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>());
});
return;
}
// broardcast mul
TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(grad.shape_,
mask.shape_, gdata.shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>());
});
} else {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType,
mshadow_op::mul>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape,
grad.dptr<DType>(), mask.dptr<DType>(), gdata.dptr<DType>());
});
}
}
} else {
const TBlob& gdata = in_grad[dropout::kData];
Expand All @@ -286,6 +364,7 @@ class DropoutOp {
real_t pkeep_;
/*! \brief Dropout mode */
dropout::DropoutOpMode mode_;
TShape axes_;
}; // class DropoutOp

template<typename xpu>
Expand Down
8 changes: 6 additions & 2 deletions src/operator/nn/dropout.cc
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2015 by Contributors
* \file dropout.cc
* \brief
* \author Bing Xu, Da Zheng
* \author Bing Xu, Da Zheng, Hang Zhang
*/

#include "./dropout-inl.h"
Expand Down Expand Up @@ -93,10 +93,14 @@ Example::
std::vector<TShape> *in_shape, std::vector<TShape> *out_shape){
using namespace mshadow;
CHECK_EQ(in_shape->size(), 1U);
const TShape &dshape = in_shape->at(0);
const DropoutParam& param = nnvm::get<DropoutParam>(attrs.parsed);
TShape dshape(in_shape->at(0));
if (dshape.ndim() == 0) return false;
out_shape->clear();
out_shape->push_back(dshape);
for (index_t i = 0; i < param.axes.ndim(); ++i) {
dshape[param.axes[i]] = 1;
}
out_shape->push_back(dshape);
return true;
})
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/dropout.cu
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2015 by Contributors
* \file dropout.cc
* \brief
* \author Bing Xu, Da Zheng
* \author Bing Xu, Da Zheng, Hang Zhang
*/

#include "./dropout-inl.h"
Expand Down