Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add axes support to Dropout for variational dropout in NLP (#9931)
Browse files Browse the repository at this point in the history
* add axes support to dropout for variational dropout, test pending, mkl part hasn't been updated

* avoid copy code

* fix typo

* consider non broadcast case

* fix backward

* avoid mkl

* rm redundent

* condition check for standard dropout
  • Loading branch information
zhanghang1989 authored and yzhliu committed Mar 6, 2018
1 parent e2b1a56 commit 40de6ab
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 12 deletions.
97 changes: 88 additions & 9 deletions src/operator/nn/dropout-inl.h
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2015 by Contributors
* \file dropout-inl.h
* \brief
* \author Bing Xu, Da Zheng
* \author Bing Xu, Da Zheng, Hang Zhang
*/

#ifndef MXNET_OPERATOR_NN_DROPOUT_INL_H_
Expand All @@ -37,6 +37,7 @@
#include "../mxnet_op.h"
#include "../mshadow_op.h"
#include "../random/sampler.h"
#include "../tensor/elemwise_binary_broadcast_op.h"

#if defined(USE_MKL) && defined(_OPENMP)
#include <omp.h>
Expand All @@ -55,9 +56,12 @@ enum DropoutOpMode {kTraining, kAlways};
namespace mxnet {
namespace op {

const int MAX_DIM = 5;

struct DropoutParam : public dmlc::Parameter<DropoutParam> {
float p;
int mode;
TShape axes;
DMLC_DECLARE_PARAMETER(DropoutParam) {
DMLC_DECLARE_FIELD(p).set_default(0.5)
.set_range(0, 1)
Expand All @@ -67,6 +71,8 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> {
.add_enum("always", dropout::kAlways)
.set_default(dropout::kTraining)
.describe("Whether to only turn on dropout during training or to also turn on for inference.");
DMLC_DECLARE_FIELD(axes).set_default(TShape())
.describe("Axes for variational dropout kernel.");
}
}; // struct DropoutParam

Expand Down Expand Up @@ -205,10 +211,25 @@ class DropoutOp {
});
}
};
struct BernoulliKernel {
/*! \brief Bernoulli kernel for generating mask */
MSHADOW_XINLINE static void Map(int id,
RandGenerator<xpu, DType> gen,
const int N,
const int step,
DType *mask_out,
const real_t pkeep) {
RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, {
const real_t rand_num = static_cast<real_t>(genImpl.uniform());
mask_out[i] = mshadow_op::threshold::Map<real_t>(rand_num, pkeep) * (1.0f / pkeep);
});
}
};

void Init(const DropoutParam &param) {
this->pkeep_ = 1.0f - param.p;
this->mode_ = static_cast<dropout::DropoutOpMode>(param.mode);
this->axes_ = param.axes;
}

void Forward(const OpContext &ctx,
Expand All @@ -225,14 +246,46 @@ class DropoutOp {
if (ctx.is_train || this->mode_ == dropout::kAlways) {
RandGenerator<xpu, DType> *pgen = ctx.requested[0].get_parallel_random<xpu, DType>();
CHECK_NOTNULL(pgen);
if (!MKLForward(s, pgen, this->pkeep_, in_data, out_data)) {
if (this->axes_.ndim() != 0 || !MKLForward(s, pgen, this->pkeep_, in_data, out_data)) {
const TBlob &mask = out_data[dropout::kMask];
CHECK(req[dropout::kOut] != kAddTo);
LaunchRNG<DropoutKernel, xpu>(s, pgen, out.Size(),
if (this->axes_.ndim() == 0) {
// standard case for dropout
LaunchRNG<DropoutKernel, xpu>(s, pgen, out.Size(),
out.dptr<DType>(),
mask.dptr<DType>(),
in_data[dropout::kData].dptr<DType>(),
this->pkeep_);
return;
}
// initialize the mask
LaunchRNG<BernoulliKernel, xpu>(s, pgen, out.Size(),
mask.dptr<DType>(),
this->pkeep_);
// broadcast mul
TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_,
mask.shape_, out.shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
s, out.Size(), out.dptr<DType>(), in_data[dropout::kData].dptr<DType>(),
mask.dptr<DType>());
});
} else {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType,
mshadow_op::mul>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[dropout::kOut],
lstride, rstride, oshape,
in_data[dropout::kData].dptr<DType>(),
mask.dptr<DType>(), out.dptr<DType>());
});
}
}
} else {
const TBlob& data = in_data[dropout::kData];
Expand All @@ -257,15 +310,40 @@ class DropoutOp {
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
if (ctx.is_train || mode_ == dropout::kAlways) {
if (!MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) {
if (this->axes_.ndim() != 0 || !MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) {
const TBlob &gdata = in_grad[dropout::kData];
const TBlob &grad = out_grad[dropout::kOut];
const TBlob &mask = out_data[dropout::kMask];
CHECK_EQ(grad.Size(), mask.Size());
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>());
});
if (this->axes_.ndim() == 0) {
// standard case for dropout
CHECK_EQ(grad.Size(), mask.Size());
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>());
});
return;
}
// broardcast mul
TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(grad.shape_,
mask.shape_, gdata.shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>());
});
} else {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType,
mshadow_op::mul>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape,
grad.dptr<DType>(), mask.dptr<DType>(), gdata.dptr<DType>());
});
}
}
} else {
const TBlob& gdata = in_grad[dropout::kData];
Expand All @@ -286,6 +364,7 @@ class DropoutOp {
real_t pkeep_;
/*! \brief Dropout mode */
dropout::DropoutOpMode mode_;
TShape axes_;
}; // class DropoutOp

template<typename xpu>
Expand Down
8 changes: 6 additions & 2 deletions src/operator/nn/dropout.cc
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2015 by Contributors
* \file dropout.cc
* \brief
* \author Bing Xu, Da Zheng
* \author Bing Xu, Da Zheng, Hang Zhang
*/

#include "./dropout-inl.h"
Expand Down Expand Up @@ -93,10 +93,14 @@ Example::
std::vector<TShape> *in_shape, std::vector<TShape> *out_shape){
using namespace mshadow;
CHECK_EQ(in_shape->size(), 1U);
const TShape &dshape = in_shape->at(0);
const DropoutParam& param = nnvm::get<DropoutParam>(attrs.parsed);
TShape dshape(in_shape->at(0));
if (dshape.ndim() == 0) return false;
out_shape->clear();
out_shape->push_back(dshape);
for (index_t i = 0; i < param.axes.ndim(); ++i) {
dshape[param.axes[i]] = 1;
}
out_shape->push_back(dshape);
return true;
})
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/dropout.cu
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2015 by Contributors
* \file dropout.cc
* \brief
* \author Bing Xu, Da Zheng
* \author Bing Xu, Da Zheng, Hang Zhang
*/

#include "./dropout-inl.h"
Expand Down

0 comments on commit 40de6ab

Please sign in to comment.