From 720274da53aa95aed1d71e21469a6bacc00f4559 Mon Sep 17 00:00:00 2001 From: xzl Date: Wed, 18 Oct 2017 17:39:39 +0800 Subject: [PATCH 1/8] add max-pool-with-mask python interface --- python/paddle/trainer/config_parser.py | 6 +++--- python/paddle/trainer_config_helpers/layers.py | 6 +++--- .../paddle/trainer_config_helpers/poolings.py | 17 +++++++++++++++-- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 098a51ab87912..3ea742b524b67 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1250,9 +1250,9 @@ def parse_bilinear(bilinear, input_layer_name, bilinear_conf): def parse_pool(pool, input_layer_name, pool_conf, ceil_mode): pool_conf.pool_type = pool.pool_type config_assert(pool.pool_type in [ - 'max-projection', 'avg-projection', 'cudnn-max-pool', 'cudnn-avg-pool' - ], "pool-type %s is not in " - "['max-projection', 'avg-projection', " + 'max-projection', 'avg-projection', 'max-pool-with-mask', 'cudnn-max-pool', 'cudnn-avg-pool' + ], "pool-type %s is not in " \ + "['max-projection', 'avg-projection', 'max-pool-with-mask'," \ "'cudnn-max-pool', 'cudnn-avg-pool']" % pool.pool_type) pool_conf.channels = pool.channels diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index d37f29d2c4bf9..88cd2bf77023a 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -20,7 +20,7 @@ from .activations import LinearActivation, SigmoidActivation, TanhActivation, \ ReluActivation, IdentityActivation, SoftmaxActivation, BaseActivation from .evaluators import * -from .poolings import MaxPooling, AvgPooling, BasePoolingType, \ +from .poolings import MaxPooling, AvgPooling, MaxWithMaskPooling, BasePoolingType, \ CudnnAvgPooling, CudnnMaxPooling from .attrs import * from .default_decorators import * @@ -2652,9 +2652,9 @@ def img_pool_layer(input, elif isinstance(pool_type, AvgPooling): pool_type.name = 'avg' - assert type(pool_type) in [AvgPooling, MaxPooling, CudnnAvgPooling, + assert type(pool_type) in [AvgPooling, MaxPooling, MaxWithMaskPooling, CudnnAvgPooling, CudnnMaxPooling], \ - "only (Cudnn)AvgPooling, (Cudnn)MaxPooling are supported" + "only (Cudnn)AvgPooling, (Cudnn)MaxPooling MaxWithMaskPooling are supported" type_name = pool_type.name + '-projection' \ if ( diff --git a/python/paddle/trainer_config_helpers/poolings.py b/python/paddle/trainer_config_helpers/poolings.py index 0c38a8dce553e..f45616551bcd4 100644 --- a/python/paddle/trainer_config_helpers/poolings.py +++ b/python/paddle/trainer_config_helpers/poolings.py @@ -15,8 +15,8 @@ """ __all__ = [ - "BasePoolingType", "MaxPooling", "AvgPooling", "CudnnMaxPooling", - "CudnnAvgPooling", "SumPooling", "SquareRootNPooling" + "BasePoolingType", "MaxPooling", "AvgPooling", "MaxWithMaskPooling", + "CudnnMaxPooling", "CudnnAvgPooling", "SumPooling", "SquareRootNPooling" ] @@ -55,6 +55,19 @@ def __init__(self, output_max_index=None): self.output_max_index = output_max_index +class MaxWithMaskPooling(BasePoolingType): + """ + MaxWithMask pooling. + + Not only return the very large values for each dimension in sequence or time steps, + but also the location indices of found maxinum values. + + """ + + def __init__(self): + BasePoolingType.__init__(self, "max-pool-with-mask") + + class CudnnMaxPooling(BasePoolingType): """ Cudnn max pooling only support GPU. Return the maxinum value in the From 9621213230c9caeac216f4796473f257e5065ec1 Mon Sep 17 00:00:00 2001 From: xzl Date: Wed, 18 Oct 2017 17:41:25 +0800 Subject: [PATCH 2/8] add max-pool-with-mask c++ impl --- paddle/gserver/layers/PoolLayer.cpp | 9 +++-- paddle/gserver/layers/PoolLayer.h | 2 ++ paddle/gserver/layers/PoolProjection.cpp | 36 ++++++++++++++++++- paddle/gserver/layers/PoolProjection.h | 13 ++++++- paddle/gserver/layers/PoolProjectionLayer.cpp | 10 +++++- paddle/gserver/layers/Projection.h | 13 +++++++ 6 files changed, 78 insertions(+), 5 deletions(-) diff --git a/paddle/gserver/layers/PoolLayer.cpp b/paddle/gserver/layers/PoolLayer.cpp index 7b932d5a76e9c..c5f4143a5bc8e 100644 --- a/paddle/gserver/layers/PoolLayer.cpp +++ b/paddle/gserver/layers/PoolLayer.cpp @@ -44,14 +44,19 @@ bool PoolLayer::init(const LayerMap& layerMap, strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride(); confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding(); outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x(); - + with_mask_ = false; + if (poolType_ == "max-pool-with-mask") { + setOutput("mask", &mask_); + with_mask_ = true; + } return true; } Layer* PoolLayer::create(const LayerConfig& config) { CHECK_EQ(config.inputs_size(), 1); const std::string& pool = config.inputs(0).pool_conf().pool_type(); - if (pool == "max-projection" || pool == "avg-projection") { + if (pool == "max-projection" || pool == "avg-projection" || + pool == "max-pool-with-mask") { return new PoolProjectionLayer(config); #ifdef PADDLE_WITH_CUDA } else if (CudnnPoolLayer::typeCheck(pool)) { diff --git a/paddle/gserver/layers/PoolLayer.h b/paddle/gserver/layers/PoolLayer.h index d43292ad2d4bb..780bfd0bce99d 100644 --- a/paddle/gserver/layers/PoolLayer.h +++ b/paddle/gserver/layers/PoolLayer.h @@ -37,6 +37,8 @@ class PoolLayer : public Layer { int confPaddingY_; std::string poolType_; + bool with_mask_; + Argument mask_; public: explicit PoolLayer(const LayerConfig& config) : Layer(config) {} diff --git a/paddle/gserver/layers/PoolProjection.cpp b/paddle/gserver/layers/PoolProjection.cpp index d90b438448eb7..ccf58228a76d7 100644 --- a/paddle/gserver/layers/PoolProjection.cpp +++ b/paddle/gserver/layers/PoolProjection.cpp @@ -36,6 +36,10 @@ PoolProjection::PoolProjection(const ProjectionConfig& config, strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride(); confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding(); outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x(); + with_mask_ = false; + if (poolType_ == "max-pool-with-mask") { + with_mask_ = true; + } } size_t PoolProjection::getSize() { @@ -73,6 +77,8 @@ PoolProjection* PoolProjection::create(const ProjectionConfig& config, return new MaxPoolProjection(config, parameter, useGpu); } else if (pool == "avg-projection") { return new AvgPoolProjection(config, parameter, useGpu); + } else if (pool == "max-pool-with-mask") { + return new MaxPoolProjection(config, parameter, useGpu); } else { LOG(FATAL) << "Unknown pool type: " << pool; return nullptr; @@ -84,6 +90,10 @@ void MaxPoolProjection::forward() { CHECK_EQ(width, out_->value->getWidth()); MatrixPtr inputV = in_->value; MatrixPtr outV = out_->value; + MatrixPtr maskV = out_->value; + if (with_mask_) { + maskV = mask_->value; + } outV->maxPoolForward(*inputV, imgSizeY_, imgSize_, @@ -95,7 +105,9 @@ void MaxPoolProjection::forward() { outputY_, outputX_, confPaddingY_, - confPadding_); + confPadding_, + maskV, + with_mask_); } void MaxPoolProjection::backward(const UpdateCallback& callback) { @@ -168,4 +180,26 @@ void AvgPoolProjection::backward(const UpdateCallback& callback) { confPaddingY_, confPadding_); } + +void MaxWithMaskPoolProjection::forward() { + size_t width = getSize(); + CHECK_EQ(width, out_->value->getWidth()); + MatrixPtr inputV = in_->value; + MatrixPtr outV = out_->value; + MatrixPtr maskV = mask_->value; + outV->maxPoolForward(*inputV, + imgSizeY_, + imgSize_, + channels_, + sizeX_, + sizeY_, + strideY_, + stride_, + outputY_, + outputX_, + confPaddingY_, + confPadding_, + maskV, + with_mask_); +} } // namespace paddle diff --git a/paddle/gserver/layers/PoolProjection.h b/paddle/gserver/layers/PoolProjection.h index 9a75f465f6fbb..d240d5c87e264 100644 --- a/paddle/gserver/layers/PoolProjection.h +++ b/paddle/gserver/layers/PoolProjection.h @@ -28,6 +28,7 @@ class PoolProjection : public Projection { int confPaddingY_, confPadding_; size_t channels_; std::string poolType_; + bool with_mask_; public: PoolProjection(const ProjectionConfig& config, @@ -37,7 +38,6 @@ class PoolProjection : public Projection { static PoolProjection* create(const ProjectionConfig& config, ParameterPtr parameter, bool useGpu); - const std::string& getPoolType() const { return poolType_; } size_t getSize(); @@ -64,4 +64,15 @@ class AvgPoolProjection : public PoolProjection { virtual void forward(); virtual void backward(const UpdateCallback& callback = nullptr); }; + +class MaxWithMaskPoolProjection : public MaxPoolProjection { +public: + MaxWithMaskPoolProjection(const ProjectionConfig& config, + ParameterPtr parameter, + bool useGpu) + : MaxPoolProjection(config, parameter, useGpu) {} + + virtual void forward(); +}; + } // namespace paddle diff --git a/paddle/gserver/layers/PoolProjectionLayer.cpp b/paddle/gserver/layers/PoolProjectionLayer.cpp index ed5011ab89906..5cd61a9ea8a27 100644 --- a/paddle/gserver/layers/PoolProjectionLayer.cpp +++ b/paddle/gserver/layers/PoolProjectionLayer.cpp @@ -51,8 +51,16 @@ void PoolProjectionLayer::forward(PassType passType) { const Argument& in = getInput(0); int batchSize = in.value->getHeight(); int size = getSize(); + + if (with_mask_) { + resetSpecifyOutput(mask_, + batchSize, + size, + /* isValueClean */ false, + /* isGradClean */ true); + } resetOutput(batchSize, size); - poolProjection_->forward(&in, &output_, passType); + poolProjection_->forward(&in, &output_, &mask_, passType); } void PoolProjectionLayer::backward(const UpdateCallback& callback) { diff --git a/paddle/gserver/layers/Projection.h b/paddle/gserver/layers/Projection.h index 778a7fe13d8a2..f60a9b931bd2d 100644 --- a/paddle/gserver/layers/Projection.h +++ b/paddle/gserver/layers/Projection.h @@ -69,6 +69,17 @@ class Projection { forward(); } + void forward(const Argument* in, + const Argument* out, + const Argument* mask, + PassType passType) { + in_ = in; + out_ = out; + mask_ = mask; + passType_ = passType; + forward(); + } + virtual void prefetch(const Argument* in) {} virtual void forward() = 0; virtual void backward(const UpdateCallback& callback) = 0; @@ -130,6 +141,8 @@ class Projection { const Argument* in_; /// Store `out` passed to forward() const Argument* out_; + /// Store `mask` passed to forward() + const Argument* mask_; /// Store `passType` passed to forward() PassType passType_; /// Layer forward function From afa690243e13a4f465cf68e57d6ac015a4b274e4 Mon Sep 17 00:00:00 2001 From: xzl Date: Wed, 18 Oct 2017 17:43:46 +0800 Subject: [PATCH 3/8] add cuda and cpu pool_forward_with_mask impl --- paddle/cuda/include/hl_cnn.h | 42 ++++++++++- paddle/cuda/include/stub/hl_cnn_stub.h | 18 +++++ paddle/cuda/src/hl_cuda_cnn.cu | 58 ++++++++++++++- paddle/math/Matrix.cpp | 98 ++++++++++++++++++++++++-- paddle/math/Matrix.h | 54 +++++++++++++- 5 files changed, 260 insertions(+), 10 deletions(-) diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index 6b56d9ec8d3da..62a761cd700d5 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -18,7 +18,7 @@ limitations under the License. */ #include "hl_base.h" /** - * @brief Maximum pool forward. + * @brief Maximum pool forward with Mask output. * * @param[in] frameCnt batch size of input image. * @param[in] inputData input data. @@ -35,7 +35,47 @@ limitations under the License. */ * @param[in] paddingW padding width. * @param[out] tgtData output data. * @param[in] tgtStride stride between output data samples. + * @param[out] maskData the location indices of select max data + * @param[in] withMask set true if output maskData + */ +extern void hl_maxpool_forward(const int frameCnt, + const real* inputData, + const int channels, + const int height, + const int width, + const int pooledH, + const int pooledW, + const int sizeX, + const int sizeY, + const int strideH, + const int strideW, + const int paddingH, + const int paddingW, + real* tgtData, + const int tgtStride, + real* maskData, + bool withMask); + +/** + * @brief Maximum pool forward. * + * @param[in] frameCnt batch size of input image. + * @param[in] inputData input data. + * @param[in] channels number of channel. + * @param[in] height image height. + * @param[in] width image width. + * @param[in] pooledH output image height. + * @param[in] pooledW output image width. + * @param[in] sizeX width of pooling window. + * @param[in] sizeY height of pooling window. + * @param[in] strideH pooling stride height. + * @param[in] strideW pooling stride width. + * @param[in] paddingH padding height. + * @param[in] paddingW padding width. + * @param[out] tgtData output data. + * @param[in] tgtStride stride between output data samples. + * @param[out] maskData the location indices of select max data + * @param[in] withMask set true if output maskData */ extern void hl_maxpool_forward(const int frameCnt, const real* inputData, diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index a76dbf0b6578d..d6e659d8422d8 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -33,6 +33,24 @@ inline void hl_maxpool_forward(const int frameCnt, real* tgtData, const int tgtStride) {} +inline void hl_maxpool_forward(const int frameCnt, + const real* inputData, + const int channels, + const int height, + const int width, + const int pooledH, + const int pooledW, + const int sizeX, + const int sizeY, + const int strideH, + const int strideW, + const int paddingH, + const int paddingW, + real* tgtData, + const int tgtStride, + real* MaskData, + bool withMask) {} + inline void hl_maxpool_backward(const int frameCnt, const real* inputData, const real* outData, diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 58674febdc4a0..f2a762f108938 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -31,7 +31,9 @@ __global__ void KeMaxPoolForward(const int nthreads, const int offsetH, const int offsetW, real* tgtData, - const int tgtStride) { + const int tgtStride, + real* maskData, + bool withMask) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < nthreads) { int pw = index % pooledW; @@ -45,16 +47,22 @@ __global__ void KeMaxPoolForward(const int nthreads, hstart = max(hstart, 0); wstart = max(wstart, 0); real maxval = -FLT_MAX; + int max_index = -1; inputData += (frameNum * channels + c) * height * width; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - if (maxval < inputData[h * width + w]) + if (maxval < inputData[h * width + w]) { maxval = inputData[h * width + w]; + max_index = h * width + w; + } } } int tgtIndex = index % (pooledW * pooledH * channels) + frameNum * tgtStride; tgtData[tgtIndex] = maxval; + if (withMask) { + maskData[tgtIndex] = max_index; + } } } @@ -92,7 +100,51 @@ void hl_maxpool_forward(const int frameCnt, paddingH, paddingW, tgtData, - tgtStride); + tgtStride, + NULL, + false); + CHECK_SYNC("hl_maxpool_forward failed"); +} + +void hl_maxpool_forward(const int frameCnt, + const real* inputData, + const int channels, + const int height, + const int width, + const int pooledH, + const int pooledW, + const int sizeX, + const int sizeY, + const int strideH, + const int strideW, + const int paddingH, + const int paddingW, + real* tgtData, + const int tgtStride, + real* maskData, + bool withMask) { + int num_kernels = pooledH * pooledW * channels * frameCnt; + int blocks = (num_kernels + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KeMaxPoolForward<<>>(num_kernels, + inputData, + channels, + height, + width, + pooledH, + pooledW, + sizeX, + sizeY, + strideH, + strideW, + paddingH, + paddingW, + tgtData, + tgtStride, + maskData, + withMask); CHECK_SYNC("hl_maxpool_forward failed"); } diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index c3e34d5309d9c..607e53074cb49 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1029,14 +1029,51 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, size_t outputW, size_t paddingH, size_t paddingW) { + maxPoolForward(inputMat, + imgSizeH, + imgSizeW, + channels, + sizeX, + sizeY, + strideH, + strideW, + outputH, + outputW, + paddingH, + paddingW, + NULL, + false); +} + +void GpuMatrix::maxPoolForward(Matrix& inputMat, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t sizeX, + size_t sizeY, + size_t strideH, + size_t strideW, + size_t outputH, + size_t outputW, + size_t paddingH, + size_t paddingW, + MatrixPtr maskMatP, + bool withMask) { CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal"; real* inputData = inputMat.getData(); + real* maskData = NULL; size_t frameNum = inputMat.getHeight(); CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputH * outputW * channels); + if (withMask) { + CHECK(maskMatP->useGpu_ == true) << "Matrix type are not equal"; + CHECK(outputH * outputW * channels == maskMatP->getWidth()); + maskData = maskMatP->getData(); + } + hl_maxpool_forward(frameNum, inputData, channels, @@ -1051,7 +1088,9 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, paddingH, paddingW, data_, - getStride()); + getStride(), + maskData, + withMask); } void GpuMatrix::maxPoolBackward(Matrix& inputMat, @@ -1974,8 +2013,39 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, size_t outputW, size_t paddingH, size_t paddingW) { + maxPoolForward(inputMat, + imgSizeH, + imgSizeW, + channels, + sizeX, + sizeY, + strideH, + strideW, + outputH, + outputW, + paddingH, + paddingW, + NULL, + false); +} + +void CpuMatrix::maxPoolForward(Matrix& inputMat, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t sizeX, + size_t sizeY, + size_t strideH, + size_t strideW, + size_t outputH, + size_t outputW, + size_t paddingH, + size_t paddingW, + MatrixPtr maskMatP, + bool withMask) { real* inputData = inputMat.getData(); real* outData = data_; + real* maskData = NULL; size_t num = inputMat.getHeight(); size_t inLength = imgSizeH * imgSizeW; size_t outLength = outputH * outputW; @@ -1984,6 +2054,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, CHECK_EQ(channels * outLength, this->getWidth()); size_t outStride = getStride(); + if (withMask) { + maskData = maskMatP->getData(); + CHECK_EQ(channels * outLength, maskMatP->getWidth()); + } + /* initialize the data_ */ for (size_t i = 0; i < height_; i++) { for (size_t j = 0; j < width_; j++) { @@ -2005,10 +2080,21 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, int wstart = pw * strideW - paddingW; int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - outData[ph * outputW + pw] = std::max( - outData[ph * outputW + pw], inputData[h * imgSizeW + w]); + if (!withMask) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + outData[ph * outputW + pw] = std::max( + outData[ph * outputW + pw], inputData[h * imgSizeW + w]); + } + } + } else { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (outData[ph * outputW + pw] < inputData[h * imgSizeW + w]) { + outData[ph * outputW + pw] = inputData[h * imgSizeW + w]; + maskData[ph * outputW + pw] = h * imgSizeW + w; + } + } } } } @@ -2016,6 +2102,8 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, // compute offset inputData += inLength; outData += outLength; + + if (withMask) maskData += outLength; } } } diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 44180bca8bca5..87a14a0af35cb 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -861,7 +861,7 @@ class Matrix : public BaseMatrix { /** * Pooling forward operation, pick out the largest element - * in the sizeX of value + * in the sizeX of value. */ virtual void maxPoolForward(Matrix& inputMat, size_t imgSizeH, @@ -878,6 +878,28 @@ class Matrix : public BaseMatrix { LOG(FATAL) << "Not implemeted"; } + /** + * Pooling forward operation, pick out the largest element + * in the sizeX of value, if set withMask true, it will + * also caculate the location indices. + */ + virtual void maxPoolForward(Matrix& inputMat, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t sizeX, + size_t sizeY, + size_t strideH, + size_t strideW, + size_t outputH, + size_t outputW, + size_t paddingH, + size_t paddingW, + MatrixPtr maskMatP, + bool withMask) { + LOG(FATAL) << "Not implemeted"; + } + /// Pooling backward operation. virtual void maxPoolBackward(Matrix& image, size_t imgSizeH, @@ -1428,6 +1450,21 @@ class GpuMatrix : public Matrix { size_t paddingH, size_t paddingW); + void maxPoolForward(Matrix& inputMat, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t sizeX, + size_t sizeY, + size_t strideH, + size_t strideW, + size_t outputH, + size_t outputW, + size_t paddingH, + size_t paddingW, + MatrixPtr maskMatP, + bool withMask); + void maxPoolBackward(Matrix& image, size_t imgSizeH, size_t imgSizeW, @@ -1699,6 +1736,21 @@ class CpuMatrix : public Matrix { size_t paddingH, size_t paddingW); + void maxPoolForward(Matrix& inputMat, + size_t imgSizeH, + size_t imgSizeW, + size_t channels, + size_t sizeX, + size_t sizeY, + size_t strideH, + size_t strideW, + size_t outputH, + size_t outputW, + size_t paddingH, + size_t paddingW, + MatrixPtr maskMatP, + bool withMask); + void maxPoolBackward(Matrix& image, size_t imgSizeH, size_t imgSizeW, From ff20a11a62e2e0862123a55aeef79d492e298f16 Mon Sep 17 00:00:00 2001 From: xzl Date: Wed, 18 Oct 2017 17:44:48 +0800 Subject: [PATCH 4/8] add layerGrad test and maskoutput test --- paddle/gserver/tests/CMakeLists.txt | 9 ++ paddle/gserver/tests/test_LayerGrad.cpp | 2 + .../tests/test_MaxPoolingWithMaskOutput.cpp | 117 ++++++++++++++++++ 3 files changed, 128 insertions(+) create mode 100644 paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index fcee19415c13e..04ef0293ab809 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -69,6 +69,15 @@ add_unittest_without_exec(test_PriorBox add_test(NAME test_PriorBox COMMAND test_PriorBox) + +################# test_MaxPoolingWithMaskOutput ################# +add_unittest_without_exec(test_MaxPoolingWithMaskOutput + test_MaxPoolingWithMaskOutput.cpp + LayerGradUtil.cpp) + +add_test(NAME test_MaxPoolingWithMaskOutput + COMMAND test_MaxPoolingWithMaskOutput) + ################# test_DetectionOutput ####################### add_unittest_without_exec(test_DetectionOutput test_DetectionOutput.cpp diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 1a46fb49153a0..eac68f3a39a23 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1234,6 +1234,7 @@ void testPoolLayer2(const string& poolType, bool trans, bool useGpu) { TEST(Layer, PoolLayer) { testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false); testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false); + testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ false); #ifdef PADDLE_WITH_CUDA testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true); @@ -1242,6 +1243,7 @@ TEST(Layer, PoolLayer) { testPoolLayer("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true); testPoolLayer2("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true); testPoolLayer2("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true); + testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ true); #endif } diff --git a/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp b/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp new file mode 100644 index 0000000000000..c351661422ea8 --- /dev/null +++ b/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp @@ -0,0 +1,117 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "LayerGradUtil.h" +#include "paddle/math/MathUtils.h" +#include "paddle/testing/TestUtil.h" + +using namespace paddle; + +void setPoolConfig(TestConfig* config, + PoolConfig* pool, + const string& poolType) { + (*config).biasSize = 0; + (*config).layerConfig.set_type("pool"); + (*config).layerConfig.set_num_filters(1); + + int kw = 3, kh = 3; + int pw = 0, ph = 0; + int sw = 2, sh = 2; + pool->set_pool_type(poolType); + pool->set_channels(1); + pool->set_size_x(kw); + pool->set_size_y(kh); + pool->set_start(0); + pool->set_padding(pw); + pool->set_padding_y(ph); + pool->set_stride(sw); + pool->set_stride_y(sh); + + int ow = outputSize(pool->img_size(), kw, pw, sw, /* caffeMode */ false); + int oh = outputSize(pool->img_size_y(), kh, ph, sh, /* caffeMode */ false); + pool->set_output_x(ow); + pool->set_output_y(oh); +} + +void doOneMaxPoolingWithMaskOutputTest(MatrixPtr& inputMat, + const string& poolType, + bool use_gpu, + MatrixPtr& maskMat) { + TestConfig config; + config.inputDefs.push_back({INPUT_DATA, "layer_0", 25, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + PoolConfig* pool = input->mutable_pool_conf(); + + pool->set_img_size(5); + pool->set_img_size_y(5); + setPoolConfig(&config, pool, poolType); + config.layerConfig.set_size(pool->output_x() * pool->output_y() * + pool->channels()); + + config.layerConfig.set_name("MaxPoolWithMask"); + + std::vector dataLayers; + LayerMap layerMap; + vector datas; + ; + initDataLayer(config, + &dataLayers, + &datas, + &layerMap, + "MaxPoolWithMask", + 1, + false, + use_gpu); + + dataLayers[0]->getOutputValue()->copyFrom(*inputMat); + + FLAGS_use_gpu = use_gpu; + std::vector parameters; + LayerPtr maxPoolingWithMaskOutputLayer; + initTestLayer(config, &layerMap, ¶meters, &maxPoolingWithMaskOutputLayer); + maxPoolingWithMaskOutputLayer->forward(PASS_GC); + ; + checkMatrixEqual(maxPoolingWithMaskOutputLayer->getOutput("mask").value, + maskMat); +} + +TEST(Layer, maxPoolingWithMaskOutputLayerFwd) { + bool useGpu = false; + MatrixPtr inputMat; + MatrixPtr maskMat; + real inputData[] = {0.1, 0.1, 0.5, 0.5, 1.1, 0.2, 0.2, 0.6, 0.1, + 0.1, 0.3, 0.3, 0.7, 0.1, 0.1, 0.4, 0.4, 0.8, + 0.8, 0.1, 1.0, 2.0, 3.0, 0.0, 9.0}; + real maskData[] = {12, 4, 22, 24}; + + inputMat = Matrix::create(1, 25, false, useGpu); + maskMat = Matrix::create(1, 4, false, useGpu); + inputMat->setData(inputData); + maskMat->setData(maskData); + doOneMaxPoolingWithMaskOutputTest( + inputMat, "max-pool-with-mask", useGpu, maskMat); +#ifdef PADDLE_WITH_CUDA + useGpu = true; + inputMat = Matrix::create(1, 25, false, useGpu); + maskMat = Matrix::create(1, 4, false, useGpu); + inputMat->copyFrom(inputData, 25); + maskMat->copyFrom(maskData, 4); + doOneMaxPoolingWithMaskOutputTest( + inputMat, "max-pool-with-mask", useGpu, maskMat); +#endif +} From 8106f414f7403442d2e9191a231ed965c4d39b98 Mon Sep 17 00:00:00 2001 From: xzl Date: Fri, 10 Nov 2017 15:47:09 +0800 Subject: [PATCH 5/8] add the max pool with mask layer --- .../gserver/layers/MaxPoolWithMaskLayer.cpp | 109 ++++++++++++++++++ paddle/gserver/layers/MaxPoolWithMaskLayer.h | 40 +++++++ 2 files changed, 149 insertions(+) create mode 100644 paddle/gserver/layers/MaxPoolWithMaskLayer.cpp create mode 100644 paddle/gserver/layers/MaxPoolWithMaskLayer.h diff --git a/paddle/gserver/layers/MaxPoolWithMaskLayer.cpp b/paddle/gserver/layers/MaxPoolWithMaskLayer.cpp new file mode 100644 index 0000000000000..d810a58d9a3ae --- /dev/null +++ b/paddle/gserver/layers/MaxPoolWithMaskLayer.cpp @@ -0,0 +1,109 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "MaxPoolWithMaskLayer.h" +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +bool MaxPoolWithMaskLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + PoolLayer::init(layerMap, parameterMap); + setOutput("mask", &mask_); + return true; +} + +size_t MaxPoolWithMaskLayer::getSize() { + CHECK_EQ(inputLayers_.size(), 1UL); + size_t layerSize = 0; + + outputY_ = outputSize(imgSizeY_, + sizeY_, + confPaddingY_, + strideY_, + /* caffeMode */ false); + outputX_ = outputSize(imgSize_, + sizeX_, + confPadding_, + stride_, + /* caffeMode */ false); + + layerSize = outputX_ * outputY_ * channels_; + getOutput().setFrameHeight(outputY_); + getOutput().setFrameWidth(outputX_); + + return layerSize; +} + +void MaxPoolWithMaskLayer::forward(PassType passType) { + size_t size = getSize(); + MatrixPtr inputV = inputLayers_[0]->getOutputValue(); + int batchSize = inputV->getHeight(); + resetOutput(batchSize, size); + + MatrixPtr outV = getOutputValue(); + CHECK_EQ(size, outV->getWidth()); + + resetSpecifyOutput(mask_, + batchSize, + size, + /* isValueClean */ false, + /* isGradClean */ true); + + MatrixPtr maskV = mask_.value; + outV->maxPoolForward(*inputV, + imgSizeY_, + imgSize_, + channels_, + sizeX_, + sizeY_, + strideY_, + stride_, + outputY_, + outputX_, + confPaddingY_, + confPadding_, + maskV); +} + +void MaxPoolWithMaskLayer::backward(const UpdateCallback& callback) { + (void)callback; + if (NULL == getInputGrad(0)) { + return; + } + + MatrixPtr outGrad = getOutputGrad(); + MatrixPtr inputV = inputLayers_[0]->getOutputValue(); + MatrixPtr outV = getOutputValue(); + MatrixPtr inputGrad = inputLayers_[0]->getOutputGrad(); + + inputGrad->maxPoolBackward(*inputV, + imgSizeY_, + imgSize_, + *outGrad, + *outV, + sizeX_, + sizeY_, + strideY_, + stride_, + outputY_, + outputX_, + 1, + 1, + confPaddingY_, + confPadding_); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/MaxPoolWithMaskLayer.h b/paddle/gserver/layers/MaxPoolWithMaskLayer.h new file mode 100644 index 0000000000000..e0174add9d944 --- /dev/null +++ b/paddle/gserver/layers/MaxPoolWithMaskLayer.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "PoolLayer.h" +#include "paddle/math/Matrix.h" + +namespace paddle { +/** + * @brief Basic parent layer of different kinds of pooling + */ +class MaxPoolWithMaskLayer : public PoolLayer { +protected: + Argument mask_; + +public: + explicit MaxPoolWithMaskLayer(const LayerConfig& config) + : PoolLayer(config) {} + + size_t getSize(); + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; +}; +} // namespace paddle From a54565ea0123183f8d50fb812f475e74faf595d0 Mon Sep 17 00:00:00 2001 From: xzl Date: Fri, 10 Nov 2017 15:48:27 +0800 Subject: [PATCH 6/8] delete mask pool interface from poolprojection --- paddle/cuda/include/hl_cnn.h | 43 +---------- paddle/cuda/include/stub/hl_cnn_stub.h | 19 +---- paddle/cuda/src/hl_cuda_cnn.cu | 51 +------------ paddle/gserver/layers/PoolLayer.cpp | 11 +-- paddle/gserver/layers/PoolLayer.h | 2 - paddle/gserver/layers/PoolProjection.cpp | 37 +--------- paddle/gserver/layers/PoolProjection.h | 11 --- paddle/gserver/layers/PoolProjectionLayer.cpp | 9 +-- paddle/gserver/layers/Projection.h | 13 ---- .../tests/test_MaxPoolingWithMaskOutput.cpp | 24 +++--- paddle/math/Matrix.cpp | 73 ++----------------- paddle/math/Matrix.h | 56 +------------- 12 files changed, 38 insertions(+), 311 deletions(-) diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index 62a761cd700d5..89c1f48edacbe 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -35,8 +35,7 @@ limitations under the License. */ * @param[in] paddingW padding width. * @param[out] tgtData output data. * @param[in] tgtStride stride between output data samples. - * @param[out] maskData the location indices of select max data - * @param[in] withMask set true if output maskData + * @param[out] maskData the location indices of select max data. */ extern void hl_maxpool_forward(const int frameCnt, const real* inputData, @@ -53,45 +52,7 @@ extern void hl_maxpool_forward(const int frameCnt, const int paddingW, real* tgtData, const int tgtStride, - real* maskData, - bool withMask); - -/** - * @brief Maximum pool forward. - * - * @param[in] frameCnt batch size of input image. - * @param[in] inputData input data. - * @param[in] channels number of channel. - * @param[in] height image height. - * @param[in] width image width. - * @param[in] pooledH output image height. - * @param[in] pooledW output image width. - * @param[in] sizeX width of pooling window. - * @param[in] sizeY height of pooling window. - * @param[in] strideH pooling stride height. - * @param[in] strideW pooling stride width. - * @param[in] paddingH padding height. - * @param[in] paddingW padding width. - * @param[out] tgtData output data. - * @param[in] tgtStride stride between output data samples. - * @param[out] maskData the location indices of select max data - * @param[in] withMask set true if output maskData - */ -extern void hl_maxpool_forward(const int frameCnt, - const real* inputData, - const int channels, - const int height, - const int width, - const int pooledH, - const int pooledW, - const int sizeX, - const int sizeY, - const int strideH, - const int strideW, - const int paddingH, - const int paddingW, - real* tgtData, - const int tgtStride); + real* maskData = NULL); /** * @brief Maximum pool backward. diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index d6e659d8422d8..fc22da024b92a 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -17,22 +17,6 @@ limitations under the License. */ #include "hl_cnn.h" -inline void hl_maxpool_forward(const int frameCnt, - const real* inputData, - const int channels, - const int height, - const int width, - const int pooledH, - const int pooledW, - const int sizeX, - const int sizeY, - const int strideH, - const int strideW, - const int paddingH, - const int paddingW, - real* tgtData, - const int tgtStride) {} - inline void hl_maxpool_forward(const int frameCnt, const real* inputData, const int channels, @@ -48,8 +32,7 @@ inline void hl_maxpool_forward(const int frameCnt, const int paddingW, real* tgtData, const int tgtStride, - real* MaskData, - bool withMask) {} + real* MaskData = NULL) {} inline void hl_maxpool_backward(const int frameCnt, const real* inputData, diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index f2a762f108938..a91ead240416e 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -32,8 +32,7 @@ __global__ void KeMaxPoolForward(const int nthreads, const int offsetW, real* tgtData, const int tgtStride, - real* maskData, - bool withMask) { + real* maskData) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < nthreads) { int pw = index % pooledW; @@ -60,52 +59,12 @@ __global__ void KeMaxPoolForward(const int nthreads, int tgtIndex = index % (pooledW * pooledH * channels) + frameNum * tgtStride; tgtData[tgtIndex] = maxval; - if (withMask) { + if (maskData != NULL) { maskData[tgtIndex] = max_index; } } } -void hl_maxpool_forward(const int frameCnt, - const real* inputData, - const int channels, - const int height, - const int width, - const int pooledH, - const int pooledW, - const int sizeX, - const int sizeY, - const int strideH, - const int strideW, - const int paddingH, - const int paddingW, - real* tgtData, - const int tgtStride) { - int num_kernels = pooledH * pooledW * channels * frameCnt; - int blocks = (num_kernels + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); - - KeMaxPoolForward<<>>(num_kernels, - inputData, - channels, - height, - width, - pooledH, - pooledW, - sizeX, - sizeY, - strideH, - strideW, - paddingH, - paddingW, - tgtData, - tgtStride, - NULL, - false); - CHECK_SYNC("hl_maxpool_forward failed"); -} - void hl_maxpool_forward(const int frameCnt, const real* inputData, const int channels, @@ -121,8 +80,7 @@ void hl_maxpool_forward(const int frameCnt, const int paddingW, real* tgtData, const int tgtStride, - real* maskData, - bool withMask) { + real* maskData) { int num_kernels = pooledH * pooledW * channels * frameCnt; int blocks = (num_kernels + 1024 - 1) / 1024; dim3 threads(1024, 1); @@ -143,8 +101,7 @@ void hl_maxpool_forward(const int frameCnt, paddingW, tgtData, tgtStride, - maskData, - withMask); + maskData); CHECK_SYNC("hl_maxpool_forward failed"); } diff --git a/paddle/gserver/layers/PoolLayer.cpp b/paddle/gserver/layers/PoolLayer.cpp index c5f4143a5bc8e..87613a96c5b3c 100644 --- a/paddle/gserver/layers/PoolLayer.cpp +++ b/paddle/gserver/layers/PoolLayer.cpp @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "PoolLayer.h" +#include "MaxPoolWithMaskLayer.h" #include "PoolProjectionLayer.h" #include "paddle/utils/Logging.h" #ifdef PADDLE_WITH_CUDA @@ -44,24 +45,20 @@ bool PoolLayer::init(const LayerMap& layerMap, strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride(); confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding(); outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x(); - with_mask_ = false; - if (poolType_ == "max-pool-with-mask") { - setOutput("mask", &mask_); - with_mask_ = true; - } return true; } Layer* PoolLayer::create(const LayerConfig& config) { CHECK_EQ(config.inputs_size(), 1); const std::string& pool = config.inputs(0).pool_conf().pool_type(); - if (pool == "max-projection" || pool == "avg-projection" || - pool == "max-pool-with-mask") { + if (pool == "max-projection" || pool == "avg-projection") { return new PoolProjectionLayer(config); #ifdef PADDLE_WITH_CUDA } else if (CudnnPoolLayer::typeCheck(pool)) { return new CudnnPoolLayer(config); #endif + } else if (pool == "max-pool-with-mask") { + return new MaxPoolWithMaskLayer(config); } else { LOG(FATAL) << "Unknown pool type: " << pool; return nullptr; diff --git a/paddle/gserver/layers/PoolLayer.h b/paddle/gserver/layers/PoolLayer.h index 780bfd0bce99d..d43292ad2d4bb 100644 --- a/paddle/gserver/layers/PoolLayer.h +++ b/paddle/gserver/layers/PoolLayer.h @@ -37,8 +37,6 @@ class PoolLayer : public Layer { int confPaddingY_; std::string poolType_; - bool with_mask_; - Argument mask_; public: explicit PoolLayer(const LayerConfig& config) : Layer(config) {} diff --git a/paddle/gserver/layers/PoolProjection.cpp b/paddle/gserver/layers/PoolProjection.cpp index ccf58228a76d7..5fa68b2c54539 100644 --- a/paddle/gserver/layers/PoolProjection.cpp +++ b/paddle/gserver/layers/PoolProjection.cpp @@ -36,10 +36,6 @@ PoolProjection::PoolProjection(const ProjectionConfig& config, strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride(); confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding(); outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x(); - with_mask_ = false; - if (poolType_ == "max-pool-with-mask") { - with_mask_ = true; - } } size_t PoolProjection::getSize() { @@ -77,8 +73,6 @@ PoolProjection* PoolProjection::create(const ProjectionConfig& config, return new MaxPoolProjection(config, parameter, useGpu); } else if (pool == "avg-projection") { return new AvgPoolProjection(config, parameter, useGpu); - } else if (pool == "max-pool-with-mask") { - return new MaxPoolProjection(config, parameter, useGpu); } else { LOG(FATAL) << "Unknown pool type: " << pool; return nullptr; @@ -90,10 +84,7 @@ void MaxPoolProjection::forward() { CHECK_EQ(width, out_->value->getWidth()); MatrixPtr inputV = in_->value; MatrixPtr outV = out_->value; - MatrixPtr maskV = out_->value; - if (with_mask_) { - maskV = mask_->value; - } + outV->maxPoolForward(*inputV, imgSizeY_, imgSize_, @@ -105,9 +96,7 @@ void MaxPoolProjection::forward() { outputY_, outputX_, confPaddingY_, - confPadding_, - maskV, - with_mask_); + confPadding_); } void MaxPoolProjection::backward(const UpdateCallback& callback) { @@ -180,26 +169,4 @@ void AvgPoolProjection::backward(const UpdateCallback& callback) { confPaddingY_, confPadding_); } - -void MaxWithMaskPoolProjection::forward() { - size_t width = getSize(); - CHECK_EQ(width, out_->value->getWidth()); - MatrixPtr inputV = in_->value; - MatrixPtr outV = out_->value; - MatrixPtr maskV = mask_->value; - outV->maxPoolForward(*inputV, - imgSizeY_, - imgSize_, - channels_, - sizeX_, - sizeY_, - strideY_, - stride_, - outputY_, - outputX_, - confPaddingY_, - confPadding_, - maskV, - with_mask_); -} } // namespace paddle diff --git a/paddle/gserver/layers/PoolProjection.h b/paddle/gserver/layers/PoolProjection.h index d240d5c87e264..ce0584d7b0fac 100644 --- a/paddle/gserver/layers/PoolProjection.h +++ b/paddle/gserver/layers/PoolProjection.h @@ -28,7 +28,6 @@ class PoolProjection : public Projection { int confPaddingY_, confPadding_; size_t channels_; std::string poolType_; - bool with_mask_; public: PoolProjection(const ProjectionConfig& config, @@ -65,14 +64,4 @@ class AvgPoolProjection : public PoolProjection { virtual void backward(const UpdateCallback& callback = nullptr); }; -class MaxWithMaskPoolProjection : public MaxPoolProjection { -public: - MaxWithMaskPoolProjection(const ProjectionConfig& config, - ParameterPtr parameter, - bool useGpu) - : MaxPoolProjection(config, parameter, useGpu) {} - - virtual void forward(); -}; - } // namespace paddle diff --git a/paddle/gserver/layers/PoolProjectionLayer.cpp b/paddle/gserver/layers/PoolProjectionLayer.cpp index 5cd61a9ea8a27..7334c3b051b44 100644 --- a/paddle/gserver/layers/PoolProjectionLayer.cpp +++ b/paddle/gserver/layers/PoolProjectionLayer.cpp @@ -52,15 +52,8 @@ void PoolProjectionLayer::forward(PassType passType) { int batchSize = in.value->getHeight(); int size = getSize(); - if (with_mask_) { - resetSpecifyOutput(mask_, - batchSize, - size, - /* isValueClean */ false, - /* isGradClean */ true); - } resetOutput(batchSize, size); - poolProjection_->forward(&in, &output_, &mask_, passType); + poolProjection_->forward(&in, &output_, passType); } void PoolProjectionLayer::backward(const UpdateCallback& callback) { diff --git a/paddle/gserver/layers/Projection.h b/paddle/gserver/layers/Projection.h index f60a9b931bd2d..778a7fe13d8a2 100644 --- a/paddle/gserver/layers/Projection.h +++ b/paddle/gserver/layers/Projection.h @@ -69,17 +69,6 @@ class Projection { forward(); } - void forward(const Argument* in, - const Argument* out, - const Argument* mask, - PassType passType) { - in_ = in; - out_ = out; - mask_ = mask; - passType_ = passType; - forward(); - } - virtual void prefetch(const Argument* in) {} virtual void forward() = 0; virtual void backward(const UpdateCallback& callback) = 0; @@ -141,8 +130,6 @@ class Projection { const Argument* in_; /// Store `out` passed to forward() const Argument* out_; - /// Store `mask` passed to forward() - const Argument* mask_; /// Store `passType` passed to forward() PassType passType_; /// Layer forward function diff --git a/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp b/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp index c351661422ea8..44fc2b91ec334 100644 --- a/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp +++ b/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp @@ -68,7 +68,7 @@ void doOneMaxPoolingWithMaskOutputTest(MatrixPtr& inputMat, std::vector dataLayers; LayerMap layerMap; vector datas; - ; + initDataLayer(config, &dataLayers, &datas, @@ -85,7 +85,7 @@ void doOneMaxPoolingWithMaskOutputTest(MatrixPtr& inputMat, LayerPtr maxPoolingWithMaskOutputLayer; initTestLayer(config, &layerMap, ¶meters, &maxPoolingWithMaskOutputLayer); maxPoolingWithMaskOutputLayer->forward(PASS_GC); - ; + checkMatrixEqual(maxPoolingWithMaskOutputLayer->getOutput("mask").value, maskMat); } @@ -105,13 +105,15 @@ TEST(Layer, maxPoolingWithMaskOutputLayerFwd) { maskMat->setData(maskData); doOneMaxPoolingWithMaskOutputTest( inputMat, "max-pool-with-mask", useGpu, maskMat); -#ifdef PADDLE_WITH_CUDA - useGpu = true; - inputMat = Matrix::create(1, 25, false, useGpu); - maskMat = Matrix::create(1, 4, false, useGpu); - inputMat->copyFrom(inputData, 25); - maskMat->copyFrom(maskData, 4); - doOneMaxPoolingWithMaskOutputTest( - inputMat, "max-pool-with-mask", useGpu, maskMat); -#endif + /* + #ifdef PADDLE_WITH_CUDA + useGpu = true; + inputMat = Matrix::create(1, 25, false, useGpu); + maskMat = Matrix::create(1, 4, false, useGpu); + inputMat->copyFrom(inputData, 25); + maskMat->copyFrom(maskData, 4); + doOneMaxPoolingWithMaskOutputTest( + inputMat, "max-pool-with-mask", useGpu, maskMat); + #endif + */ } diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 607e53074cb49..743922cd9bd65 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1017,34 +1017,6 @@ void GpuMatrix::check(std::ostream& os, Matrix& refMat, bool printDiff) { LOG(INFO) << "the diffCnt is " << diffCnt; } -void GpuMatrix::maxPoolForward(Matrix& inputMat, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t sizeX, - size_t sizeY, - size_t strideH, - size_t strideW, - size_t outputH, - size_t outputW, - size_t paddingH, - size_t paddingW) { - maxPoolForward(inputMat, - imgSizeH, - imgSizeW, - channels, - sizeX, - sizeY, - strideH, - strideW, - outputH, - outputW, - paddingH, - paddingW, - NULL, - false); -} - void GpuMatrix::maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW, @@ -1057,8 +1029,7 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, size_t outputW, size_t paddingH, size_t paddingW, - MatrixPtr maskMatP, - bool withMask) { + MatrixPtr maskMatP) { CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal"; real* inputData = inputMat.getData(); @@ -1068,7 +1039,7 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputH * outputW * channels); - if (withMask) { + if (maskMatP != NULL) { CHECK(maskMatP->useGpu_ == true) << "Matrix type are not equal"; CHECK(outputH * outputW * channels == maskMatP->getWidth()); maskData = maskMatP->getData(); @@ -1089,8 +1060,7 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, paddingW, data_, getStride(), - maskData, - withMask); + maskData); } void GpuMatrix::maxPoolBackward(Matrix& inputMat, @@ -2001,34 +1971,6 @@ void CpuMatrix::inverse(MatrixPtr& matInv, bool memAlloc) { CHECK_EQ(info, 0); } -void CpuMatrix::maxPoolForward(Matrix& inputMat, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t sizeX, - size_t sizeY, - size_t strideH, - size_t strideW, - size_t outputH, - size_t outputW, - size_t paddingH, - size_t paddingW) { - maxPoolForward(inputMat, - imgSizeH, - imgSizeW, - channels, - sizeX, - sizeY, - strideH, - strideW, - outputH, - outputW, - paddingH, - paddingW, - NULL, - false); -} - void CpuMatrix::maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW, @@ -2041,8 +1983,7 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, size_t outputW, size_t paddingH, size_t paddingW, - MatrixPtr maskMatP, - bool withMask) { + MatrixPtr maskMatP) { real* inputData = inputMat.getData(); real* outData = data_; real* maskData = NULL; @@ -2054,7 +1995,7 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, CHECK_EQ(channels * outLength, this->getWidth()); size_t outStride = getStride(); - if (withMask) { + if (maskMatP != NULL) { maskData = maskMatP->getData(); CHECK_EQ(channels * outLength, maskMatP->getWidth()); } @@ -2080,7 +2021,7 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, int wstart = pw * strideW - paddingW; int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - if (!withMask) { + if (maskMatP == NULL) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { outData[ph * outputW + pw] = std::max( @@ -2103,7 +2044,7 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, inputData += inLength; outData += outLength; - if (withMask) maskData += outLength; + if (maskMatP != NULL) maskData += outLength; } } } diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 87a14a0af35cb..d252d642258f4 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -861,26 +861,7 @@ class Matrix : public BaseMatrix { /** * Pooling forward operation, pick out the largest element - * in the sizeX of value. - */ - virtual void maxPoolForward(Matrix& inputMat, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t sizeX, - size_t sizeY, - size_t strideH, - size_t strideW, - size_t outputH, - size_t outputW, - size_t paddingH, - size_t paddingW) { - LOG(FATAL) << "Not implemeted"; - } - - /** - * Pooling forward operation, pick out the largest element - * in the sizeX of value, if set withMask true, it will + * in the sizeX of value, if the maskMatP is not NULL, it will * also caculate the location indices. */ virtual void maxPoolForward(Matrix& inputMat, @@ -895,8 +876,7 @@ class Matrix : public BaseMatrix { size_t outputW, size_t paddingH, size_t paddingW, - MatrixPtr maskMatP, - bool withMask) { + MatrixPtr maskMatP = NULL) { LOG(FATAL) << "Not implemeted"; } @@ -1437,19 +1417,6 @@ class GpuMatrix : public Matrix { void classificationError(Matrix& output, IVector& label, size_t topkSize = 1); - void maxPoolForward(Matrix& inputMat, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t sizeX, - size_t sizeY, - size_t strideH, - size_t strideW, - size_t outputH, - size_t outputW, - size_t paddingH, - size_t paddingW); - void maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW, @@ -1462,8 +1429,7 @@ class GpuMatrix : public Matrix { size_t outputW, size_t paddingH, size_t paddingW, - MatrixPtr maskMatP, - bool withMask); + MatrixPtr maskMatP); void maxPoolBackward(Matrix& image, size_t imgSizeH, @@ -1723,19 +1689,6 @@ class CpuMatrix : public Matrix { MatrixPtr clone(size_t height, size_t width, bool useGpu = false); - void maxPoolForward(Matrix& inputMat, - size_t imgSizeH, - size_t imgSizeW, - size_t channels, - size_t sizeX, - size_t sizeY, - size_t strideH, - size_t strideW, - size_t outputH, - size_t outputW, - size_t paddingH, - size_t paddingW); - void maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW, @@ -1748,8 +1701,7 @@ class CpuMatrix : public Matrix { size_t outputW, size_t paddingH, size_t paddingW, - MatrixPtr maskMatP, - bool withMask); + MatrixPtr maskMatP); void maxPoolBackward(Matrix& image, size_t imgSizeH, From 5aa3e768cdd26005779abfd84742bbc5b8d3b025 Mon Sep 17 00:00:00 2001 From: xzl Date: Mon, 13 Nov 2017 17:52:08 +0800 Subject: [PATCH 7/8] fix bug with default parameter --- paddle/cuda/include/stub/hl_cnn_stub.h | 2 +- paddle/gserver/layers/PoolProjection.cpp | 1 - paddle/gserver/layers/PoolProjection.h | 2 +- paddle/gserver/layers/PoolProjectionLayer.cpp | 1 - 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index fc22da024b92a..968ed4840ffb0 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -32,7 +32,7 @@ inline void hl_maxpool_forward(const int frameCnt, const int paddingW, real* tgtData, const int tgtStride, - real* MaskData = NULL) {} + real* MaskData) {} inline void hl_maxpool_backward(const int frameCnt, const real* inputData, diff --git a/paddle/gserver/layers/PoolProjection.cpp b/paddle/gserver/layers/PoolProjection.cpp index 5fa68b2c54539..d90b438448eb7 100644 --- a/paddle/gserver/layers/PoolProjection.cpp +++ b/paddle/gserver/layers/PoolProjection.cpp @@ -84,7 +84,6 @@ void MaxPoolProjection::forward() { CHECK_EQ(width, out_->value->getWidth()); MatrixPtr inputV = in_->value; MatrixPtr outV = out_->value; - outV->maxPoolForward(*inputV, imgSizeY_, imgSize_, diff --git a/paddle/gserver/layers/PoolProjection.h b/paddle/gserver/layers/PoolProjection.h index ce0584d7b0fac..9a75f465f6fbb 100644 --- a/paddle/gserver/layers/PoolProjection.h +++ b/paddle/gserver/layers/PoolProjection.h @@ -37,6 +37,7 @@ class PoolProjection : public Projection { static PoolProjection* create(const ProjectionConfig& config, ParameterPtr parameter, bool useGpu); + const std::string& getPoolType() const { return poolType_; } size_t getSize(); @@ -63,5 +64,4 @@ class AvgPoolProjection : public PoolProjection { virtual void forward(); virtual void backward(const UpdateCallback& callback = nullptr); }; - } // namespace paddle diff --git a/paddle/gserver/layers/PoolProjectionLayer.cpp b/paddle/gserver/layers/PoolProjectionLayer.cpp index 7334c3b051b44..ed5011ab89906 100644 --- a/paddle/gserver/layers/PoolProjectionLayer.cpp +++ b/paddle/gserver/layers/PoolProjectionLayer.cpp @@ -51,7 +51,6 @@ void PoolProjectionLayer::forward(PassType passType) { const Argument& in = getInput(0); int batchSize = in.value->getHeight(); int size = getSize(); - resetOutput(batchSize, size); poolProjection_->forward(&in, &output_, passType); } From 0b9c4cd7e5fd194110defbf1649d54da2e068c8b Mon Sep 17 00:00:00 2001 From: xzl Date: Tue, 14 Nov 2017 11:56:29 +0800 Subject: [PATCH 8/8] fix comments --- paddle/cuda/src/hl_cuda_cnn.cu | 2 +- .../tests/test_MaxPoolingWithMaskOutput.cpp | 20 +++++++++---------- paddle/math/Matrix.cpp | 4 ++-- .../paddle/trainer_config_helpers/layers.py | 2 +- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index a91ead240416e..3699b1e8ae9d8 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -51,8 +51,8 @@ __global__ void KeMaxPoolForward(const int nthreads, for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { if (maxval < inputData[h * width + w]) { - maxval = inputData[h * width + w]; max_index = h * width + w; + maxval = inputData[max_index]; } } } diff --git a/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp b/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp index 44fc2b91ec334..16438886df94c 100644 --- a/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp +++ b/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp @@ -105,15 +105,13 @@ TEST(Layer, maxPoolingWithMaskOutputLayerFwd) { maskMat->setData(maskData); doOneMaxPoolingWithMaskOutputTest( inputMat, "max-pool-with-mask", useGpu, maskMat); - /* - #ifdef PADDLE_WITH_CUDA - useGpu = true; - inputMat = Matrix::create(1, 25, false, useGpu); - maskMat = Matrix::create(1, 4, false, useGpu); - inputMat->copyFrom(inputData, 25); - maskMat->copyFrom(maskData, 4); - doOneMaxPoolingWithMaskOutputTest( - inputMat, "max-pool-with-mask", useGpu, maskMat); - #endif - */ +#ifdef PADDLE_WITH_CUDA + useGpu = true; + inputMat = Matrix::create(1, 25, false, useGpu); + maskMat = Matrix::create(1, 4, false, useGpu); + inputMat->copyFrom(inputData, 25); + maskMat->copyFrom(maskData, 4); + doOneMaxPoolingWithMaskOutputTest( + inputMat, "max-pool-with-mask", useGpu, maskMat); +#endif } diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 743922cd9bd65..41ee5089677f2 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -2021,7 +2021,7 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, int wstart = pw * strideW - paddingW; int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - if (maskMatP == NULL) { + if (maskData == NULL) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { outData[ph * outputW + pw] = std::max( @@ -2044,7 +2044,7 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, inputData += inLength; outData += outLength; - if (maskMatP != NULL) maskData += outLength; + if (maskData != NULL) maskData += outLength; } } } diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index f7ab7a5ca0a36..e21071f5b0201 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2701,7 +2701,7 @@ def img_pool_layer(input, assert type(pool_type) in [AvgPooling, MaxPooling, MaxWithMaskPooling, CudnnAvgPooling, CudnnMaxPooling], \ - "only (Cudnn)AvgPooling, (Cudnn)MaxPooling MaxWithMaskPooling are supported" + "only (Cudnn)AvgPooling, (Cudnn)MaxPooling, MaxWithMaskPooling are supported" type_name = pool_type.name + '-projection' \ if (