diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index 6b56d9ec8d3da..89c1f48edacbe 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -18,7 +18,7 @@ limitations under the License. */ #include "hl_base.h" /** - * @brief Maximum pool forward. + * @brief Maximum pool forward with Mask output. * * @param[in] frameCnt batch size of input image. * @param[in] inputData input data. @@ -35,7 +35,7 @@ limitations under the License. */ * @param[in] paddingW padding width. * @param[out] tgtData output data. * @param[in] tgtStride stride between output data samples. - * + * @param[out] maskData the location indices of select max data. */ extern void hl_maxpool_forward(const int frameCnt, const real* inputData, @@ -51,7 +51,8 @@ extern void hl_maxpool_forward(const int frameCnt, const int paddingH, const int paddingW, real* tgtData, - const int tgtStride); + const int tgtStride, + real* maskData = NULL); /** * @brief Maximum pool backward. diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index a76dbf0b6578d..968ed4840ffb0 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -31,7 +31,8 @@ inline void hl_maxpool_forward(const int frameCnt, const int paddingH, const int paddingW, real* tgtData, - const int tgtStride) {} + const int tgtStride, + real* MaskData) {} inline void hl_maxpool_backward(const int frameCnt, const real* inputData, diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 58674febdc4a0..3699b1e8ae9d8 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -31,7 +31,8 @@ __global__ void KeMaxPoolForward(const int nthreads, const int offsetH, const int offsetW, real* tgtData, - const int tgtStride) { + const int tgtStride, + real* maskData) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < nthreads) { int pw = index % pooledW; @@ -45,16 +46,22 @@ __global__ void KeMaxPoolForward(const int nthreads, hstart = max(hstart, 0); wstart = max(wstart, 0); real maxval = -FLT_MAX; + int max_index = -1; inputData += (frameNum * channels + c) * height * width; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - if (maxval < inputData[h * width + w]) - maxval = inputData[h * width + w]; + if (maxval < inputData[h * width + w]) { + max_index = h * width + w; + maxval = inputData[max_index]; + } } } int tgtIndex = index % (pooledW * pooledH * channels) + frameNum * tgtStride; tgtData[tgtIndex] = maxval; + if (maskData != NULL) { + maskData[tgtIndex] = max_index; + } } } @@ -72,7 +79,8 @@ void hl_maxpool_forward(const int frameCnt, const int paddingH, const int paddingW, real* tgtData, - const int tgtStride) { + const int tgtStride, + real* maskData) { int num_kernels = pooledH * pooledW * channels * frameCnt; int blocks = (num_kernels + 1024 - 1) / 1024; dim3 threads(1024, 1); @@ -92,7 +100,8 @@ void hl_maxpool_forward(const int frameCnt, paddingH, paddingW, tgtData, - tgtStride); + tgtStride, + maskData); CHECK_SYNC("hl_maxpool_forward failed"); } diff --git a/paddle/gserver/layers/MaxPoolWithMaskLayer.cpp b/paddle/gserver/layers/MaxPoolWithMaskLayer.cpp new file mode 100644 index 0000000000000..d810a58d9a3ae --- /dev/null +++ b/paddle/gserver/layers/MaxPoolWithMaskLayer.cpp @@ -0,0 +1,109 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "MaxPoolWithMaskLayer.h" +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +bool MaxPoolWithMaskLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + PoolLayer::init(layerMap, parameterMap); + setOutput("mask", &mask_); + return true; +} + +size_t MaxPoolWithMaskLayer::getSize() { + CHECK_EQ(inputLayers_.size(), 1UL); + size_t layerSize = 0; + + outputY_ = outputSize(imgSizeY_, + sizeY_, + confPaddingY_, + strideY_, + /* caffeMode */ false); + outputX_ = outputSize(imgSize_, + sizeX_, + confPadding_, + stride_, + /* caffeMode */ false); + + layerSize = outputX_ * outputY_ * channels_; + getOutput().setFrameHeight(outputY_); + getOutput().setFrameWidth(outputX_); + + return layerSize; +} + +void MaxPoolWithMaskLayer::forward(PassType passType) { + size_t size = getSize(); + MatrixPtr inputV = inputLayers_[0]->getOutputValue(); + int batchSize = inputV->getHeight(); + resetOutput(batchSize, size); + + MatrixPtr outV = getOutputValue(); + CHECK_EQ(size, outV->getWidth()); + + resetSpecifyOutput(mask_, + batchSize, + size, + /* isValueClean */ false, + /* isGradClean */ true); + + MatrixPtr maskV = mask_.value; + outV->maxPoolForward(*inputV, + imgSizeY_, + imgSize_, + channels_, + sizeX_, + sizeY_, + strideY_, + stride_, + outputY_, + outputX_, + confPaddingY_, + confPadding_, + maskV); +} + +void MaxPoolWithMaskLayer::backward(const UpdateCallback& callback) { + (void)callback; + if (NULL == getInputGrad(0)) { + return; + } + + MatrixPtr outGrad = getOutputGrad(); + MatrixPtr inputV = inputLayers_[0]->getOutputValue(); + MatrixPtr outV = getOutputValue(); + MatrixPtr inputGrad = inputLayers_[0]->getOutputGrad(); + + inputGrad->maxPoolBackward(*inputV, + imgSizeY_, + imgSize_, + *outGrad, + *outV, + sizeX_, + sizeY_, + strideY_, + stride_, + outputY_, + outputX_, + 1, + 1, + confPaddingY_, + confPadding_); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/MaxPoolWithMaskLayer.h b/paddle/gserver/layers/MaxPoolWithMaskLayer.h new file mode 100644 index 0000000000000..e0174add9d944 --- /dev/null +++ b/paddle/gserver/layers/MaxPoolWithMaskLayer.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "PoolLayer.h" +#include "paddle/math/Matrix.h" + +namespace paddle { +/** + * @brief Basic parent layer of different kinds of pooling + */ +class MaxPoolWithMaskLayer : public PoolLayer { +protected: + Argument mask_; + +public: + explicit MaxPoolWithMaskLayer(const LayerConfig& config) + : PoolLayer(config) {} + + size_t getSize(); + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; +}; +} // namespace paddle diff --git a/paddle/gserver/layers/PoolLayer.cpp b/paddle/gserver/layers/PoolLayer.cpp index 7b932d5a76e9c..87613a96c5b3c 100644 --- a/paddle/gserver/layers/PoolLayer.cpp +++ b/paddle/gserver/layers/PoolLayer.cpp @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "PoolLayer.h" +#include "MaxPoolWithMaskLayer.h" #include "PoolProjectionLayer.h" #include "paddle/utils/Logging.h" #ifdef PADDLE_WITH_CUDA @@ -44,7 +45,6 @@ bool PoolLayer::init(const LayerMap& layerMap, strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride(); confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding(); outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x(); - return true; } @@ -57,6 +57,8 @@ Layer* PoolLayer::create(const LayerConfig& config) { } else if (CudnnPoolLayer::typeCheck(pool)) { return new CudnnPoolLayer(config); #endif + } else if (pool == "max-pool-with-mask") { + return new MaxPoolWithMaskLayer(config); } else { LOG(FATAL) << "Unknown pool type: " << pool; return nullptr; diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index aa94ee406e27c..e24d423b9e646 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -24,6 +24,7 @@ gserver_test(test_ConvUnify) gserver_test(test_BatchNorm) gserver_test(test_KmaxSeqScore) gserver_test(test_Expand) +gserver_test(test_MaxPoolingWithMaskOutput) ########## test_Mkldnn layers and activations ########## if(WITH_MKLDNN) diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index fcbcb5b0f1f4c..3afaab630b21e 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1234,6 +1234,7 @@ void testPoolLayer2(const string& poolType, bool trans, bool useGpu) { TEST(Layer, PoolLayer) { testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false); testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false); + testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ false); #ifdef PADDLE_WITH_CUDA testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true); @@ -1242,6 +1243,7 @@ TEST(Layer, PoolLayer) { testPoolLayer("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true); testPoolLayer2("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true); testPoolLayer2("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true); + testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ true); #endif } diff --git a/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp b/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp new file mode 100644 index 0000000000000..16438886df94c --- /dev/null +++ b/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp @@ -0,0 +1,117 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "LayerGradUtil.h" +#include "paddle/math/MathUtils.h" +#include "paddle/testing/TestUtil.h" + +using namespace paddle; + +void setPoolConfig(TestConfig* config, + PoolConfig* pool, + const string& poolType) { + (*config).biasSize = 0; + (*config).layerConfig.set_type("pool"); + (*config).layerConfig.set_num_filters(1); + + int kw = 3, kh = 3; + int pw = 0, ph = 0; + int sw = 2, sh = 2; + pool->set_pool_type(poolType); + pool->set_channels(1); + pool->set_size_x(kw); + pool->set_size_y(kh); + pool->set_start(0); + pool->set_padding(pw); + pool->set_padding_y(ph); + pool->set_stride(sw); + pool->set_stride_y(sh); + + int ow = outputSize(pool->img_size(), kw, pw, sw, /* caffeMode */ false); + int oh = outputSize(pool->img_size_y(), kh, ph, sh, /* caffeMode */ false); + pool->set_output_x(ow); + pool->set_output_y(oh); +} + +void doOneMaxPoolingWithMaskOutputTest(MatrixPtr& inputMat, + const string& poolType, + bool use_gpu, + MatrixPtr& maskMat) { + TestConfig config; + config.inputDefs.push_back({INPUT_DATA, "layer_0", 25, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + PoolConfig* pool = input->mutable_pool_conf(); + + pool->set_img_size(5); + pool->set_img_size_y(5); + setPoolConfig(&config, pool, poolType); + config.layerConfig.set_size(pool->output_x() * pool->output_y() * + pool->channels()); + + config.layerConfig.set_name("MaxPoolWithMask"); + + std::vector dataLayers; + LayerMap layerMap; + vector datas; + + initDataLayer(config, + &dataLayers, + &datas, + &layerMap, + "MaxPoolWithMask", + 1, + false, + use_gpu); + + dataLayers[0]->getOutputValue()->copyFrom(*inputMat); + + FLAGS_use_gpu = use_gpu; + std::vector parameters; + LayerPtr maxPoolingWithMaskOutputLayer; + initTestLayer(config, &layerMap, ¶meters, &maxPoolingWithMaskOutputLayer); + maxPoolingWithMaskOutputLayer->forward(PASS_GC); + + checkMatrixEqual(maxPoolingWithMaskOutputLayer->getOutput("mask").value, + maskMat); +} + +TEST(Layer, maxPoolingWithMaskOutputLayerFwd) { + bool useGpu = false; + MatrixPtr inputMat; + MatrixPtr maskMat; + real inputData[] = {0.1, 0.1, 0.5, 0.5, 1.1, 0.2, 0.2, 0.6, 0.1, + 0.1, 0.3, 0.3, 0.7, 0.1, 0.1, 0.4, 0.4, 0.8, + 0.8, 0.1, 1.0, 2.0, 3.0, 0.0, 9.0}; + real maskData[] = {12, 4, 22, 24}; + + inputMat = Matrix::create(1, 25, false, useGpu); + maskMat = Matrix::create(1, 4, false, useGpu); + inputMat->setData(inputData); + maskMat->setData(maskData); + doOneMaxPoolingWithMaskOutputTest( + inputMat, "max-pool-with-mask", useGpu, maskMat); +#ifdef PADDLE_WITH_CUDA + useGpu = true; + inputMat = Matrix::create(1, 25, false, useGpu); + maskMat = Matrix::create(1, 4, false, useGpu); + inputMat->copyFrom(inputData, 25); + maskMat->copyFrom(maskData, 4); + doOneMaxPoolingWithMaskOutputTest( + inputMat, "max-pool-with-mask", useGpu, maskMat); +#endif +} diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index c3e34d5309d9c..41ee5089677f2 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1028,15 +1028,23 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, size_t outputH, size_t outputW, size_t paddingH, - size_t paddingW) { + size_t paddingW, + MatrixPtr maskMatP) { CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal"; real* inputData = inputMat.getData(); + real* maskData = NULL; size_t frameNum = inputMat.getHeight(); CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputH * outputW * channels); + if (maskMatP != NULL) { + CHECK(maskMatP->useGpu_ == true) << "Matrix type are not equal"; + CHECK(outputH * outputW * channels == maskMatP->getWidth()); + maskData = maskMatP->getData(); + } + hl_maxpool_forward(frameNum, inputData, channels, @@ -1051,7 +1059,8 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, paddingH, paddingW, data_, - getStride()); + getStride(), + maskData); } void GpuMatrix::maxPoolBackward(Matrix& inputMat, @@ -1973,9 +1982,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, size_t outputH, size_t outputW, size_t paddingH, - size_t paddingW) { + size_t paddingW, + MatrixPtr maskMatP) { real* inputData = inputMat.getData(); real* outData = data_; + real* maskData = NULL; size_t num = inputMat.getHeight(); size_t inLength = imgSizeH * imgSizeW; size_t outLength = outputH * outputW; @@ -1984,6 +1995,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, CHECK_EQ(channels * outLength, this->getWidth()); size_t outStride = getStride(); + if (maskMatP != NULL) { + maskData = maskMatP->getData(); + CHECK_EQ(channels * outLength, maskMatP->getWidth()); + } + /* initialize the data_ */ for (size_t i = 0; i < height_; i++) { for (size_t j = 0; j < width_; j++) { @@ -2005,10 +2021,21 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, int wstart = pw * strideW - paddingW; int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - outData[ph * outputW + pw] = std::max( - outData[ph * outputW + pw], inputData[h * imgSizeW + w]); + if (maskData == NULL) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + outData[ph * outputW + pw] = std::max( + outData[ph * outputW + pw], inputData[h * imgSizeW + w]); + } + } + } else { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (outData[ph * outputW + pw] < inputData[h * imgSizeW + w]) { + outData[ph * outputW + pw] = inputData[h * imgSizeW + w]; + maskData[ph * outputW + pw] = h * imgSizeW + w; + } + } } } } @@ -2016,6 +2043,8 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, // compute offset inputData += inLength; outData += outLength; + + if (maskData != NULL) maskData += outLength; } } } diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 44180bca8bca5..d252d642258f4 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -861,7 +861,8 @@ class Matrix : public BaseMatrix { /** * Pooling forward operation, pick out the largest element - * in the sizeX of value + * in the sizeX of value, if the maskMatP is not NULL, it will + * also caculate the location indices. */ virtual void maxPoolForward(Matrix& inputMat, size_t imgSizeH, @@ -874,7 +875,8 @@ class Matrix : public BaseMatrix { size_t outputH, size_t outputW, size_t paddingH, - size_t paddingW) { + size_t paddingW, + MatrixPtr maskMatP = NULL) { LOG(FATAL) << "Not implemeted"; } @@ -1426,7 +1428,8 @@ class GpuMatrix : public Matrix { size_t outputH, size_t outputW, size_t paddingH, - size_t paddingW); + size_t paddingW, + MatrixPtr maskMatP); void maxPoolBackward(Matrix& image, size_t imgSizeH, @@ -1697,7 +1700,8 @@ class CpuMatrix : public Matrix { size_t outputH, size_t outputW, size_t paddingH, - size_t paddingW); + size_t paddingW, + MatrixPtr maskMatP); void maxPoolBackward(Matrix& image, size_t imgSizeH, diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 43d02bf70e74c..5891b1f0cd657 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1253,9 +1253,9 @@ def parse_bilinear(bilinear, input_layer_name, bilinear_conf): def parse_pool(pool, input_layer_name, pool_conf, ceil_mode): pool_conf.pool_type = pool.pool_type config_assert(pool.pool_type in [ - 'max-projection', 'avg-projection', 'cudnn-max-pool', 'cudnn-avg-pool' - ], "pool-type %s is not in " - "['max-projection', 'avg-projection', " + 'max-projection', 'avg-projection', 'max-pool-with-mask', 'cudnn-max-pool', 'cudnn-avg-pool' + ], "pool-type %s is not in " \ + "['max-projection', 'avg-projection', 'max-pool-with-mask'," \ "'cudnn-max-pool', 'cudnn-avg-pool']" % pool.pool_type) pool_conf.channels = pool.channels diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 617fbff948bf0..e21071f5b0201 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -20,7 +20,7 @@ from .activations import LinearActivation, SigmoidActivation, TanhActivation, \ ReluActivation, IdentityActivation, SoftmaxActivation, BaseActivation from .evaluators import * -from .poolings import MaxPooling, AvgPooling, BasePoolingType, \ +from .poolings import MaxPooling, AvgPooling, MaxWithMaskPooling, BasePoolingType, \ CudnnAvgPooling, CudnnMaxPooling from .attrs import * from .default_decorators import * @@ -2699,9 +2699,9 @@ def img_pool_layer(input, elif isinstance(pool_type, AvgPooling): pool_type.name = 'avg' - assert type(pool_type) in [AvgPooling, MaxPooling, CudnnAvgPooling, + assert type(pool_type) in [AvgPooling, MaxPooling, MaxWithMaskPooling, CudnnAvgPooling, CudnnMaxPooling], \ - "only (Cudnn)AvgPooling, (Cudnn)MaxPooling are supported" + "only (Cudnn)AvgPooling, (Cudnn)MaxPooling, MaxWithMaskPooling are supported" type_name = pool_type.name + '-projection' \ if ( diff --git a/python/paddle/trainer_config_helpers/poolings.py b/python/paddle/trainer_config_helpers/poolings.py index 0c38a8dce553e..f45616551bcd4 100644 --- a/python/paddle/trainer_config_helpers/poolings.py +++ b/python/paddle/trainer_config_helpers/poolings.py @@ -15,8 +15,8 @@ """ __all__ = [ - "BasePoolingType", "MaxPooling", "AvgPooling", "CudnnMaxPooling", - "CudnnAvgPooling", "SumPooling", "SquareRootNPooling" + "BasePoolingType", "MaxPooling", "AvgPooling", "MaxWithMaskPooling", + "CudnnMaxPooling", "CudnnAvgPooling", "SumPooling", "SquareRootNPooling" ] @@ -55,6 +55,19 @@ def __init__(self, output_max_index=None): self.output_max_index = output_max_index +class MaxWithMaskPooling(BasePoolingType): + """ + MaxWithMask pooling. + + Not only return the very large values for each dimension in sequence or time steps, + but also the location indices of found maxinum values. + + """ + + def __init__(self): + BasePoolingType.__init__(self, "max-pool-with-mask") + + class CudnnMaxPooling(BasePoolingType): """ Cudnn max pooling only support GPU. Return the maxinum value in the