Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

max pool Layer with mask #4891

Merged
merged 19 commits into from
Nov 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
720274d
add max-pool-with-mask python interface
NHZlX Oct 18, 2017
9621213
add max-pool-with-mask c++ impl
NHZlX Oct 18, 2017
afa6902
add cuda and cpu pool_forward_with_mask impl
NHZlX Oct 18, 2017
ff20a11
add layerGrad test and maskoutput test
NHZlX Oct 18, 2017
c876dfa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Oct 18, 2017
8106f41
add the max pool with mask layer
NHZlX Nov 10, 2017
a54565e
delete mask pool interface from poolprojection
NHZlX Nov 10, 2017
ba2e5de
pull develop and fix conflict
NHZlX Nov 10, 2017
2ef1867
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Nov 10, 2017
921f723
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Nov 10, 2017
3428a52
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Nov 10, 2017
9e894f6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Nov 11, 2017
4ebb05c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Nov 11, 2017
b29cbd4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Nov 11, 2017
69147da
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Nov 13, 2017
471573e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Nov 13, 2017
5aa3e76
fix bug with default parameter
NHZlX Nov 13, 2017
0b9c4cd
fix comments
NHZlX Nov 14, 2017
e19b931
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Nov 14, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions paddle/cuda/include/hl_cnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License. */
#include "hl_base.h"

/**
* @brief Maximum pool forward.
* @brief Maximum pool forward with Mask output.
*
* @param[in] frameCnt batch size of input image.
* @param[in] inputData input data.
Expand All @@ -35,7 +35,7 @@ limitations under the License. */
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
*
* @param[out] maskData the location indices of select max data.
*/
extern void hl_maxpool_forward(const int frameCnt,
const real* inputData,
Expand All @@ -51,7 +51,8 @@ extern void hl_maxpool_forward(const int frameCnt,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride);
const int tgtStride,
real* maskData = NULL);

/**
* @brief Maximum pool backward.
Expand Down
3 changes: 2 additions & 1 deletion paddle/cuda/include/stub/hl_cnn_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ inline void hl_maxpool_forward(const int frameCnt,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride) {}
const int tgtStride,
real* MaskData) {}

inline void hl_maxpool_backward(const int frameCnt,
const real* inputData,
Expand Down
19 changes: 14 additions & 5 deletions paddle/cuda/src/hl_cuda_cnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ __global__ void KeMaxPoolForward(const int nthreads,
const int offsetH,
const int offsetW,
real* tgtData,
const int tgtStride) {
const int tgtStride,
real* maskData) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int pw = index % pooledW;
Expand All @@ -45,16 +46,22 @@ __global__ void KeMaxPoolForward(const int nthreads,
hstart = max(hstart, 0);
wstart = max(wstart, 0);
real maxval = -FLT_MAX;
int max_index = -1;
inputData += (frameNum * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (maxval < inputData[h * width + w])
maxval = inputData[h * width + w];
if (maxval < inputData[h * width + w]) {
max_index = h * width + w;
maxval = inputData[max_index];
}
}
}
int tgtIndex =
index % (pooledW * pooledH * channels) + frameNum * tgtStride;
tgtData[tgtIndex] = maxval;
if (maskData != NULL) {
maskData[tgtIndex] = max_index;
}
}
}

Expand All @@ -72,7 +79,8 @@ void hl_maxpool_forward(const int frameCnt,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride) {
const int tgtStride,
real* maskData) {
int num_kernels = pooledH * pooledW * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024;
dim3 threads(1024, 1);
Expand All @@ -92,7 +100,8 @@ void hl_maxpool_forward(const int frameCnt,
paddingH,
paddingW,
tgtData,
tgtStride);
tgtStride,
maskData);
CHECK_SYNC("hl_maxpool_forward failed");
}

Expand Down
109 changes: 109 additions & 0 deletions paddle/gserver/layers/MaxPoolWithMaskLayer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "MaxPoolWithMaskLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"

namespace paddle {

bool MaxPoolWithMaskLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
PoolLayer::init(layerMap, parameterMap);
setOutput("mask", &mask_);
return true;
}

size_t MaxPoolWithMaskLayer::getSize() {
CHECK_EQ(inputLayers_.size(), 1UL);
size_t layerSize = 0;

outputY_ = outputSize(imgSizeY_,
sizeY_,
confPaddingY_,
strideY_,
/* caffeMode */ false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it can be modified into caffeMode = true? The other Layer (Pooling, Conv) default is true.

Copy link
Contributor Author

@NHZlX NHZlX Nov 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The caffeMode in poolProjection.cpp is false, i referred there. https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/gserver/layers/PoolProjection.cpp#L51

outputX_ = outputSize(imgSize_,
sizeX_,
confPadding_,
stride_,
/* caffeMode */ false);

layerSize = outputX_ * outputY_ * channels_;
getOutput().setFrameHeight(outputY_);
getOutput().setFrameWidth(outputX_);

return layerSize;
}

void MaxPoolWithMaskLayer::forward(PassType passType) {
size_t size = getSize();
MatrixPtr inputV = inputLayers_[0]->getOutputValue();
int batchSize = inputV->getHeight();
resetOutput(batchSize, size);

MatrixPtr outV = getOutputValue();
CHECK_EQ(size, outV->getWidth());

resetSpecifyOutput(mask_,
batchSize,
size,
/* isValueClean */ false,
/* isGradClean */ true);

MatrixPtr maskV = mask_.value;
outV->maxPoolForward(*inputV,
imgSizeY_,
imgSize_,
channels_,
sizeX_,
sizeY_,
strideY_,
stride_,
outputY_,
outputX_,
confPaddingY_,
confPadding_,
maskV);
}

void MaxPoolWithMaskLayer::backward(const UpdateCallback& callback) {
(void)callback;
if (NULL == getInputGrad(0)) {
return;
}

MatrixPtr outGrad = getOutputGrad();
MatrixPtr inputV = inputLayers_[0]->getOutputValue();
MatrixPtr outV = getOutputValue();
MatrixPtr inputGrad = inputLayers_[0]->getOutputGrad();

inputGrad->maxPoolBackward(*inputV,
imgSizeY_,
imgSize_,
*outGrad,
*outV,
sizeX_,
sizeY_,
strideY_,
stride_,
outputY_,
outputX_,
1,
1,
confPaddingY_,
confPadding_);
}

} // namespace paddle
40 changes: 40 additions & 0 deletions paddle/gserver/layers/MaxPoolWithMaskLayer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <vector>
#include "PoolLayer.h"
#include "paddle/math/Matrix.h"

namespace paddle {
/**
* @brief Basic parent layer of different kinds of pooling
*/
class MaxPoolWithMaskLayer : public PoolLayer {
protected:
Argument mask_;

public:
explicit MaxPoolWithMaskLayer(const LayerConfig& config)
: PoolLayer(config) {}

size_t getSize();

void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
};
} // namespace paddle
4 changes: 3 additions & 1 deletion paddle/gserver/layers/PoolLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "PoolLayer.h"
#include "MaxPoolWithMaskLayer.h"
#include "PoolProjectionLayer.h"
#include "paddle/utils/Logging.h"
#ifdef PADDLE_WITH_CUDA
Expand Down Expand Up @@ -44,7 +45,6 @@ bool PoolLayer::init(const LayerMap& layerMap,
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();

return true;
}

Expand All @@ -57,6 +57,8 @@ Layer* PoolLayer::create(const LayerConfig& config) {
} else if (CudnnPoolLayer::typeCheck(pool)) {
return new CudnnPoolLayer(config);
#endif
} else if (pool == "max-pool-with-mask") {
return new MaxPoolWithMaskLayer(config);
} else {
LOG(FATAL) << "Unknown pool type: " << pool;
return nullptr;
Expand Down
1 change: 1 addition & 0 deletions paddle/gserver/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ gserver_test(test_ConvUnify)
gserver_test(test_BatchNorm)
gserver_test(test_KmaxSeqScore)
gserver_test(test_Expand)
gserver_test(test_MaxPoolingWithMaskOutput)

########## test_Mkldnn layers and activations ##########
if(WITH_MKLDNN)
Expand Down
2 changes: 2 additions & 0 deletions paddle/gserver/tests/test_LayerGrad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,7 @@ void testPoolLayer2(const string& poolType, bool trans, bool useGpu) {
TEST(Layer, PoolLayer) {
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false);
testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false);
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ false);

#ifdef PADDLE_WITH_CUDA
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true);
Expand All @@ -1242,6 +1243,7 @@ TEST(Layer, PoolLayer) {
testPoolLayer("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ true);
#endif
}

Expand Down
Loading