Add axes support to Dropout for variational dropout in NLP #9931
Conversation
…l part hasn't been updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the speed differences between the old and new default axes implementation (cpu, gpu, mkl) it can be measured by dropout_perf test
src/operator/nn/dropout-inl.h
Outdated
LOG(FATAL) << "NDim too large "; \ | ||
} | ||
|
||
inline int BinaryBroadcastShapeCompact(const TShape& lshape, const TShape& rshape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These look like a copy from src/operator/tensor/elemwise_binary_broadcast_op.h. Can we avoid copying the code?
src/operator/nn/dropout-inl.h
Outdated
*new_lshape = TShape(odim); | ||
*new_rshape = TShape(odim); | ||
*new_oshape = TShape(odim); | ||
index_t bl = oshape.ndim() - lshape.ndim(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: const index_t
should we do a CHECK()) for oshape.dim() >= lshape.dim()?
If MKL doesn't support the non-default axis behavior, it should skip using MKL for non-default axes similarly to how batch norm doesn't use MKL of CUDNN for non-default channel axis. |
src/operator/nn/dropout.cc
Outdated
if (dshape.ndim() == 0) return false; | ||
out_shape->clear(); | ||
out_shape->push_back(dshape); | ||
if (param.axes.ndim() != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This axes can be empty for normal dropout :)
src/operator/nn/dropout-inl.h
Outdated
for (int i = 1; i < length; ++i) { | ||
inc(&coord, oshape, &lidx, lstride, &ridx, rstride); | ||
// When tuning, don't actually run the op, since it's not going to be tuned against | ||
// the actual op we'll eventually be using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cannot quite get the point.
in_data[dropout::kData].dptr<DType>(), | ||
mask.dptr<DType>(), out.dptr<DType>()); | ||
}); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if MKL enables, the broadcast will not happen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't updated the MKL code for variational dropout (enabling axes). I need help with MKL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for non default axes, you can fall back to this op instead of the MKL op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx @cjolivier01 . I added the condition check here https://github.com/apache/incubator-mxnet/pull/9931/files#diff-4aea2cc24c0bb4e8e48face9faf4aa26R249
src/operator/nn/dropout-inl.h
Outdated
mask.dptr<DType>(), | ||
this->pkeep_); | ||
if (req[0] != kNullOp) { | ||
// broardcast mul |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo broadcast
src/operator/nn/dropout-inl.h
Outdated
const real_t pkeep) { | ||
RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, { | ||
const real_t rand_num = static_cast<real_t>(genImpl.uniform()); | ||
mask_out[i] = mshadow_op::threshold::Map<real_t>(rand_num, pkeep) * (1.0f / pkeep); | ||
dropout_out[i] = input_data[i] * mask_out[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not saying it needs to be done, but have you considered merging this operation with the other kernel, perhaps by deriving from broadcast_kernel or passing a modified version of the mul OP to broadcast_kernel?
Making two full passes across the memory is going to cause a performance hit due to both OMP overhead as well as CPU cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand you clearly.
I separate the original dropout kernel into two parts: 1) BernoulliKernel 2) broad_cast Mul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand you clearly.
I separate the original dropout kernel into two parts: 1) BernoulliKernel 2) broad_cast mul
so that we can enable axes support for variational dropout.
Thx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. What's the performance impact of using two kernels instead of one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx @cjolivier01 . I get your point for efficiency. I have added a condition check for standard dropout, which has the same efficiency when none-axes provided:
https://github.com/apache/incubator-mxnet/pull/9931/files#diff-4aea2cc24c0bb4e8e48face9faf4aa26R252
src/operator/nn/dropout-inl.h
Outdated
@@ -337,6 +336,7 @@ class DropoutOp { | |||
real_t pkeep_; | |||
/*! \brief Dropout mode */ | |||
dropout::DropoutOpMode mode_; | |||
TShape axes; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: member variable name should end in an underscore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
@zhanghang1989 could you provide some performance number reports for the speed difference before and after the change?
Pinging @cjolivier01 and @yzhliu for a final review. I intend to merge this as soon as possible, so I will wait for either your approval or three days for lazy consensus, whichever is earlier.
@szha The performance of dropout should be the same as before, when no axes are given. |
If you have a request for changes from a committer, you can't merge per Apache guidelines. |
What is the performance impact of these changes for default axes behavior compared ot the older code? |
Hang suggested that if |
Did the guideline explain how committers deal with stale reviews? |
Apache say: "A code-modification proposal may be stopped dead in its tracks by a -1 vote by a qualified voter. This constitutes a veto, and it cannot be overruled nor overridden by anyone. Vetos stand until and unless withdrawn by their casters." I am guessing for a "stale" review (stale I would imagine > 2 months old?), a death certificate of said committer would be useful. |
Haha OK. Jokes aside, did Hang sufficiently address your concern? |
Yeah, I'm good. |
@zhanghang1989 is working on it |
I am creating another PR for unit test. |
Next time please make sure to have code changes and tests within the same PR instead of splitting them. |
@yzhliu @zhanghang1989 |
👍 Got it. My bad. |
* add axes support to dropout for variational dropout, test pending, mkl part hasn't been updated * avoid copy code * fix typo * consider non broadcast case * fix backward * avoid mkl * rm redundent * condition check for standard dropout
* add axes support to dropout for variational dropout, test pending, mkl part hasn't been updated * avoid copy code * fix typo * consider non broadcast case * fix backward * avoid mkl * rm redundent * condition check for standard dropout
* add axes support to dropout for variational dropout, test pending, mkl part hasn't been updated * avoid copy code * fix typo * consider non broadcast case * fix backward * avoid mkl * rm redundent * condition check for standard dropout
add axes support to dropout for variational dropout in NLP
@szha Could you test this implementation? ping @yzhliu for MKL implementation
Description
(Brief description on what this PR is about)
Checklist
Essentials
make lint
)Changes
Comments