diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index f55197c8c9ebb..9f84db72da24b 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -17,73 +17,6 @@ limitations under the License. */ #include "hl_base.h" -/** - * @brief Shrink column to feature. - * - * @param[in] dataCol expand data. - * @param[in] channels number of channel. - * @param[in] height image height. - * @param[in] width image width. - * @param[in] blockH filter height. - * @param[in] blockW filter width. - * @param[in] strideH stride height. - * @param[in] strideW stride width. - * @param[in] paddingH padding height. - * @param[in] paddingW padding width. - * @param[in] outputH output height. - * @param[in] outputW output width. - * @param[out] dataIm output image data. - * @param[in] alpha - * @param[in] beta - */ -extern void hl_shrink_col2feature(const real* dataCol, - size_t channels, - size_t height, - size_t width, - size_t blockH, - size_t blockW, - size_t strideH, - size_t strideW, - size_t paddingH, - size_t paddingW, - size_t outputH, - size_t outputW, - real* dataIm, - real alpha = 1.0f, - real beta = 0.0f); - -/** - * @brief Expand feature to column. - * - * @param[in] dataIm input image data. - * @param[in] channels number of channel. - * @param[in] height image height. - * @param[in] width image width. - * @param[in] blockH filter height. - * @param[in] blockW filter width. - * @param[in] strideH stride height. - * @param[in] strideW stride width. - * @param[in] paddingH padding height. - * @param[in] paddingW padding width. - * @param[in] outputH output height. - * @param[in] outputW output width. - * @param[out] dataCol expand data. - * - */ -extern void hl_expand_feature2col(const real* dataIm, - size_t channels, - size_t height, - size_t width, - size_t blockH, - size_t blockW, - size_t strideH, - size_t strideW, - size_t paddingH, - size_t paddingW, - size_t outputH, - size_t outputW, - real* dataCol); - /** * @brief Maximum pool forward. * diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index 039551c6cc695..2bbb9fa8dfd5e 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -17,36 +17,6 @@ limitations under the License. */ #include "hl_cnn.h" -inline void hl_shrink_col2feature(const real* dataCol, - size_t channels, - size_t height, - size_t width, - size_t blockH, - size_t blockW, - size_t strideH, - size_t strideW, - size_t paddingH, - size_t paddingW, - size_t outputH, - size_t outputW, - real* dataIm, - real alpha, - real beta) {} - -inline void hl_expand_feature2col(const real* dataIm, - size_t channels, - size_t height, - size_t width, - size_t blockH, - size_t blockW, - size_t strideH, - size_t strideW, - size_t paddingH, - size_t paddingW, - size_t outputH, - size_t outputW, - real* dataCol) {} - inline void hl_maxpool_forward(const int frameCnt, const real* inputData, const int channels, diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index b94f4d8fe4a25..b6e3e63a4f522 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -18,134 +18,6 @@ limitations under the License. */ #include "hl_cnn.h" #include "hl_device_functions.cuh" -__global__ void KeFeature2col(size_t n, size_t height, const real* data_im, - size_t blockH, size_t blockW, size_t width, - size_t strideH, size_t strideW, - size_t paddingH, size_t paddingW, - size_t height_col, size_t width_col, - real* data_col) { - size_t index = - (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; - if (index < n) { - size_t w_out = index % width_col; - index /= width_col; - size_t h_out = index % height_col; - size_t channel_in = index / height_col; - size_t channel_out = channel_in * blockH * blockW; - size_t h_in = h_out * strideH; - size_t w_in = w_out * strideW; - - data_col += (channel_out * height_col + h_out) * width_col + w_out; - for (size_t i = 0; i < blockH; ++i) { - for (size_t j = 0; j < blockW; ++j) { - int rIdx = int(h_in+i); - int cIdx = int(w_in+j); - if ((rIdx-(int)paddingH) >= (int)height || - (rIdx-(int)paddingH) < 0 || - (cIdx-(int)paddingW) >= (int)width || - (cIdx-(int)paddingW) < 0) { - *data_col = 0; - } else { - rIdx = rIdx + channel_in*height - paddingH; - cIdx = cIdx - paddingW; - *data_col = data_im[rIdx* width + cIdx]; - } - data_col += height_col * width_col; - } - } - } -} - -void hl_expand_feature2col(const real* dataIm, size_t channels, - size_t height, size_t width, - size_t blockH, size_t blockW, - size_t strideH, size_t strideW, - size_t paddingH, size_t paddingW, - size_t outputH, size_t outputW, - real* dataCol) { - size_t numKernels = channels * outputH * outputW; - - size_t blocks = (numKernels + 1024 -1) / 1024; - size_t blockX = 512; - size_t blockY = (blocks+512-1)/512; - dim3 threads(1024, 1); - dim3 grid(blockX, blockY); - KeFeature2col<<< grid, threads, 0, STREAM_DEFAULT >>> - (numKernels, height, dataIm, blockH, blockW, width, - strideH, strideW, paddingH, paddingW, - outputH, outputW, dataCol); - CHECK_SYNC("hl_expand_feature2col failed"); -} - -__global__ void KeCol2Feature(size_t n, const real* data_col, size_t height, - size_t width, size_t channels, - size_t blockH, size_t blockW, - size_t strideH, size_t strideW, - size_t paddingH, size_t paddingW, - size_t height_col, size_t width_col, - real* data_im, real alpha, real beta) { - size_t index = - (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; - if (index < n) { - real val = 0; - int w = int(index % width); - int h = int((index / width) % height); - int c = int(index / (width * height)); - if ((w - (int)paddingW) >= 0 && - (w - (int)paddingW) < (width-2 * paddingW) && - (h - (int)paddingH) >= 0 && - (h - paddingH) < (height - 2 * paddingH)) { - // compute the start and end of the output - int w_col_start = - (w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1; - int w_col_end = - min((int)(w / (int)strideW + 1), (int)(width_col)); - int h_col_start = - (h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1; - int h_col_end = min(int(h / strideH + 1), int(height_col)); - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - // the col location: [c * width * height + h_out, w_out] - int c_col = int(c * blockH* blockW) + \ - (h - h_col * (int)strideH) * (int)blockW + - (w - w_col * (int)strideW); - val += data_col[(c_col * height_col + h_col) * width_col + w_col]; - } - } - h -= paddingH; - w -= paddingW; - real tD = data_im[c*((width-2*paddingW) * (height-2*paddingH)) + - h*(width-2*paddingW) + w]; - data_im[c*((width-2*paddingW) * (height-2*paddingH)) + - h*(width-2*paddingW) + w] = alpha * val + beta*tD; - } - } -} - -void hl_shrink_col2feature(const real * dataCol, size_t channels, - size_t height, size_t width, - size_t blockH, size_t blockW, - size_t strideH, size_t strideW, - size_t paddingH, size_t paddingW, - size_t outputH, size_t outputW, - real* dataIm, real alpha, real beta) { - size_t numKernels = channels * (height + 2*paddingH) * (width + 2*paddingW); - - size_t blocks = (numKernels + 1024 -1) / 1024; - size_t blockX = 512; - size_t blockY = (blocks+512-1)/512; - dim3 threads(1024, 1); - dim3 grid(blockX, blockY); - - // To avoid involving atomic operations, we will launch one kernel per - // bottom dimension, and then in the kernel add up the top dimensions. - KeCol2Feature<<< grid, threads, 0, STREAM_DEFAULT >>> - (numKernels, dataCol, height + 2*paddingH, width + 2*paddingW, - channels, blockH, blockW, strideH, strideW, paddingH, paddingW, - outputH, outputW, dataIm, alpha, beta); - CHECK_SYNC("hl_shrink_col2feature failed"); -} - __global__ void KeMaxPoolForward(const int nthreads, const real* inputData, const int channels, const int height, const int width, diff --git a/paddle/function/BlockExpandOp.cpp b/paddle/function/BlockExpandOp.cpp new file mode 100644 index 0000000000000..a89b6bba45843 --- /dev/null +++ b/paddle/function/BlockExpandOp.cpp @@ -0,0 +1,202 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" +#include "Im2Col.h" + +namespace paddle { + +/* + * \brief Converts the image data of four dimensions(NCHW) into + * a sequence data of three dimensions(NST) in the forward calculation, + * which is reversed in the backward calculation. + * Where N is batch size, S is the length of the sequence after each + * image is expanded, T is the size of each time step in the sequence. + * + * Arguments in forward function: + * \param inputs[0] Image data of NCHW format. + * \param outputs[0] Sequence data of NST format. + * + * Arguments in backward function: + * \param inputs[0] Sequence data of NST format. + * \param outputs[0] Image data of NCHW format. + */ +class BlockExpandFunction : public FunctionBase { +public: + void init(const FuncConfig& config) override { + // function arguments + strides_ = config.get>("strides"); + paddings_ = config.get>("paddings"); + blocks_ = config.get>("blocks"); + + // number of inputs and outputs + numInputs_ = 1; + numOutputs_ = 1; + } + + void checkShape(const TensorShape& image, const TensorShape& sequence) const { + // image shape should be 4-dimensional. + CHECK_EQ(image.ndims(), (size_t)4); + // sequence shape should be 3-dimensional. + CHECK_EQ(sequence.ndims(), (size_t)3); + // The batchSize of the image needs to be equal to + // the batchSize of the sequence. + CHECK_EQ(image[0], sequence[0]); + } + + // Calculate the shape of colData based on the shape of the image + // and the shape of the sequence. + TensorShape getColShape(const TensorShape& image, + const TensorShape& sequence) const { + size_t inputChannels = image[1]; + size_t inputHeight = image[2]; + size_t inputWidth = image[3]; + size_t seqLength = sequence[1]; + size_t stepSize = sequence[2]; + size_t outputHeight = + 1 + + (inputHeight + 2 * paddingH() - blockH() + strideH() - 1) / strideH(); + size_t outputWidth = + 1 + + (inputWidth + 2 * paddingW() - blockW() + strideW() - 1) / strideW(); + CHECK_EQ(seqLength, outputHeight * outputWidth); + CHECK_EQ(stepSize, inputChannels * blockH() * blockW()); + + // [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] + return TensorShape({outputHeight, + outputWidth, + inputChannels, + (size_t)blockH(), + (size_t)blockW()}); + } + +protected: + std::vector strides_; + std::vector paddings_; + std::vector blocks_; + + inline int strideH() const { return strides_[0]; } + + inline int strideW() const { return strides_[1]; } + + inline int paddingH() const { return paddings_[0]; } + + inline int paddingW() const { return paddings_[1]; } + + inline int blockH() const { return blocks_[0]; } + + inline int blockW() const { return blocks_[1]; } +}; + +template +class BlockExpandForward : public BlockExpandFunction { +public: + void init(const FuncConfig& config) override { + BlockExpandFunction::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& image = inputs[0].shape(); + const TensorShape& sequence = outputs[0].shape(); + checkShape(image, sequence); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + const TensorShape& image = inputs[0].shape(); + const TensorShape& sequence = outputs[0].shape(); + + TensorShape imShape = TensorShape({image[1], image[2], image[3]}); + TensorShape colShape = getColShape(image, sequence); + size_t batchSize = image[0]; + + real* imageData = inputs[0].data(); + real* seqData = outputs[0].data(); + Im2ColFunctor im2col; + for (size_t i = 0; i < batchSize; i++) { + // The result of im2col is [outputHeight, outputWidth, + // inputChannels, filterHeight, filterWidth], and it is easy to + // reshape into [seqLength, stepSize], where seqLength is equal + // output_height * output_width, stepSize is equal + // input_channels * filter_height * filter_width + im2col(imageData, + imShape, + seqData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + imageData += imShape.getElements(); + seqData += colShape.getElements(); + } + } +}; + +template +class BlockExpandBackward : public BlockExpandFunction { +public: + void init(const FuncConfig& config) override { + BlockExpandFunction::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& image = outputs[0].shape(); + const TensorShape& sequence = inputs[0].shape(); + checkShape(image, sequence); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); + // Since the implementation of Col2ImFunctor is ADD_TO, + // this function only supports ADD_TO mode. + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + const TensorShape& image = outputs[0].shape(); + const TensorShape& sequence = inputs[0].shape(); + + TensorShape imShape = TensorShape({image[1], image[2], image[3]}); + TensorShape colShape = getColShape(image, sequence); + size_t batchSize = image[0]; + + real* imageData = outputs[0].data(); + real* seqData = inputs[0].data(); + Col2ImFunctor col2im; + for (size_t i = 0; i < batchSize; i++) { + col2im(imageData, + imShape, + seqData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + imageData += imShape.getElements(); + seqData += colShape.getElements(); + } + } +}; + +REGISTER_TYPED_FUNC(BlockExpand, CPU, BlockExpandForward); +REGISTER_TYPED_FUNC(BlockExpandGrad, CPU, BlockExpandBackward); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(BlockExpand, GPU, BlockExpandForward); +REGISTER_TYPED_FUNC(BlockExpandGrad, GPU, BlockExpandBackward); +#endif + +} // namespace paddle diff --git a/paddle/function/BlockExpandOpTest.cpp b/paddle/function/BlockExpandOpTest.cpp new file mode 100644 index 0000000000000..5e4897e72ba9f --- /dev/null +++ b/paddle/function/BlockExpandOpTest.cpp @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +namespace paddle { + +TEST(BlockExpandForward, real) { + for (size_t batchSize : {5, 32}) { + for (size_t channels : {1, 5, 32}) { + for (size_t inputHeight : {5, 33, 100}) { + for (size_t inputWidth : {5, 32, 96}) { + for (size_t block : {1, 3, 5}) { + for (size_t stride : {1, 2}) { + for (size_t padding : {0, 1}) { + // init Test object + std::vector strides = {stride, stride}; + std::vector paddings = {padding, padding}; + std::vector blocks = {block, block}; + CpuGpuFuncCompare test("BlockExpand", + FuncConfig() + .set("strides", strides) + .set("paddings", paddings) + .set("blocks", blocks)); + + size_t outputHeight = + 1 + + (inputHeight + 2 * padding - block + stride - 1) / stride; + size_t outputWidth = + 1 + + (inputWidth + 2 * padding - block + stride - 1) / stride; + TensorShape inputShape = + TensorShape({batchSize, channels, inputHeight, inputWidth}); + TensorShape outputShape = + TensorShape({batchSize, + outputHeight * outputWidth, + channels * block * block}); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, inputShape)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, outputShape)); + // run Function + test.run(); + } + } + } + } + } + } + } +} + +TEST(BlockExpandBackward, real) { + for (size_t batchSize : {5, 32}) { + for (size_t channels : {1, 5, 32}) { + for (size_t inputHeight : {5, 33, 100}) { + for (size_t inputWidth : {5, 32, 96}) { + for (size_t block : {1, 3, 5}) { + for (size_t stride : {1, 2}) { + for (size_t padding : {0, 1}) { + // init Test object + std::vector strides = {stride, stride}; + std::vector paddings = {padding, padding}; + std::vector blocks = {block, block}; + CpuGpuFuncCompare test("BlockExpandGrad", + FuncConfig() + .set("strides", strides) + .set("paddings", paddings) + .set("blocks", blocks)); + + size_t outputHeight = + 1 + + (inputHeight + 2 * padding - block + stride - 1) / stride; + size_t outputWidth = + 1 + + (inputWidth + 2 * padding - block + stride - 1) / stride; + TensorShape inputShape = + TensorShape({batchSize, channels, inputHeight, inputWidth}); + TensorShape outputShape = + TensorShape({batchSize, + outputHeight * outputWidth, + channels * block * block}); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, outputShape)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, inputShape), + ADD_TO); + // run Function + test.run(); + } + } + } + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 2bec00cdb2d32..93304f7303769 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -36,10 +36,12 @@ if(WITH_GPU) add_simple_unittest(MulOpTest) add_simple_unittest(CosSimOpTest) add_simple_unittest(RowConvOpTest) + add_simple_unittest(BlockExpandOpTest) add_simple_unittest(CropOpTest) endif() add_simple_unittest(ConvOpTest) +add_simple_unittest(Im2ColTest) endif() add_style_check_target(paddle_function ${h_files}) diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 00880effc59cc..9deb2739fcfff 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -12,101 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "GemmConvOp.h" +#include "ConvOp.h" #include "GemmFunctor.h" +#include "Im2Col.h" #include "paddle/math/MemoryHandle.h" namespace paddle { -/* - * imData = [input_channels, input_height, input_width] - * colData = [input_channels, filter_height, filter_width, - * output_height, output_width] - */ -template -class Im2ColFunctor { -public: - void operator()(const T* imData, - int inputChannels, - int inputHeight, - int inputWidth, - int filterHeight, - int filterWidth, - int strideHeight, - int strideWidth, - int paddingHeight, - int paddingWidth, - int outputHeight, - int outputWidth, - T* colData) { - int channelsCol = inputChannels * filterHeight * filterWidth; - - for (int c = 0; c < channelsCol; ++c) { - int wOffset = c % filterWidth; - int hOffset = (c / filterWidth) % filterHeight; - int c_im = c / filterWidth / filterHeight; - for (int h = 0; h < outputHeight; ++h) { - for (int w = 0; w < outputWidth; ++w) { - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; - if ((imRowIdx - paddingHeight) < 0 || - (imRowIdx - paddingHeight) >= inputHeight || - (imColIdx - paddingWidth) < 0 || - (imColIdx - paddingWidth) >= inputWidth) { - colData[(c * outputHeight + h) * outputWidth + w] = T(0); - } else { - imRowIdx += c_im * inputHeight - paddingHeight; - imColIdx -= paddingWidth; - colData[(c * outputHeight + h) * outputWidth + w] = - imData[imRowIdx * inputWidth + imColIdx]; - } - } - } - } - } -}; - -template -class Col2ImFunctor { -public: - void operator()(const T* colData, - int inputChannels, - int inputHeight, - int inputWidth, - int filterHeight, - int filterWidth, - int strideHeight, - int strideWidth, - int paddingHeight, - int paddingWidth, - int outputHeight, - int outputWidth, - T* imData) { - int channelsCol = inputChannels * filterHeight * filterWidth; - - for (int c = 0; c < channelsCol; ++c) { - int wOffset = c % filterWidth; - int hOffset = (c / filterWidth) % filterHeight; - int c_im = c / filterWidth / filterHeight; - for (int h = 0; h < outputHeight; ++h) { - for (int w = 0; w < outputWidth; ++w) { - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; - if ((imRowIdx - paddingHeight) >= 0 && - (imRowIdx - paddingHeight) < inputHeight && - (imColIdx - paddingWidth) >= 0 && - (imColIdx - paddingWidth) < inputWidth) { - imRowIdx += c_im * inputHeight - paddingHeight; - imColIdx -= paddingWidth; - imData[imRowIdx * inputWidth + imColIdx] += - colData[(c * outputHeight + h) * outputWidth + w]; - } - } - } - } - } -}; - /* * \brief Forward calculation of convolution. */ @@ -154,15 +66,20 @@ class GemmConvFunction : public ConvFunctionBase { real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); real* outputData = outputs[0].data(); - - size_t size = inputChannels / groups_ * filterHeight * filterWidth * - outputHeight * outputWidth; - resizeBuffer(size); + TensorShape imShape = + TensorShape({inputChannels / groups_, inputHeight, inputWidth}); + TensorShape colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + resizeBuffer(colShape.getElements()); real* colData = reinterpret_cast(memory_->getBuf()); - Im2ColFunctor im2col; + Im2ColFunctor im2col; GemmFunctor gemm; - size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; + size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; size_t filterOffset = filter.getElements() / groups_; @@ -170,18 +87,13 @@ class GemmConvFunction : public ConvFunctionBase { for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { im2col(inputData + g * inputOffset, - inputChannels / groups_, - inputHeight, - inputWidth, - filterHeight, - filterWidth, + imShape, + colData, + colShape, strideH(), strideW(), paddingH(), - paddingW(), - outputHeight, - outputWidth, - colData); + paddingW()); int M = outputChannels / groups_; int N = outputHeight * outputWidth; @@ -247,15 +159,20 @@ class GemmConvGradInputFunction : public ConvFunctionBase { real* outputGrad = inputs[0].data(); real* filterData = inputs[1].data(); real* inputGrad = outputs[0].data(); - - size_t size = inputChannels / groups_ * filterHeight * filterWidth * - outputHeight * outputWidth; - resizeBuffer(size); + TensorShape imShape = + TensorShape({inputChannels / groups_, inputHeight, inputWidth}); + TensorShape colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + resizeBuffer(colShape.getElements()); real* colData = reinterpret_cast(memory_->getBuf()); - Col2ImFunctor col2im; + Col2ImFunctor col2im; GemmFunctor gemm; - size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; + size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; size_t filterOffset = filter.getElements() / groups_; @@ -278,20 +195,14 @@ class GemmConvGradInputFunction : public ConvFunctionBase { 0.0f, colData, N); - - col2im(colData, - inputChannels / groups_, - inputHeight, - inputWidth, - filterHeight, - filterWidth, + col2im(inputGrad + g * inputOffset, + imShape, + colData, + colShape, strideH(), strideW(), paddingH(), - paddingW(), - outputHeight, - outputWidth, - inputGrad + g * inputOffset); + paddingW()); } inputGrad += inputChannels * inputHeight * inputWidth; outputGrad += outputChannels * outputHeight * outputWidth; @@ -344,33 +255,33 @@ class GemmConvGradFilterFunction : public ConvFunctionBase { real* outputGrad = inputs[0].data(); real* inputData = inputs[1].data(); real* filterGrad = outputs[0].data(); - - size_t size = inputChannels / groups_ * filterHeight * filterWidth * - outputHeight * outputWidth; - resizeBuffer(size); + TensorShape imShape = + TensorShape({inputChannels / groups_, inputHeight, inputWidth}); + TensorShape colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + resizeBuffer(colShape.getElements()); real* colData = reinterpret_cast(memory_->getBuf()); - Im2ColFunctor im2col; + Im2ColFunctor im2col; GemmFunctor gemm; - size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; + size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; size_t filterOffset = filter.getElements() / groups_; for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { im2col(inputData + g * inputOffset, - inputChannels / groups_, - inputHeight, - inputWidth, - filterHeight, - filterWidth, + imShape, + colData, + colShape, strideH(), strideW(), paddingH(), - paddingW(), - outputHeight, - outputWidth, - colData); + paddingW()); int M = outputChannels / groups_; int K = outputHeight * outputWidth; diff --git a/paddle/function/GemmConvOp.h b/paddle/function/GemmConvOp.h deleted file mode 100644 index 9f11cce597a07..0000000000000 --- a/paddle/function/GemmConvOp.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "ConvOp.h" - -namespace paddle { - -/* - * imData = [input_channels, input_height, input_width] - * colData = [input_channels, filter_height, filter_width, - * output_height, output_width] - */ -template -class Im2ColFunctor { -public: - void operator()(const T* imData, - int inputChannels, - int inputHeight, - int inputWidth, - int filterHeight, - int filterWidth, - int strideHeight, - int strideWidth, - int paddingHeight, - int paddingWidth, - int outputHeight, - int outputWidth, - T* colData); -}; - -template -class Col2ImFunctor { -public: - void operator()(const T* colData, - int inputChannels, - int inputHeight, - int inputWidth, - int filterHeight, - int filterWidth, - int strideHeight, - int strideWidth, - int paddingHeight, - int paddingWidth, - int outputHeight, - int outputWidth, - T* imData); -}; - -} // namespace paddle diff --git a/paddle/function/GemmConvOpGpu.cu b/paddle/function/GemmConvOpGpu.cu deleted file mode 100644 index 2a1795ff0fb56..0000000000000 --- a/paddle/function/GemmConvOpGpu.cu +++ /dev/null @@ -1,186 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "ConvOp.h" -#include "GemmConvOp.h" - -namespace paddle { - -template -__global__ -void im2col(const T* data_im, int numOuts, int height, int width, - int blockH, int blockW, - int strideH, int strideW, - int paddingH, int paddingW, - int height_col, int width_col, - T* data_col) { - int index = - (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; - if (index < numOuts) { - int w_out = index % width_col; - index /= width_col; - int h_out = index % height_col; - int channel_in = index / height_col; - int channel_out = channel_in * blockH * blockW; - int h_in = h_out * strideH; - int w_in = w_out * strideW; - - data_col += (channel_out * height_col + h_out) * width_col + w_out; - for (int i = 0; i < blockH; ++i) { - for (int j = 0; j < blockW; ++j) { - int rIdx = int(h_in+i); - int cIdx = int(w_in+j); - if ((rIdx-(int)paddingH) >= (int)height || - (rIdx-(int)paddingH) < 0 || - (cIdx-(int)paddingW) >= (int)width || - (cIdx-(int)paddingW) < 0) { - *data_col = 0; - } else { - rIdx = rIdx + channel_in*height - paddingH; - cIdx = cIdx - paddingW; - *data_col = data_im[rIdx* width + cIdx]; - } - data_col += height_col * width_col; - } - } - } -} - -template -class Im2ColFunctor { -public: - void operator()(const T* imData, - int inputChannels, - int inputHeight, - int inputWidth, - int filterHeight, - int filterWidth, - int strideHeight, - int strideWidth, - int paddingHeight, - int paddingWidth, - int outputHeight, - int outputWidth, - T* colData) { - int numKernels = inputChannels * outputHeight * outputWidth; - int blocks = (numKernels + 1024 -1) / 1024; - int blockX = 512; - int blockY = (blocks + 512 - 1) / 512; - dim3 threads(1024, 1); - dim3 grid(blockX, blockY); - im2col<<< grid, threads, 0, STREAM_DEFAULT >>> - (imData, numKernels, inputHeight, inputWidth, filterHeight, filterWidth, - strideHeight, strideWidth, paddingHeight, paddingWidth, - outputHeight, outputWidth, colData); - CHECK_SYNC("Im2ColFunctor GPU failed"); - } -}; - -template -__global__ -void col2im(size_t n, const T* data_col, size_t height, - size_t width, size_t channels, - size_t blockH, size_t blockW, - size_t strideH, size_t strideW, - size_t paddingH, size_t paddingW, - size_t height_col, size_t width_col, - T* data_im) { - size_t index = - (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; - if (index < n) { - T val = 0; - int w = int(index % width); - int h = int((index / width) % height); - int c = int(index / (width * height)); - if ((w - (int)paddingW) >= 0 && - (w - (int)paddingW) < (width-2 * paddingW) && - (h - (int)paddingH) >= 0 && - (h - paddingH) < (height - 2 * paddingH)) { - // compute the start and end of the output - int w_col_start = - (w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1; - int w_col_end = - min((int)(w / (int)strideW + 1), (int)(width_col)); - int h_col_start = - (h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1; - int h_col_end = min(int(h / strideH + 1), int(height_col)); - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - // the col location: [c * width * height + h_out, w_out] - int c_col = int(c * blockH* blockW) + \ - (h - h_col * (int)strideH) * (int)blockW + - (w - w_col * (int)strideW); - val += data_col[(c_col * height_col + h_col) * width_col + w_col]; - } - } - h -= paddingH; - w -= paddingW; - data_im[c*((width-2*paddingW) * (height-2*paddingH)) + - h*(width-2*paddingW) + w] += val; - } - } -} - -template -class Col2ImFunctor { -public: - void operator()(const T* colData, - int inputChannels, - int inputHeight, - int inputWidth, - int filterHeight, - int filterWidth, - int strideHeight, - int strideWidth, - int paddingHeight, - int paddingWidth, - int outputHeight, - int outputWidth, - T* imData) { - size_t numKernels = inputChannels * (inputHeight + 2*paddingHeight) - * (inputWidth + 2*paddingWidth); - - size_t blocks = (numKernels + 1024 -1) / 1024; - size_t blockX = 512; - size_t blockY = (blocks+512-1)/512; - dim3 threads(1024, 1); - dim3 grid(blockX, blockY); - - // To avoid involving atomic operations, we will launch one kernel per - // bottom dimension, and then in the kernel add up the top dimensions. - col2im<<< grid, threads, 0, STREAM_DEFAULT >>> - (numKernels, - colData, - inputHeight + 2*paddingHeight, - inputWidth + 2*paddingWidth, - inputChannels, - filterHeight, - filterWidth, - strideHeight, - strideWidth, - paddingHeight, - paddingWidth, - outputHeight, - outputWidth, - imData); - CHECK_SYNC("Col2ImFunctor GPU failed"); - } -}; - -template class Im2ColFunctor; -template class Im2ColFunctor; -template class Col2ImFunctor; -template class Col2ImFunctor; - -} // namespace paddle diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h new file mode 100644 index 0000000000000..48e2e32f9256f --- /dev/null +++ b/paddle/function/Im2Col.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "TensorShape.h" +#include "TensorType.h" + +namespace paddle { + +/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */ +enum ColFormat { kCFO = 0, kOCF = 1 }; + +/* + * \brief Converts the image data of three dimensions(CHW) into a colData of + * five dimensions in the Im2ColFunctor calculation, + * And in the Col2ImFunctor calculation, it is reversed. + * + * \param imData Image data. + * \param imShape The shape of imData, + * [inputChannels, inputHeight, inputWidth]. + * \param colData Column data. + * \param colShape The shape of colData. + * + * If the template argument Format is kCFO, the shape of colData is: + * [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] + * So, it is easy to reshape into a convolution matrix for convolution + * calculation based on matrix multiplication. + * The shape of convolution matrix is [height, width], where the height is equal + * inputChannels * filterHeight * filterWidth, and the width is equal + * outputHeight * outputWidth. + * + * Reshape: + * shape of colData shape of convolution matrix + * [inputChannels, + * filterHeight, + * filterWidth, ======> [height, width] + * outputHeight, + * outputWidth] + * + * If the template argument Format is kOCF, the shape of colData is: + * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] + * So, it is easy to reshape into a sequence matrix for rnn calculation. + * The shape of sequence matrix is [seqLength, stepSize], where the seqLength + * is equal outputHeight * outputWidth, and the stepSize is equal + * inputChannels * filterHeight * filterWidth. + * + * Reshape: + * shape of colData shape of sequence matrix + * [outputHeight, + * outputWidth, + * inputChannels, ======> [seqLength, stepSize] + * filterHeight, + * filterWidth] + * + * \note The caller needs to ensure that imShape.inputChannels is equal to + * colShape.inputChannels. + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth); +}; + +template +class Col2ImFunctor { +public: + void operator()(T* imData, + const TensorShape& imShape, + const T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth); +}; + +} // namespace paddle diff --git a/paddle/function/Im2ColOp.cpp b/paddle/function/Im2ColOp.cpp new file mode 100644 index 0000000000000..b7d1eb1eded7a --- /dev/null +++ b/paddle/function/Im2ColOp.cpp @@ -0,0 +1,235 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Im2Col.h" + +namespace paddle { + +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputHeight = colShape[3]; + int outputWidth = colShape[4]; + int channelsCol = inputChannels * filterHeight * filterWidth; + + for (int c = 0; c < channelsCol; ++c) { + int wOffset = c % filterWidth; + int hOffset = (c / filterWidth) % filterHeight; + int c_im = c / filterWidth / filterHeight; + for (int h = 0; h < outputHeight; ++h) { + for (int w = 0; w < outputWidth; ++w) { + int imRowIdx = h * strideHeight + hOffset; + int imColIdx = w * strideWidth + wOffset; + if ((imRowIdx - paddingHeight) < 0 || + (imRowIdx - paddingHeight) >= inputHeight || + (imColIdx - paddingWidth) < 0 || + (imColIdx - paddingWidth) >= inputWidth) { + colData[(c * outputHeight + h) * outputWidth + w] = T(0); + } else { + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; + colData[(c * outputHeight + h) * outputWidth + w] = + imData[imRowIdx * inputWidth + imColIdx]; + } + } + } + } + } +}; + +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] + */ +template +class Col2ImFunctor { +public: + void operator()(T* imData, + const TensorShape& imShape, + const T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputHeight = colShape[3]; + int outputWidth = colShape[4]; + int channelsCol = inputChannels * filterHeight * filterWidth; + + for (int c = 0; c < channelsCol; ++c) { + int wOffset = c % filterWidth; + int hOffset = (c / filterWidth) % filterHeight; + int c_im = c / filterWidth / filterHeight; + for (int h = 0; h < outputHeight; ++h) { + for (int w = 0; w < outputWidth; ++w) { + int imRowIdx = h * strideHeight + hOffset; + int imColIdx = w * strideWidth + wOffset; + if ((imRowIdx - paddingHeight) >= 0 && + (imRowIdx - paddingHeight) < inputHeight && + (imColIdx - paddingWidth) >= 0 && + (imColIdx - paddingWidth) < inputWidth) { + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; + imData[imRowIdx * inputWidth + imColIdx] += + colData[(c * outputHeight + h) * outputWidth + w]; + } + } + } + } + } +}; + +template class Im2ColFunctor; +template class Im2ColFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; + +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[3]; + int filterWidth = colShape[4]; + int outputHeight = colShape[0]; + int outputWidth = colShape[1]; + for (int outputH = 0; outputH < outputHeight; ++outputH) { + for (int outputW = 0; outputW < outputWidth; ++outputW) { + for (int channel = 0; channel < inputChannels; ++channel) { + for (int filterH = 0; filterH < filterHeight; ++filterH) { + for (int filterW = 0; filterW < filterWidth; ++filterW) { + int imRowOffset = + outputH * strideHeight + filterH - paddingHeight; + int imColOffset = outputW * strideWidth + filterW - paddingWidth; + int colDataOffset = + (((outputH * outputWidth + outputW) * inputChannels + + channel) * + filterHeight + + filterH) * + filterWidth + + filterW; + if (imRowOffset < 0 || imRowOffset >= inputHeight || + imColOffset < 0 || imColOffset >= inputWidth) { + colData[colDataOffset] = float(0); + } else { + int imDataOffset = + (channel * inputHeight + imRowOffset) * inputWidth + + imColOffset; + colData[colDataOffset] = imData[imDataOffset]; + } + } + } + } + } + } + } +}; + +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] + */ +template +class Col2ImFunctor { +public: + void operator()(T* imData, + const TensorShape& imShape, + const T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[3]; + int filterWidth = colShape[4]; + int outputHeight = colShape[0]; + int outputWidth = colShape[1]; + for (int outputH = 0; outputH < outputHeight; ++outputH) { + for (int outputW = 0; outputW < outputWidth; ++outputW) { + for (int channel = 0; channel < inputChannels; ++channel) { + for (int filterH = 0; filterH < filterHeight; ++filterH) { + for (int filterW = 0; filterW < filterWidth; ++filterW) { + int imRowOffset = + outputH * strideHeight + filterH - paddingHeight; + int imColOffset = outputW * strideWidth + filterW - paddingWidth; + int colDataOffset = + (((outputH * outputWidth + outputW) * inputChannels + + channel) * + filterHeight + + filterH) * + filterWidth + + filterW; + if (imRowOffset >= 0 && imRowOffset < inputHeight && + imColOffset >= 0 && imColOffset < inputWidth) { + int imDataOffset = + (channel * inputHeight + imRowOffset) * inputWidth + + imColOffset; + imData[imDataOffset] += colData[colDataOffset]; + } + } + } + } + } + } + } +}; + +template class Im2ColFunctor; +template class Im2ColFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; + +} // namespace paddle diff --git a/paddle/function/Im2ColOpGpu.cu b/paddle/function/Im2ColOpGpu.cu new file mode 100644 index 0000000000000..15ba854009636 --- /dev/null +++ b/paddle/function/Im2ColOpGpu.cu @@ -0,0 +1,381 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Im2Col.h" +#include "hl_device_functions.cuh" + +namespace paddle { + +template +__global__ +void im2col(const T* data_im, int numOuts, int height, int width, + int blockH, int blockW, + int strideH, int strideW, + int paddingH, int paddingW, + int height_col, int width_col, + T* data_col) { + int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < numOuts) { + int w_out = index % width_col; + index /= width_col; + int h_out = index % height_col; + int channel_in = index / height_col; + int channel_out = channel_in * blockH * blockW; + int h_in = h_out * strideH; + int w_in = w_out * strideW; + + data_col += (channel_out * height_col + h_out) * width_col + w_out; + for (int i = 0; i < blockH; ++i) { + for (int j = 0; j < blockW; ++j) { + int rIdx = int(h_in+i); + int cIdx = int(w_in+j); + if ((rIdx-(int)paddingH) >= (int)height || + (rIdx-(int)paddingH) < 0 || + (cIdx-(int)paddingW) >= (int)width || + (cIdx-(int)paddingW) < 0) { + *data_col = 0; + } else { + rIdx = rIdx + channel_in*height - paddingH; + cIdx = cIdx - paddingW; + *data_col = data_im[rIdx* width + cIdx]; + } + data_col += height_col * width_col; + } + } + } +} + +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputHeight = colShape[3]; + int outputWidth = colShape[4]; + + int numKernels = inputChannels * outputHeight * outputWidth; + int blocks = (numKernels + 1024 -1) / 1024; + int blockX = 512; + int blockY = (blocks + 512 - 1) / 512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + im2col<<< grid, threads, 0, STREAM_DEFAULT >>> + (imData, numKernels, inputHeight, inputWidth, filterHeight, filterWidth, + strideHeight, strideWidth, paddingHeight, paddingWidth, + outputHeight, outputWidth, colData); + CHECK_SYNC("Im2ColFunctor GPU failed"); + } +}; + +template +__global__ +void col2im(size_t n, const T* data_col, size_t height, + size_t width, size_t channels, + size_t blockH, size_t blockW, + size_t strideH, size_t strideW, + size_t paddingH, size_t paddingW, + size_t height_col, size_t width_col, + T* data_im) { + size_t index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < n) { + T val = 0; + int w = int(index % width); + int h = int((index / width) % height); + int c = int(index / (width * height)); + if ((w - (int)paddingW) >= 0 && + (w - (int)paddingW) < (width-2 * paddingW) && + (h - (int)paddingH) >= 0 && + (h - paddingH) < (height - 2 * paddingH)) { + // compute the start and end of the output + int w_col_start = + (w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1; + int w_col_end = + min((int)(w / (int)strideW + 1), (int)(width_col)); + int h_col_start = + (h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1; + int h_col_end = min(int(h / strideH + 1), int(height_col)); + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + // the col location: [c * width * height + h_out, w_out] + int c_col = int(c * blockH* blockW) + \ + (h - h_col * (int)strideH) * (int)blockW + + (w - w_col * (int)strideW); + val += data_col[(c_col * height_col + h_col) * width_col + w_col]; + } + } + h -= paddingH; + w -= paddingW; + data_im[c*((width-2*paddingW) * (height-2*paddingH)) + + h*(width-2*paddingW) + w] += val; + } + } +} + +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] + */ +template +class Col2ImFunctor { +public: + void operator()(T* imData, + const TensorShape& imShape, + const T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputHeight = colShape[3]; + int outputWidth = colShape[4]; + + size_t numKernels = inputChannels * (inputHeight + 2*paddingHeight) + * (inputWidth + 2*paddingWidth); + + size_t blocks = (numKernels + 1024 -1) / 1024; + size_t blockX = 512; + size_t blockY = (blocks+512-1)/512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + col2im<<< grid, threads, 0, STREAM_DEFAULT >>> + (numKernels, + colData, + inputHeight + 2*paddingHeight, + inputWidth + 2*paddingWidth, + inputChannels, + filterHeight, + filterWidth, + strideHeight, + strideWidth, + paddingHeight, + paddingWidth, + outputHeight, + outputWidth, + imData); + CHECK_SYNC("Col2ImFunctor GPU failed"); + } +}; + +template class Im2ColFunctor; +template class Im2ColFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; + +template +__global__ +void im2colOCF(const T* imData, T* colData, + int inputChannels, + int inputHeight, int inputWidth, + int filterHeight, int filterWidth, + int strideHeight, int strideWidth, + int paddingHeight, int paddingWidth, + int outputHeight, int outputWidth) { + int swId = blockIdx.x; + int shId = blockIdx.y; + for (int channelId = threadIdx.z; + channelId < inputChannels; + channelId += blockDim.z) { + for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { + for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { + int widthOffset = idx + swId * strideWidth - paddingWidth; + int heightOffset = idy + shId * strideHeight - paddingHeight; + int imOffset = widthOffset + heightOffset * inputWidth + + channelId * inputHeight * inputWidth; + + int colOffset = idx + idy * filterWidth + + channelId * filterHeight * filterWidth + + (shId * outputWidth + swId) + * (inputChannels * filterHeight * filterWidth); + + if (heightOffset >= inputHeight || heightOffset < 0 || + widthOffset >= inputWidth || widthOffset < 0) { + colData[colOffset] = T(0); + } else { + colData[colOffset] = imData[imOffset]; + } + } + } + } +} + +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[3]; + int filterWidth = colShape[4]; + int outputHeight = colShape[0]; + int outputWidth = colShape[1]; + + int blockDimX = 0; + int blockDimY = 0; + if (filterHeight <= 4 && filterWidth <= 4) { + blockDimX = 4; + blockDimY = 4; + } else if (filterHeight <= 8 && filterWidth <= 8) { + blockDimX = 8; + blockDimY = 8; + } else if (filterHeight <= 16 && filterWidth <= 16) { + blockDimX = 16; + blockDimY = 16; + } else { + blockDimX = 32; + blockDimY = 32; + } + + int blockDimZ = 1024 / blockDimX / blockDimY; + dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels)); + dim3 grid(outputWidth, outputHeight); + im2colOCF<<< grid, threads, 0, STREAM_DEFAULT >>> + (imData, colData, inputChannels, inputHeight, inputWidth, + filterHeight, filterWidth, strideHeight, strideWidth, + paddingHeight, paddingWidth, outputHeight, outputWidth); + CHECK_SYNC("Im2ColFunctor GPU failed"); + } +}; + +template +__global__ +void col2imOCF(T* imData, const T* colData, + int inputChannels, + int inputHeight, int inputWidth, + int filterHeight, int filterWidth, + int strideHeight, int strideWidth, + int paddingHeight, int paddingWidth, + int outputHeight, int outputWidth) { + int swId = blockIdx.x; + int shId = blockIdx.y; + for (int channelId = threadIdx.z; + channelId < inputChannels; + channelId += blockDim.z) { + for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { + for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { + int widthOffset = idx + swId * strideWidth - paddingWidth; + int heightOffset = idy + shId * strideHeight - paddingHeight; + int imOffset = widthOffset + heightOffset * inputWidth + + channelId * inputHeight * inputWidth; + + int colOffset = idx + idy * filterWidth + + channelId * filterHeight * filterWidth + + (shId * outputWidth + swId) + * (inputChannels * filterHeight * filterWidth); + + if (heightOffset >= 0 && heightOffset < inputHeight && + widthOffset >= 0 && widthOffset < inputWidth) { + paddle::paddleAtomicAdd(imData + imOffset, colData[colOffset]); + } + } + } + } +} + +/* + * imShape = [inputChannels, inputHeight, inputWidth] + * colShape = + * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] + */ +template +class Col2ImFunctor { +public: + void operator()(T* imData, + const TensorShape& imShape, + const T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[3]; + int filterWidth = colShape[4]; + int outputHeight = colShape[0]; + int outputWidth = colShape[1]; + + int blockDimX = 0; + int blockDimY = 0; + if (filterHeight <= 4 && filterWidth <= 4) { + blockDimX = 4; + blockDimY = 4; + } else if (filterHeight <= 8 && filterWidth <= 8) { + blockDimX = 8; + blockDimY = 8; + } else if (filterHeight <= 16 && filterWidth <= 16) { + blockDimX = 16; + blockDimY = 16; + } else { + blockDimX = 32; + blockDimY = 32; + } + + int blockDimZ = 1024 / blockDimX / blockDimY; + dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels)); + dim3 grid(outputWidth, outputHeight); + col2imOCF<<< grid, threads, 0, STREAM_DEFAULT >>> + (imData, colData, inputChannels, inputHeight, inputWidth, + filterHeight, filterWidth, strideHeight, strideWidth, + paddingHeight, paddingWidth, outputHeight, outputWidth); + CHECK_SYNC("Col2ImFunctor GPU failed"); + } +}; + +template class Im2ColFunctor; +template class Im2ColFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; + +} // namespace paddle diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp new file mode 100644 index 0000000000000..acc88a553abe7 --- /dev/null +++ b/paddle/function/Im2ColTest.cpp @@ -0,0 +1,125 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Im2Col.h" +#include +#include "Function.h" +#include "paddle/math/Matrix.h" +#include "paddle/math/tests/TensorCheck.h" + +namespace paddle { + +template +void TestIm2ColFunctor() { + for (size_t channels : {1, 5, 32}) { + for (size_t inputHeight : {5, 33, 100}) { + for (size_t inputWidth : {5, 32, 96}) { + for (size_t filterHeight : {1, 5}) { + for (size_t filterWidth : {3, 7}) { + for (size_t stride : {1, 2}) { + for (size_t padding : {0, 1}) { + if (inputHeight <= filterHeight || inputWidth <= filterWidth) + break; + if (padding >= filterHeight || padding >= filterWidth) break; + size_t outputHeight = + (inputHeight - filterHeight + 2 * padding + stride) / + stride; + size_t outputWidth = + (inputWidth - filterWidth + 2 * padding + stride) / stride; + + TensorShape imShape = + TensorShape({channels, inputHeight, inputWidth}); + TensorShape colShape1 = TensorShape({channels, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + TensorShape colShape2 = TensorShape({outputHeight, + outputWidth, + channels, + filterHeight, + filterWidth}); + + size_t height = channels * filterHeight * filterWidth; + size_t width = outputHeight * outputWidth; + VectorPtr input1 = Vector::create(imShape.getElements(), false); + VectorPtr input2 = Vector::create(imShape.getElements(), false); + MatrixPtr output1 = Matrix::create(height, width, false, false); + MatrixPtr output2 = Matrix::create(width, height, false, false); + input1->uniform(0.001, 1); + input2->copyFrom(*input1); + + Im2ColFunctor im2Col1; + Im2ColFunctor im2Col2; + im2Col1(input1->getData(), + imShape, + output1->getData(), + colShape1, + stride, + stride, + padding, + padding); + im2Col2(input2->getData(), + imShape, + output2->getData(), + colShape2, + stride, + stride, + padding, + padding); + + // The transposition of the result of ColFormat == kCFO + // is equal to the result of ColFormat == kOCF. + MatrixPtr test; + output2->transpose(test, true); + autotest::TensorCheckErr(*output1, *test); + + Col2ImFunctor col2Im1; + Col2ImFunctor col2Im2; + col2Im1(input1->getData(), + imShape, + output1->getData(), + colShape1, + stride, + stride, + padding, + padding); + col2Im2(input2->getData(), + imShape, + output2->getData(), + colShape2, + stride, + stride, + padding, + padding); + + autotest::TensorCheckErr(*input1, *input2); + } + } + } + } + } + } + } +} + +TEST(Im2ColFunctor, CPU) { TestIm2ColFunctor(); } + +#ifndef PADDLE_ONLY_CPU + +TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor(); } + +#endif + +} // namespace paddle diff --git a/paddle/gserver/layers/BlockExpandLayer.cpp b/paddle/gserver/layers/BlockExpandLayer.cpp index 2bafeb92158c5..3b1f346359172 100644 --- a/paddle/gserver/layers/BlockExpandLayer.cpp +++ b/paddle/gserver/layers/BlockExpandLayer.cpp @@ -37,6 +37,22 @@ bool BlockExpandLayer::init(const LayerMap& layerMap, imgSizeH_ = blockConf.img_size_y(); imgSizeW_ = blockConf.img_size_x(); + std::vector strides = {(size_t)strideH_, (size_t)strideW_}; + std::vector paddings = {(size_t)paddingH_, (size_t)paddingW_}; + std::vector blocks = {(size_t)blockH_, (size_t)blockW_}; + createFunction(forward_, + "BlockExpand", + FuncConfig() + .set("strides", strides) + .set("paddings", paddings) + .set("blocks", blocks)); + createFunction(backward_, + "BlockExpandGrad", + FuncConfig() + .set("strides", strides) + .set("paddings", paddings) + .set("blocks", blocks)); + return true; } @@ -63,48 +79,27 @@ void BlockExpandLayer::forward(PassType passType) { Layer::forward(passType); size_t batchSize = inputLayers_[0]->getOutputValue()->getHeight(); - size_t blockNum = getBlockNum(); size_t blockSize = blockH_ * blockW_ * channels_; resetOutput(blockNum * batchSize, blockSize); - Argument& out = getOutput(); - MatrixPtr outV = getOutputValue(); - MatrixPtr input = getPrev(0)->getOutputValue(); - Matrix::resizeOrCreate(outVTrans_, blockSize, blockNum, false, useGpu_); + // calculate output_.value + inputShape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_}); + outputShape_ = TensorShape({batchSize, blockNum, blockSize}); + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), inputShape_); + outputs.addArg(*getOutputValue(), outputShape_, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); + + // calculate output_.sequenceStartPositions and output_.cpuSequenceDims + Argument& out = getOutput(); ICpuGpuVector::resizeOrCreate( out.sequenceStartPositions, batchSize + 1, false); IVector::resizeOrCreate(out.cpuSequenceDims, 2 * batchSize, false); int* start = out.sequenceStartPositions->getMutableData(false); int* dims = out.cpuSequenceDims->getData(); for (size_t i = 0; i < batchSize; i++) { - outVTrans_->zeroMem(); - /* expand each block as one row */ - MatrixPtr inputTmp = - Matrix::create(input->getData() + i * input->getWidth(), - 1, - input->getWidth(), - false, - useGpu_); - outVTrans_->convExpand(*inputTmp, - imgSizeH_, - imgSizeW_, - channels_, - blockH_, - blockW_, - strideH_, - strideW_, - paddingH_, - paddingW_, - outputH_, - outputW_); - MatrixPtr outVTmp = - Matrix::create(outV->getData() + i * blockNum * blockSize, - blockNum, - blockSize, - false, - useGpu_); - outVTrans_->transpose(outVTmp, false); start[i] = i * blockNum; dims[2 * i] = outputH_; dims[2 * i + 1] = outputW_; @@ -113,48 +108,13 @@ void BlockExpandLayer::forward(PassType passType) { } void BlockExpandLayer::backward(const UpdateCallback& callback) { - size_t blockNum = outputH_ * outputW_; - size_t blockSize = blockH_ * blockW_ * channels_; /* Calculate the input layers error */ - MatrixPtr preGrad = inputLayers_[0]->getOutputGrad(); - if (!preGrad) { - return; - } - MatrixPtr grad = getOutputGrad(); - MatrixPtr gradTrans = Matrix::create(blockSize, blockNum, false, useGpu_); - size_t batchSize = preGrad->getHeight(); - - CHECK_EQ(batchSize * blockNum, grad->getHeight()); - CHECK_EQ(blockSize, grad->getWidth()); - - for (size_t i = 0; i < batchSize; i++) { - MatrixPtr gradTmp = - Matrix::create(grad->getData() + i * blockNum * blockSize, - blockNum, - blockSize, - false, - useGpu_); - gradTmp->transpose(gradTrans, false); - MatrixPtr preGradTmp = - Matrix::create(preGrad->getData() + i * preGrad->getWidth(), - 1, - preGrad->getWidth(), - false, - useGpu_); - preGradTmp->convShrink(*gradTrans, - imgSizeH_, - imgSizeW_, - channels_, - blockH_, - blockW_, - strideH_, - strideW_, - paddingH_, - paddingW_, - outputH_, - outputW_, - 1.0, - 1.0); + if (getInputGrad(0)) { + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outputShape_); + outputs.addArg(*getInputGrad(0), inputShape_, ADD_TO); + backward_[0]->calc(inputs, outputs); } } diff --git a/paddle/gserver/layers/BlockExpandLayer.h b/paddle/gserver/layers/BlockExpandLayer.h index 8f347400e60ec..15ce73ab8b2ca 100644 --- a/paddle/gserver/layers/BlockExpandLayer.h +++ b/paddle/gserver/layers/BlockExpandLayer.h @@ -50,8 +50,8 @@ class BlockExpandLayer : public Layer { size_t blockH_, blockW_, strideH_, strideW_, paddingH_, paddingW_; size_t imgSizeH_, imgSizeW_, outputH_, outputW_, channels_; - /// auxiliary variable, which saves the transposed output value. - MatrixPtr outVTrans_; + TensorShape inputShape_; + TensorShape outputShape_; public: explicit BlockExpandLayer(const LayerConfig& config) : Layer(config) {} diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 4431d613f655c..27f7d95b752d4 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1016,81 +1016,6 @@ void GpuMatrix::check(std::ostream& os, Matrix& refMat, bool printDiff) { LOG(INFO) << "the diffCnt is " << diffCnt; } -void GpuMatrix::convExpand(Matrix& feature, - int feaImgHeight, - int feaImgWidth, - int channels, - int blockH, - int blockW, - int strideH, - int strideW, - int paddingH, - int paddingW, - int outputH, - int outputW) { - CHECK(feature.useGpu_ == true) << "Matrix type are not equal"; - - CHECK_EQ(size_t(feaImgHeight * feaImgWidth * channels), - feature.getHeight() * feature.getWidth()) - << "Matrix dimensions are not equal"; - - size_t elemCnt = outputH * outputW * blockH * blockW * channels; - CHECK_EQ(elemCnt, height_ * width_) << "Matrix dimensions are not equal"; - - hl_expand_feature2col(feature.getData(), - channels, - feaImgHeight, - feaImgWidth, - blockH, - blockW, - strideH, - strideW, - paddingH, - paddingW, - outputH, - outputW, - getData()); -} - -void GpuMatrix::convShrink(Matrix& expandFeat, - int thisImgHeight, - int thisImgWidth, - int channels, - int blockH, - int blockW, - int strideH, - int strideW, - int paddingH, - int paddingW, - int outputH, - int outputW, - real alpha, - real beta) { - CHECK(expandFeat.useGpu_ == true) << "Matrix type are not equal"; - CHECK_EQ(size_t(thisImgHeight * thisImgWidth * channels), - getHeight() * getWidth()) - << "Matrix dimensions are not equal"; - - size_t elemCnt = outputH * outputW * blockW * blockH * channels; - CHECK(elemCnt == expandFeat.getHeight() * expandFeat.getWidth()) - << "Matrix dimensions are not equal"; - hl_shrink_col2feature(expandFeat.getData(), - channels, - thisImgHeight, - thisImgWidth, - blockH, - blockW, - strideH, - strideW, - paddingH, - paddingW, - outputH, - outputW, - getData(), - alpha, - beta); -} - void GpuMatrix::maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW, @@ -1777,103 +1702,6 @@ void CpuMatrix::inverse(MatrixPtr& matInv, bool memAlloc) { CHECK_EQ(info, 0); } -void CpuMatrix::convExpand(Matrix& feature, - int feaImgHeight, - int feaImgWidth, - int channels, - int blockH, - int blockW, - int strideH, - int strideW, - int paddingH, - int paddingW, - int outputH, - int outputW) { - CHECK(feature.useGpu_ == false) << "Matrix type are not equal"; - - CHECK_EQ(size_t(feaImgHeight * feaImgWidth * channels), - feature.getHeight() * feature.getWidth()) - << "Matrix dimensions are not equal"; - - size_t elemCnt = outputH * outputW * blockH * blockW * channels; - CHECK_EQ(elemCnt, height_ * width_) << "Matrix dimensions are not equal"; - - int channelsCol = channels * blockH * blockW; - real* srcData = feature.getData(); - for (int c = 0; c < channelsCol; ++c) { - int wOffset = c % blockW; - int hOffset = (c / blockW) % blockH; - int c_im = c / blockH / blockW; - for (int h = 0; h < outputH; ++h) { - for (int w = 0; w < outputW; ++w) { - // no c_im*height to Exclude the channel number - int imgRowIdx = h * strideH + hOffset; - int imgColIdx = w * strideW + wOffset; - if ((imgRowIdx - paddingH) < 0 || - (imgRowIdx - paddingH) >= feaImgHeight || - (imgColIdx - paddingW) < 0 || - (imgColIdx - paddingW) >= feaImgWidth) { - data_[(c * outputH + h) * outputW + w] = 0; - } else { - imgRowIdx += c_im * feaImgHeight - paddingH; - imgColIdx -= paddingW; - data_[(c * outputH + h) * outputW + w] = - srcData[imgRowIdx * feaImgWidth + imgColIdx]; - } - } - } - } -} - -void CpuMatrix::convShrink(Matrix& expandFeat, - int thisImgHeight, - int thisImgWidth, - int channels, - int blockH, - int blockW, - int strideH, - int strideW, - int paddingH, - int paddingW, - int outputH, - int outputW, - real alpha, - real beta) { - CHECK(expandFeat.useGpu_ == false) << "Matrix type are not equal"; - CHECK_EQ(size_t(thisImgHeight * thisImgWidth * channels), - getHeight() * getWidth()) - << "Matrix dimensions are not equal"; - - size_t elemCnt = outputH * outputW * blockH * blockW * channels; - - CHECK(elemCnt == expandFeat.getHeight() * expandFeat.getWidth()) - << "Matrix dimensions are not equal"; - - real* expandData = expandFeat.getData(); - int channelsCol = channels * blockH * blockW; - for (int c = 0; c < channelsCol; ++c) { - int wOffset = c % blockW; - int hOffset = (c / blockW) % blockH; - int c_im = c / blockW / blockH; - for (int h = 0; h < outputH; ++h) { - for (int w = 0; w < outputW; ++w) { - int imRowIdx = h * strideH + hOffset; - int imColIdx = w * strideW + wOffset; - if ((imRowIdx - paddingH) >= 0 && - (imRowIdx - paddingH) < thisImgHeight && - (imColIdx - paddingW) >= 0 && - (imColIdx - paddingW) < thisImgWidth) { - imRowIdx += c_im * thisImgHeight - paddingH; - imColIdx -= paddingW; - data_[imRowIdx * thisImgWidth + imColIdx] = - alpha * expandData[(c * outputH + h) * outputW + w] + - beta * data_[imRowIdx * thisImgWidth + imColIdx]; - } - } - } - } -} - void CpuMatrix::maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW, diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 7dfd593225065..bb802bbb2c752 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -859,49 +859,6 @@ class Matrix : public BaseMatrix { LOG(FATAL) << "Not implemented"; } - /** - * This function is used to calculate the convolution: - * - * It will expand a feature matrix according to the - * convolution filters - */ - virtual void convExpand(Matrix& feature, - int feaImgHeight, - int feaImgWidth, - int channels, - int blockH, - int blockW, - int strideH, - int strideW, - int paddingH, - int paddingW, - int outputH, - int outputW) { - LOG(FATAL) << "Not implemeted"; - } - - /** - * This function is the reverse implementation of convExpand: - * - * Its function is to restore a expanded-matrix into a feature matrix - */ - virtual void convShrink(Matrix& expandColMat, - int thisImgHeight, - int thisImgWidth, - int channels, - int blockH, - int blockW, - int strideH, - int strideW, - int paddingH, - int paddingW, - int outputH, - int outputW, - real alpha = 1.0f, - real beta = 0.0f) { - LOG(FATAL) << "Not implemeted"; - } - /** * Pooling forward operation, pick out the largest element * in the sizeX of value @@ -1335,34 +1292,6 @@ class GpuMatrix : public Matrix { void classificationError(Matrix& output, IVector& label, size_t topkSize = 1); - void convExpand(Matrix& feature, - int feaImgHeight, - int feaImgWidth, - int channels, - int blockH, - int blockW, - int strideH, - int strideW, - int paddingH, - int paddingW, - int outputH, - int outputW); - - void convShrink(Matrix& expandColMat, - int thisImgHeight, - int thisImgWidth, - int channels, - int blockH, - int blochW, - int strideH, - int strideW, - int paddingH, - int paddingWreal, - int outputH, - int outputW, - real alpha = 1.0f, - real beta = 0.0f); - void maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW, @@ -1522,34 +1451,6 @@ class CpuMatrix : public Matrix { MatrixPtr clone(size_t height, size_t width, bool useGpu = false); - void convExpand(Matrix& feature, - int feaImgHeight, - int feaImgWidth, - int channels, - int blcokH, - int blockW, - int strideH, - int strideW, - int paddingH, - int paddingW, - int outputH, - int outputW); - - void convShrink(Matrix& expandFeat, - int thisImgHeight, - int thisImgWidth, - int channels, - int blockH, - int blockW, - int strideH, - int strideW, - int paddingH, - int paddingW, - int outputH, - int outputW, - real alpha = 1.0f, - real beta = 0.0f); - void maxPoolForward(Matrix& inputMat, size_t imgSizeH, size_t imgSizeW,