-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add max pool op (with index) #4461
Add max pool op (with index) #4461
Conversation
1cb037f
to
8a97324
Compare
e4e6964
to
d987c44
Compare
2149e40
to
6326c40
Compare
2d72558
to
d907eba
Compare
d907eba
to
5b606f3
Compare
5b606f3
to
bee95fc
Compare
I've commented out the check_grad in max pooling because the current method of testing can not effectively test the gradient of max pooling. But it ensures that the gradients computed by GPU and CPU are equal. |
59f20d5
to
f5e625a
Compare
42b9a34
to
bb33c2b
Compare
… Add_maxpool_withIdx_only
paddle/operators/CMakeLists.txt
Outdated
@@ -75,6 +75,12 @@ function(op_library TARGET) | |||
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n") | |||
endif() | |||
|
|||
if ("${TARGET}" STREQUAL "pool_with_index_op") | |||
set(pybind_flag 1) | |||
# It's enough to just adding one operator to pybind |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment does not make clear why need add another name to this operator. See above comments, like:
# reduce_op contains several operators
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/math/pooling.cc
Outdated
@@ -458,6 +458,233 @@ template class Pool3dGradFunctor< | |||
platform::CPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>; | |||
template class Pool3dGradFunctor< | |||
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>; | |||
|
|||
template <typename T> | |||
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to add some comments about these functors, like "Input order NCHW or NHWC"
Please take a glance at https://google.github.io/styleguide/cppguide.html#Function_Comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thx !!!
namespace paddle { | ||
namespace operators { | ||
|
||
int OutputSizeMaxPool(int input_size, int filter_size, int padding, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function can be a inline
. And, this function can be of more general use for conv
and pool
operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
This function is frequently called in conv2/3d, conv2/3d_cudnn, pool2/3d and pool2/3_cudnn.
I think we should convert this function to functor, and move to namespace math.
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OP(maxPool2dWithIndex, ops::MaxPoolWithIndexOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems operator names are commonly xxx_xxx_op
but not "camal case".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
232077d
to
811cf0b
Compare
811cf0b
to
fcfce48
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments.
if ("${TARGET}" STREQUAL "pool_with_index_op") | ||
set(pybind_flag 1) | ||
# It's enough to just adding one operator to pybind | ||
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also need to add max_pool3d_with_index
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just adding one operator of pool_with_index_op is ok. And unit test(max_pool2d_with_index, max_pool3d_with_index) passed.
"X(Input) of Pooling should not be null."); | ||
PADDLE_ENFORCE( | ||
ctx->HasOutput(framework::GradVarName("X")), | ||
"X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@GRAD
suffix is for internal use, may confuse users if this message came out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I approve with you. If X@GRAD is null, it must be an internet error. But I think it is helpful to developer.
"ksize", | ||
"Pooling size(height, width) of pooling operator." | ||
"If globalPooling = true, ksize is ignored and need not be " | ||
"specified."); // TODO(Add checker) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems you need to configure pre-commit on your local environment before running git commit
, so the style check can automatically prints some style errors.
In this case, "TODO" comments should be like: https://google.github.io/styleguide/cppguide.html#TODO_Comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this style of writing is really nonstandard.
I have configured pre-commit on my local environment, but it also passed.
"image."); | ||
AddOutput("Out", | ||
"The output tensor of pooling operator." | ||
"The format of output tensor is also NCDHW." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"The format of output tensor is also NCDHW."
"Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and "
"width of image."
Can put these explains to AddComment
, so here can be more simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
c93063a
to
6c6474c
Compare
5308f84
to
68b0508
Compare
68b0508
to
36da825
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM++
fix #4327