From 14dbff439adac0eb7f2692bf8219ff78339172b2 Mon Sep 17 00:00:00 2001 From: Zhang Date: Wed, 28 Feb 2018 16:51:38 -0800 Subject: [PATCH 1/8] add axes support to dropout for variational dropout, test pending, mkl part hasn't been updated --- src/operator/nn/dropout-inl.h | 148 ++++++++++++++++++++++++++++------ src/operator/nn/dropout.cc | 10 ++- src/operator/nn/dropout.cu | 2 +- 3 files changed, 132 insertions(+), 28 deletions(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index cff35a3cef7f..0f8f73d21647 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file dropout-inl.h * \brief - * \author Bing Xu, Da Zheng + * \author Bing Xu, Da Zheng, Hang Zhang */ #ifndef MXNET_OPERATOR_NN_DROPOUT_INL_H_ @@ -55,9 +55,12 @@ enum DropoutOpMode {kTraining, kAlways}; namespace mxnet { namespace op { +const int MAX_DIM = 5; + struct DropoutParam : public dmlc::Parameter { float p; int mode; + TShape axes; DMLC_DECLARE_PARAMETER(DropoutParam) { DMLC_DECLARE_FIELD(p).set_default(0.5) .set_range(0, 1) @@ -67,9 +70,92 @@ struct DropoutParam : public dmlc::Parameter { .add_enum("always", dropout::kAlways) .set_default(dropout::kTraining) .describe("Whether to only turn on dropout during training or to also turn on for inference."); + DMLC_DECLARE_FIELD(axes).set_default(TShape()) + .describe("Axes for variational dropout kernel."); } }; // struct DropoutParam +namespace mxnet_op { +template +struct binary_broadcast_kernel { + /*! \brief Map function for binary_broadcast_kernel */ + MSHADOW_XINLINE static void Map(int base, int length, OpReqType req, + const Shape &lstride, const Shape &rstride, + const Shape &oshape, DType *lhs, DType *rhs, + DType *out) { + Shape coord = unravel(base, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto ridx = static_cast(dot(coord, rstride)); + KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); + // starts from 1 to avoid extra inc at end of loop + for (int i = 1; i < length; ++i) { + inc(&coord, oshape, &lidx, lstride, &ridx, rstride); + // When tuning, don't actually run the op, since it's not going to be tuned against + // the actual op we'll eventually be using + KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx])); + } + } +}; +} // namespace mxnet_op + +#define BROADCAST_NDIM_SWITCH(ndim, NDim, ...) \ + if (ndim <= 2) { \ + const int NDim = 2; \ + {__VA_ARGS__} \ + } else if (ndim <= 4) { \ + const int NDim = 4; \ + {__VA_ARGS__} \ + } else if (ndim <= MAX_DIM) { \ + const int NDim = MAX_DIM; \ + {__VA_ARGS__} \ + } else { \ + LOG(FATAL) << "NDim too large "; \ + } + +inline int BinaryBroadcastShapeCompact(const TShape& lshape, const TShape& rshape, + const TShape& oshape, TShape *new_lshape, + TShape *new_rshape, TShape *new_oshape) { + if (lshape == rshape) return 0; + index_t odim = std::max(oshape.ndim(), MAX_DIM); + *new_lshape = TShape(odim); + *new_rshape = TShape(odim); + *new_oshape = TShape(odim); + index_t bl = oshape.ndim() - lshape.ndim(); + index_t br = oshape.ndim() - rshape.ndim(); + index_t j = 0, lprod = 1, rprod = 1, oprod = 1; + for (index_t i = 0; i < oshape.ndim(); ++i) { + index_t l = 1, r = 1, o = oshape[i]; + if (i >= bl) l = lshape[i-bl]; + if (i >= br) r = rshape[i-br]; + if ((lprod != rprod || l != r) && + lprod*l > 1 && rprod*r > 1) { + (*new_lshape)[j] = lprod; + (*new_rshape)[j] = rprod; + (*new_oshape)[j] = oprod; + lprod = rprod = oprod = 1; ++j; + } + lprod *= l; + rprod *= r; + oprod *= o; + } + if (lprod > 1 || rprod > 1) { + (*new_lshape)[j] = lprod; + (*new_rshape)[j] = rprod; + (*new_oshape)[j] = oprod; + ++j; + } + if (j <= MAX_DIM) { + BROADCAST_NDIM_SWITCH(j, NDim, { + new_lshape->assign(&(*new_lshape)[0], &(*new_lshape)[NDim]); + new_rshape->assign(&(*new_rshape)[0], &(*new_rshape)[NDim]); + new_oshape->assign(&(*new_oshape)[0], &(*new_oshape)[NDim]); + }); + } else { + LOG(FATAL) << "Too many broadcast dimensions with operands " << lshape << " " << rshape; + } + return j; +} + template class DropoutOp { #if defined(USE_MKL) && defined(_OPENMP) @@ -178,30 +264,17 @@ class DropoutOp { /*! * \brief Dropout kernel, compute dropout tensor */ - struct DropoutKernel { - /*! - * \brief Dropout kernel function - * \param id Thread number (0-based representing count) - * \param gen Random number generator - * \param N Total number of items in the output - * \param step Step between items, related to parallelism - * \param dropout_out Output dropout values - * \param mask_out Output mask (is multiplied to create dropout output, may be 0) - * \param input_data Input data to perform the dropout on - * \param pkeep Dropout rate (keep when the generated random number is less than this value) - */ + struct BernoulliKernel { + /*! \brief Bernoulli kernel for generating mask */ MSHADOW_XINLINE static void Map(int id, RandGenerator gen, const int N, const int step, - DType *dropout_out, DType *mask_out, - const DType *input_data, const real_t pkeep) { RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, { const real_t rand_num = static_cast(genImpl.uniform()); mask_out[i] = mshadow_op::threshold::Map(rand_num, pkeep) * (1.0f / pkeep); - dropout_out[i] = input_data[i] * mask_out[i]; }); } }; @@ -228,11 +301,27 @@ class DropoutOp { if (!MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { const TBlob &mask = out_data[dropout::kMask]; CHECK(req[dropout::kOut] != kAddTo); - LaunchRNG(s, pgen, out.Size(), - out.dptr(), - mask.dptr(), - in_data[dropout::kData].dptr(), - this->pkeep_); + // initialize the mask + LaunchRNG(s, pgen, out.Size(), + mask.dptr(), + this->pkeep_); + if (req[0] != kNullOp) { + // broardcast mul + TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_, + mask.shape_, out.shape_, + &new_lshape, &new_rshape, &new_oshape); + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + in_data[dropout::kData].dptr(), + mask.dptr(), out.dptr()); + }); + } } } else { const TBlob& data = in_data[dropout::kData]; @@ -261,10 +350,19 @@ class DropoutOp { const TBlob &gdata = in_grad[dropout::kData]; const TBlob &grad = out_grad[dropout::kOut]; const TBlob &mask = out_data[dropout::kMask]; - CHECK_EQ(grad.Size(), mask.Size()); - MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); + // broardcast mul + TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(grad.shape_, + mask.shape_, gdata.shape_, + &new_lshape, &new_rshape, &new_oshape); + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + grad.dptr(), mask.dptr(), gdata.dptr()); }); } } else { diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index dd5f1e58fbe5..c457d939dda3 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file dropout.cc * \brief - * \author Bing Xu, Da Zheng + * \author Bing Xu, Da Zheng, Hang Zhang */ #include "./dropout-inl.h" @@ -93,10 +93,16 @@ Example:: std::vector *in_shape, std::vector *out_shape){ using namespace mshadow; CHECK_EQ(in_shape->size(), 1U); - const TShape &dshape = in_shape->at(0); + const DropoutParam& param = nnvm::get(attrs.parsed); + TShape dshape(in_shape->at(0)); if (dshape.ndim() == 0) return false; out_shape->clear(); out_shape->push_back(dshape); + if (param.axes.ndim() != 0) { + for (index_t i = 0; i < param.axes.ndim(); ++i) { + dshape[param.axes[i]] = 1; + } + } out_shape->push_back(dshape); return true; }) diff --git a/src/operator/nn/dropout.cu b/src/operator/nn/dropout.cu index e655278822a4..832490b08f1f 100644 --- a/src/operator/nn/dropout.cu +++ b/src/operator/nn/dropout.cu @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file dropout.cc * \brief - * \author Bing Xu, Da Zheng + * \author Bing Xu, Da Zheng, Hang Zhang */ #include "./dropout-inl.h" From 372a154ef5ead53064e837bfd8aada98fa76a900 Mon Sep 17 00:00:00 2001 From: Zhang Date: Fri, 2 Mar 2018 10:37:53 -0800 Subject: [PATCH 2/8] avoid copy code --- src/operator/nn/dropout-inl.h | 82 +---------------------------------- 1 file changed, 1 insertion(+), 81 deletions(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 0f8f73d21647..2d2e0c56fe03 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -37,6 +37,7 @@ #include "../mxnet_op.h" #include "../mshadow_op.h" #include "../random/sampler.h" +#include "../tensor/elemwise_binary_broadcast_op.h" #if defined(USE_MKL) && defined(_OPENMP) #include @@ -75,87 +76,6 @@ struct DropoutParam : public dmlc::Parameter { } }; // struct DropoutParam -namespace mxnet_op { -template -struct binary_broadcast_kernel { - /*! \brief Map function for binary_broadcast_kernel */ - MSHADOW_XINLINE static void Map(int base, int length, OpReqType req, - const Shape &lstride, const Shape &rstride, - const Shape &oshape, DType *lhs, DType *rhs, - DType *out) { - Shape coord = unravel(base, oshape); - auto lidx = static_cast(dot(coord, lstride)); - auto ridx = static_cast(dot(coord, rstride)); - KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); - // starts from 1 to avoid extra inc at end of loop - for (int i = 1; i < length; ++i) { - inc(&coord, oshape, &lidx, lstride, &ridx, rstride); - // When tuning, don't actually run the op, since it's not going to be tuned against - // the actual op we'll eventually be using - KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx])); - } - } -}; -} // namespace mxnet_op - -#define BROADCAST_NDIM_SWITCH(ndim, NDim, ...) \ - if (ndim <= 2) { \ - const int NDim = 2; \ - {__VA_ARGS__} \ - } else if (ndim <= 4) { \ - const int NDim = 4; \ - {__VA_ARGS__} \ - } else if (ndim <= MAX_DIM) { \ - const int NDim = MAX_DIM; \ - {__VA_ARGS__} \ - } else { \ - LOG(FATAL) << "NDim too large "; \ - } - -inline int BinaryBroadcastShapeCompact(const TShape& lshape, const TShape& rshape, - const TShape& oshape, TShape *new_lshape, - TShape *new_rshape, TShape *new_oshape) { - if (lshape == rshape) return 0; - index_t odim = std::max(oshape.ndim(), MAX_DIM); - *new_lshape = TShape(odim); - *new_rshape = TShape(odim); - *new_oshape = TShape(odim); - index_t bl = oshape.ndim() - lshape.ndim(); - index_t br = oshape.ndim() - rshape.ndim(); - index_t j = 0, lprod = 1, rprod = 1, oprod = 1; - for (index_t i = 0; i < oshape.ndim(); ++i) { - index_t l = 1, r = 1, o = oshape[i]; - if (i >= bl) l = lshape[i-bl]; - if (i >= br) r = rshape[i-br]; - if ((lprod != rprod || l != r) && - lprod*l > 1 && rprod*r > 1) { - (*new_lshape)[j] = lprod; - (*new_rshape)[j] = rprod; - (*new_oshape)[j] = oprod; - lprod = rprod = oprod = 1; ++j; - } - lprod *= l; - rprod *= r; - oprod *= o; - } - if (lprod > 1 || rprod > 1) { - (*new_lshape)[j] = lprod; - (*new_rshape)[j] = rprod; - (*new_oshape)[j] = oprod; - ++j; - } - if (j <= MAX_DIM) { - BROADCAST_NDIM_SWITCH(j, NDim, { - new_lshape->assign(&(*new_lshape)[0], &(*new_lshape)[NDim]); - new_rshape->assign(&(*new_rshape)[0], &(*new_rshape)[NDim]); - new_oshape->assign(&(*new_oshape)[0], &(*new_oshape)[NDim]); - }); - } else { - LOG(FATAL) << "Too many broadcast dimensions with operands " << lshape << " " << rshape; - } - return j; -} - template class DropoutOp { #if defined(USE_MKL) && defined(_OPENMP) From 2d4333b1d443f743b09da040d032fb92e761022a Mon Sep 17 00:00:00 2001 From: Zhang Date: Fri, 2 Mar 2018 10:44:45 -0800 Subject: [PATCH 3/8] fix typo --- src/operator/nn/dropout-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 2d2e0c56fe03..d939f0912507 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -226,7 +226,7 @@ class DropoutOp { mask.dptr(), this->pkeep_); if (req[0] != kNullOp) { - // broardcast mul + // broadcast mul TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_, mask.shape_, out.shape_, From 9785fd650ff907887c1c109cde39449a114dfda3 Mon Sep 17 00:00:00 2001 From: Zhang Date: Mon, 5 Mar 2018 11:07:19 -0800 Subject: [PATCH 4/8] consider non broadcast case --- src/operator/nn/dropout-inl.h | 44 +++++++++++++++++++++++++++-------- src/operator/nn/dropout.cc | 6 ++--- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index d939f0912507..7af0d64a51d6 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -199,6 +199,23 @@ class DropoutOp { } }; + struct ElemwiseMulKernel { + /*! Elementwise multiply kernel */ + MSHADOW_XINLINE static void Map(int id, + RandGenerator gen, + const int N, + const int step, + DType *dropout_out, + DType *mask_out, + const DType *input_data) { + const int start = id * step; + const int end = start + step; + for (int i = start; i < end && i < N; ++i) { + dropout_out[i] = input_data[i] * mask_out[i]; + } + } + }; + void Init(const DropoutParam ¶m) { this->pkeep_ = 1.0f - param.p; this->mode_ = static_cast(param.mode); @@ -231,16 +248,23 @@ class DropoutOp { int ndim = BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_, mask.shape_, out.shape_, &new_lshape, &new_rshape, &new_oshape); - BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - in_data[dropout::kData].dptr(), - mask.dptr(), out.dptr()); - }); + if (!ndim) { + LaunchRNG(s, pgen, out.Size(), + out.dptr(), + mask.dptr(), + in_data[dropout::kData].dptr()); + } else { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + in_data[dropout::kData].dptr(), + mask.dptr(), out.dptr()); + }); + } } } } else { diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index c457d939dda3..3021e0105b4f 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -98,10 +98,8 @@ Example:: if (dshape.ndim() == 0) return false; out_shape->clear(); out_shape->push_back(dshape); - if (param.axes.ndim() != 0) { - for (index_t i = 0; i < param.axes.ndim(); ++i) { - dshape[param.axes[i]] = 1; - } + for (index_t i = 0; i < param.axes.ndim(); ++i) { + dshape[param.axes[i]] = 1; } out_shape->push_back(dshape); return true; From b4057b892e2bb186989d67b9a3cc0a0e85ee070c Mon Sep 17 00:00:00 2001 From: Zhang Date: Mon, 5 Mar 2018 13:46:00 -0800 Subject: [PATCH 5/8] fix backward --- src/operator/nn/dropout-inl.h | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 7af0d64a51d6..214609ae06bf 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -253,7 +253,7 @@ class DropoutOp { out.dptr(), mask.dptr(), in_data[dropout::kData].dptr()); - } else { + } else { BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); @@ -299,15 +299,24 @@ class DropoutOp { int ndim = BinaryBroadcastShapeCompact(grad.shape_, mask.shape_, gdata.shape_, &new_lshape, &new_rshape, &new_oshape); - BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - grad.dptr(), mask.dptr(), gdata.dptr()); - }); + if (!ndim) { + RandGenerator *pgen = ctx.requested[0].get_parallel_random(); + CHECK_NOTNULL(pgen); + LaunchRNG(s, pgen, new_oshape.Size(), + grad.dptr(), + mask.dptr(), + gdata.dptr()); + } else { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + grad.dptr(), mask.dptr(), gdata.dptr()); + }); + } } } else { const TBlob& gdata = in_grad[dropout::kData]; From d03be57e3dee838915e020390f50ac42d00e8f6d Mon Sep 17 00:00:00 2001 From: Zhang Date: Mon, 5 Mar 2018 17:48:49 -0800 Subject: [PATCH 6/8] avoid mkl --- src/operator/nn/dropout-inl.h | 62 +++++++++++++++++------------------ 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 214609ae06bf..908073a9c8a5 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -219,6 +219,7 @@ class DropoutOp { void Init(const DropoutParam ¶m) { this->pkeep_ = 1.0f - param.p; this->mode_ = static_cast(param.mode); + this->axes = param.axes; } void Forward(const OpContext &ctx, @@ -235,36 +236,36 @@ class DropoutOp { if (ctx.is_train || this->mode_ == dropout::kAlways) { RandGenerator *pgen = ctx.requested[0].get_parallel_random(); CHECK_NOTNULL(pgen); - if (!MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { + if (this->axes.ndim() != 0 || !MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { const TBlob &mask = out_data[dropout::kMask]; CHECK(req[dropout::kOut] != kAddTo); // initialize the mask LaunchRNG(s, pgen, out.Size(), mask.dptr(), this->pkeep_); - if (req[0] != kNullOp) { - // broadcast mul - TShape new_lshape, new_rshape, new_oshape; - int ndim = BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_, - mask.shape_, out.shape_, - &new_lshape, &new_rshape, &new_oshape); - if (!ndim) { - LaunchRNG(s, pgen, out.Size(), - out.dptr(), - mask.dptr(), - in_data[dropout::kData].dptr()); - } else { - BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - in_data[dropout::kData].dptr(), - mask.dptr(), out.dptr()); - }); - } + // broadcast mul + TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_, + mask.shape_, out.shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), in_data[dropout::kData].dptr(), + mask.dptr()); + }); + } else { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[dropout::kOut], + lstride, rstride, oshape, + in_data[dropout::kData].dptr(), + mask.dptr(), out.dptr()); + }); } } } else { @@ -290,7 +291,7 @@ class DropoutOp { using namespace mshadow::expr; Stream *s = ctx.get_stream(); if (ctx.is_train || mode_ == dropout::kAlways) { - if (!MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) { + if (this->axes.ndim() != 0 || !MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) { const TBlob &gdata = in_grad[dropout::kData]; const TBlob &grad = out_grad[dropout::kOut]; const TBlob &mask = out_data[dropout::kMask]; @@ -300,12 +301,10 @@ class DropoutOp { mask.shape_, gdata.shape_, &new_lshape, &new_rshape, &new_oshape); if (!ndim) { - RandGenerator *pgen = ctx.requested[0].get_parallel_random(); - CHECK_NOTNULL(pgen); - LaunchRNG(s, pgen, new_oshape.Size(), - grad.dptr(), - mask.dptr(), - gdata.dptr()); + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); + }); } else { BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape oshape = new_oshape.get(); @@ -337,6 +336,7 @@ class DropoutOp { real_t pkeep_; /*! \brief Dropout mode */ dropout::DropoutOpMode mode_; + TShape axes; }; // class DropoutOp template From 785c2124b631365946f1569d5698c0dd68980728 Mon Sep 17 00:00:00 2001 From: Zhang Date: Mon, 5 Mar 2018 17:50:17 -0800 Subject: [PATCH 7/8] rm redundent --- src/operator/nn/dropout-inl.h | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 908073a9c8a5..78ca6d067abc 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -199,23 +199,6 @@ class DropoutOp { } }; - struct ElemwiseMulKernel { - /*! Elementwise multiply kernel */ - MSHADOW_XINLINE static void Map(int id, - RandGenerator gen, - const int N, - const int step, - DType *dropout_out, - DType *mask_out, - const DType *input_data) { - const int start = id * step; - const int end = start + step; - for (int i = start; i < end && i < N; ++i) { - dropout_out[i] = input_data[i] * mask_out[i]; - } - } - }; - void Init(const DropoutParam ¶m) { this->pkeep_ = 1.0f - param.p; this->mode_ = static_cast(param.mode); From 69718862dca2d86548b348cc87100b040b80bf82 Mon Sep 17 00:00:00 2001 From: Zhang Date: Mon, 5 Mar 2018 17:57:10 -0800 Subject: [PATCH 8/8] condition check for standard dropout --- src/operator/nn/dropout-inl.h | 53 ++++++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 78ca6d067abc..b57ab45891e9 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -184,6 +184,33 @@ class DropoutOp { /*! * \brief Dropout kernel, compute dropout tensor */ + struct DropoutKernel { + /*! + * \brief Dropout kernel function + * \param id Thread number (0-based representing count) + * \param gen Random number generator + * \param N Total number of items in the output + * \param step Step between items, related to parallelism + * \param dropout_out Output dropout values + * \param mask_out Output mask (is multiplied to create dropout output, may be 0) + * \param input_data Input data to perform the dropout on + * \param pkeep Dropout rate (keep when the generated random number is less than this value) + */ + MSHADOW_XINLINE static void Map(int id, + RandGenerator gen, + const int N, + const int step, + DType *dropout_out, + DType *mask_out, + const DType *input_data, + const real_t pkeep) { + RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, { + const real_t rand_num = static_cast(genImpl.uniform()); + mask_out[i] = mshadow_op::threshold::Map(rand_num, pkeep) * (1.0f / pkeep); + dropout_out[i] = input_data[i] * mask_out[i]; + }); + } + }; struct BernoulliKernel { /*! \brief Bernoulli kernel for generating mask */ MSHADOW_XINLINE static void Map(int id, @@ -202,7 +229,7 @@ class DropoutOp { void Init(const DropoutParam ¶m) { this->pkeep_ = 1.0f - param.p; this->mode_ = static_cast(param.mode); - this->axes = param.axes; + this->axes_ = param.axes; } void Forward(const OpContext &ctx, @@ -219,9 +246,18 @@ class DropoutOp { if (ctx.is_train || this->mode_ == dropout::kAlways) { RandGenerator *pgen = ctx.requested[0].get_parallel_random(); CHECK_NOTNULL(pgen); - if (this->axes.ndim() != 0 || !MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { + if (this->axes_.ndim() != 0 || !MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { const TBlob &mask = out_data[dropout::kMask]; CHECK(req[dropout::kOut] != kAddTo); + if (this->axes_.ndim() == 0) { + // standard case for dropout + LaunchRNG(s, pgen, out.Size(), + out.dptr(), + mask.dptr(), + in_data[dropout::kData].dptr(), + this->pkeep_); + return; + } // initialize the mask LaunchRNG(s, pgen, out.Size(), mask.dptr(), @@ -274,10 +310,19 @@ class DropoutOp { using namespace mshadow::expr; Stream *s = ctx.get_stream(); if (ctx.is_train || mode_ == dropout::kAlways) { - if (this->axes.ndim() != 0 || !MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) { + if (this->axes_.ndim() != 0 || !MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) { const TBlob &gdata = in_grad[dropout::kData]; const TBlob &grad = out_grad[dropout::kOut]; const TBlob &mask = out_data[dropout::kMask]; + if (this->axes_.ndim() == 0) { + // standard case for dropout + CHECK_EQ(grad.Size(), mask.Size()); + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); + }); + return; + } // broardcast mul TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(grad.shape_, @@ -319,7 +364,7 @@ class DropoutOp { real_t pkeep_; /*! \brief Dropout mode */ dropout::DropoutOpMode mode_; - TShape axes; + TShape axes_; }; // class DropoutOp template