diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index cff35a3cef7f..b57ab45891e9 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file dropout-inl.h * \brief - * \author Bing Xu, Da Zheng + * \author Bing Xu, Da Zheng, Hang Zhang */ #ifndef MXNET_OPERATOR_NN_DROPOUT_INL_H_ @@ -37,6 +37,7 @@ #include "../mxnet_op.h" #include "../mshadow_op.h" #include "../random/sampler.h" +#include "../tensor/elemwise_binary_broadcast_op.h" #if defined(USE_MKL) && defined(_OPENMP) #include @@ -55,9 +56,12 @@ enum DropoutOpMode {kTraining, kAlways}; namespace mxnet { namespace op { +const int MAX_DIM = 5; + struct DropoutParam : public dmlc::Parameter { float p; int mode; + TShape axes; DMLC_DECLARE_PARAMETER(DropoutParam) { DMLC_DECLARE_FIELD(p).set_default(0.5) .set_range(0, 1) @@ -67,6 +71,8 @@ struct DropoutParam : public dmlc::Parameter { .add_enum("always", dropout::kAlways) .set_default(dropout::kTraining) .describe("Whether to only turn on dropout during training or to also turn on for inference."); + DMLC_DECLARE_FIELD(axes).set_default(TShape()) + .describe("Axes for variational dropout kernel."); } }; // struct DropoutParam @@ -205,10 +211,25 @@ class DropoutOp { }); } }; + struct BernoulliKernel { + /*! \brief Bernoulli kernel for generating mask */ + MSHADOW_XINLINE static void Map(int id, + RandGenerator gen, + const int N, + const int step, + DType *mask_out, + const real_t pkeep) { + RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, { + const real_t rand_num = static_cast(genImpl.uniform()); + mask_out[i] = mshadow_op::threshold::Map(rand_num, pkeep) * (1.0f / pkeep); + }); + } + }; void Init(const DropoutParam ¶m) { this->pkeep_ = 1.0f - param.p; this->mode_ = static_cast(param.mode); + this->axes_ = param.axes; } void Forward(const OpContext &ctx, @@ -225,14 +246,46 @@ class DropoutOp { if (ctx.is_train || this->mode_ == dropout::kAlways) { RandGenerator *pgen = ctx.requested[0].get_parallel_random(); CHECK_NOTNULL(pgen); - if (!MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { + if (this->axes_.ndim() != 0 || !MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { const TBlob &mask = out_data[dropout::kMask]; CHECK(req[dropout::kOut] != kAddTo); - LaunchRNG(s, pgen, out.Size(), + if (this->axes_.ndim() == 0) { + // standard case for dropout + LaunchRNG(s, pgen, out.Size(), out.dptr(), mask.dptr(), in_data[dropout::kData].dptr(), this->pkeep_); + return; + } + // initialize the mask + LaunchRNG(s, pgen, out.Size(), + mask.dptr(), + this->pkeep_); + // broadcast mul + TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_, + mask.shape_, out.shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), in_data[dropout::kData].dptr(), + mask.dptr()); + }); + } else { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[dropout::kOut], + lstride, rstride, oshape, + in_data[dropout::kData].dptr(), + mask.dptr(), out.dptr()); + }); + } } } else { const TBlob& data = in_data[dropout::kData]; @@ -257,15 +310,40 @@ class DropoutOp { using namespace mshadow::expr; Stream *s = ctx.get_stream(); if (ctx.is_train || mode_ == dropout::kAlways) { - if (!MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) { + if (this->axes_.ndim() != 0 || !MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) { const TBlob &gdata = in_grad[dropout::kData]; const TBlob &grad = out_grad[dropout::kOut]; const TBlob &mask = out_data[dropout::kMask]; - CHECK_EQ(grad.Size(), mask.Size()); - MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); - }); + if (this->axes_.ndim() == 0) { + // standard case for dropout + CHECK_EQ(grad.Size(), mask.Size()); + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); + }); + return; + } + // broardcast mul + TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(grad.shape_, + mask.shape_, gdata.shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); + }); + } else { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + grad.dptr(), mask.dptr(), gdata.dptr()); + }); + } } } else { const TBlob& gdata = in_grad[dropout::kData]; @@ -286,6 +364,7 @@ class DropoutOp { real_t pkeep_; /*! \brief Dropout mode */ dropout::DropoutOpMode mode_; + TShape axes_; }; // class DropoutOp template diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index dd5f1e58fbe5..3021e0105b4f 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file dropout.cc * \brief - * \author Bing Xu, Da Zheng + * \author Bing Xu, Da Zheng, Hang Zhang */ #include "./dropout-inl.h" @@ -93,10 +93,14 @@ Example:: std::vector *in_shape, std::vector *out_shape){ using namespace mshadow; CHECK_EQ(in_shape->size(), 1U); - const TShape &dshape = in_shape->at(0); + const DropoutParam& param = nnvm::get(attrs.parsed); + TShape dshape(in_shape->at(0)); if (dshape.ndim() == 0) return false; out_shape->clear(); out_shape->push_back(dshape); + for (index_t i = 0; i < param.axes.ndim(); ++i) { + dshape[param.axes[i]] = 1; + } out_shape->push_back(dshape); return true; }) diff --git a/src/operator/nn/dropout.cu b/src/operator/nn/dropout.cu index e655278822a4..832490b08f1f 100644 --- a/src/operator/nn/dropout.cu +++ b/src/operator/nn/dropout.cu @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file dropout.cc * \brief - * \author Bing Xu, Da Zheng + * \author Bing Xu, Da Zheng, Hang Zhang */ #include "./dropout-inl.h"