diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 1f54ac1231c6a..5e170714cf5b1 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -14,8 +14,8 @@ add_library(paddle_function STATIC ${cpp_files} ${cu_objs}) add_dependencies(paddle_function ${external_project_dependencies}) add_dependencies(paddle_function gen_proto_cpp) -if(WITH_GPU) if(WITH_TESTING) +if(WITH_GPU) # TODO: # file(GLOB test_files . *OpTest.cpp) # add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files}) @@ -30,6 +30,8 @@ if(WITH_TESTING) add_simple_unittest(CosSimOpTest) add_simple_unittest(RowConvOpTest) endif() + +add_simple_unittest(ConvOpTest) endif() add_style_check_target(paddle_function ${h_files}) diff --git a/paddle/function/ContextProjectionOpTest.cpp b/paddle/function/ContextProjectionOpTest.cpp index 1b25172ca5c0c..9e9dd20e6f3ab 100644 --- a/paddle/function/ContextProjectionOpTest.cpp +++ b/paddle/function/ContextProjectionOpTest.cpp @@ -28,7 +28,7 @@ void testMatrixProjectionForward(int context_start, std::max(0, (int)(context_start + context_length - 1)); if (pad == 0) is_padding = false; - FunctionCompare test( + CpuGpuFuncCompare test( "ContextProjectionForward", FuncConfig() .set("context_length", context_length) @@ -60,7 +60,7 @@ void testMatrixProjectionBackward(int context_start, std::max(0, (int)(context_start + context_length - 1)); if (pad == 0) is_padding = false; - FunctionCompare test( + CpuGpuFuncCompare test( "ContextProjectionBackward", FuncConfig() .set("context_length", context_length) diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h new file mode 100644 index 0000000000000..65b9d1d53f921 --- /dev/null +++ b/paddle/function/ConvOp.h @@ -0,0 +1,146 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +/* + * \brief Based on the ConvFunctionBase class, the forward calculation, + * backward input calculation and backward filter calculation + * of convolution operations can be implemented. + * + * Arguments of forward and backward calculation: + * 1. Forward calculation of convolution. + * inputs = {INPUT, FILTER}, outputs = {OUTPUT} + * The first and second input arguments are input image and filter data. + * The output argument is output image. + * + * 2. Backward input calculation of convolution. + * inputs = {OUTPUT_GRAD, FILTER}, outputs = {INPUT_GRAD} + * The first and second input arguments are output grad image + * and filter data. + * The output argument is input grad image. + * + * 3. Backward filter calculation of convolution. + * inputs = {OUTPUT_GRAD, INPUT}, outputs = {FILTER_GRAD} + * The first and second input arguments are output grad image + * and input image. + * The output argument is filter grad. + * + * Arguments format of input, filter and output: + * 1. Input image, output image, input image gradient, output image gradient + * are all NCHW format. Where N is batch size, C is the number of channels, + * H and W is the height and width of image or image gradient. + * + * 2. The format of the filter data is MCHW, where M is the number of output + * image channels, C is the number of input image channels, + * H and W is height and width of filter. + * + * If `groups` is greater than 1, the filter's data format should be GMCHW, + * where G is the `groups`, and G * M is the number of output image + * channels, G * C is the number of input image channels, + * H and W is height and width of filter. + */ +class ConvFunctionBase : public FunctionBase { +public: + void init(const FuncConfig& config) override { + // function arguments + strides_ = config.get>("strides"); + paddings_ = config.get>("paddings"); + groups_ = config.get("groups"); + + // number of inputs and outputs + numInputs_ = 2; + numOutputs_ = 1; + } + + virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} + + // input can be INPUT and INPUT_GRAD + // filter can be FILTER and FILTER_GRAD + // output can be OUTPUT and OUTPUT_GRAD + void check(const TensorShape& input, + const TensorShape& filter, + const TensorShape& output) { + // inputs and outputs arguments should be 4-dimensional. + CHECK_EQ(input.ndims(), (size_t)4); + CHECK_EQ(output.ndims(), (size_t)4); + // The batchSize of the input needs to be equal to + // the batchSize of the output. + CHECK_EQ(input[0], output[0]); + + if (filter.ndims() == (size_t)4) { + // If the filter's dimension is 4, groups convolution is not supported. + CHECK_EQ(groups_, (size_t)1); + // The input and output channel dimensions are the second and first + // dimensions of the filter shape. + CHECK_EQ(input[1], filter[1]); + CHECK_EQ(output[1], filter[0]); + } else { + // filter argument should be 5-dimensional. + CHECK_EQ(filter.ndims(), (size_t)5); + // The first dimension of the filter is the size of the group + CHECK_EQ(filter[0], groups_); + // The input and output channel dimensions are the third and second + // dimensions of the filter shape. + CHECK_EQ(input[1], filter[2] * groups_); + CHECK_EQ(output[1], filter[1] * groups_); + } + } + +protected: + size_t getFilterHeight(const TensorShape& filter) const { + return filter[filter.ndims() - 2]; + } + + size_t getFilterWidth(const TensorShape& filter) const { + return filter[filter.ndims() - 1]; + } + + std::vector strides_; + std::vector paddings_; + + /// Group size, refer to grouped convolution in + /// Alex Krizhevsky's paper: when group=2, the first half of the + /// filters are only connected to the first half of the input channels, + /// and the second half only connected to the second half. + size_t groups_; + + inline int strideH() const { return strides_[0]; } + + inline int strideW() const { return strides_[1]; } + + inline int paddingH() const { return paddings_[0]; } + + inline int paddingW() const { return paddings_[1]; } + + // A temporary memory in convolution calculation. + MemoryHandlePtr memory_; + + template + void resizeBuffer(size_t newSize) { + if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) { + if (Device == DEVICE_TYPE_CPU) { + memory_ = std::make_shared(newSize * sizeof(real)); + } else { + memory_ = std::make_shared(newSize * sizeof(real)); + } + } + } +}; + +} // namespace paddle diff --git a/paddle/function/ConvOpTest.cpp b/paddle/function/ConvOpTest.cpp new file mode 100644 index 0000000000000..dfa2f784610b0 --- /dev/null +++ b/paddle/function/ConvOpTest.cpp @@ -0,0 +1,210 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "Function.h" +#include "FunctionTest.h" + +namespace paddle { + +enum TestType { + kForwardTest = 0, + kBackwardInputTest = 1, + kBackwardFilterTest = 2, +}; + +template +class ConvolutionTest { +public: + ConvolutionTest(const std::string& conv1, + const std::string& conv2, + TestType type, + std::string algo = "auto") { + for (size_t batchSize : {1, 32}) { + for (size_t inputSize : {7, 14, 54}) { + for (size_t filterSize : {1, 3, 5}) { + for (size_t inputChannels : {3, 64}) { + for (size_t outputChannels : {3, 64, 128}) { + if (inputChannels < outputChannels) break; + for (size_t stride : {1, 2}) { + for (size_t padding : {0, 1}) { + if (padding >= filterSize) break; + size_t outputSize = + (inputSize - filterSize + 2 * padding + stride) / stride; + VLOG(3) << " batchSize=" << batchSize + << " inputChannels=" << inputChannels + << " inputHeight=" << inputSize + << " inputWidth=" << inputSize + << " outputChannels=" << outputChannels + << " filterHeight=" << filterSize + << " filterWidth=" << filterSize + << " outputHeight=" << outputSize + << " outputWidth=" << outputSize + << " stride=" << stride << " padding=" << padding; + + std::vector paddings = {padding, padding}; + std::vector strides = {stride, stride}; + Compare2Function test( + conv1, + conv2, + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)1) + .set("algo", algo)); + + TensorShape input{ + batchSize, inputChannels, inputSize, inputSize}; + TensorShape filter{ + outputChannels, inputChannels, filterSize, filterSize}; + TensorShape output{ + batchSize, outputChannels, outputSize, outputSize}; + + if (type == kForwardTest) { + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output)); + test.run(); + } else if (type == kBackwardInputTest) { + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO); + test.run(); + } else if (type == kBackwardFilterTest) { + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.run(); + } + } + } + } + } + } + } + } + } +}; + +// Mainly used to test cases where the height and width (input, filter) +// are not equal. +template +class ConvolutionTest2 { +public: + ConvolutionTest2(const std::string& conv1, + const std::string& conv2, + TestType type, + std::string algo = "auto") { + for (size_t batchSize : {16}) { + for (size_t inputHeight : {7, 31}) { + for (size_t inputWidth : {10, 54}) { + for (size_t filterHeight : {1, 5}) { + for (size_t filterWidth : {3, 7}) { + for (size_t inputChannels : {7}) { + for (size_t outputChannels : {32}) { + size_t stride = 1; + size_t padding = 0; + size_t outputHeight = + (inputHeight - filterHeight + 2 * padding + stride) / + stride; + size_t outputWidth = + (inputWidth - filterWidth + 2 * padding + stride) / + stride; + VLOG(3) << " batchSize=" << batchSize + << " inputChannels=" << inputChannels + << " inputHeight=" << inputHeight + << " inputWidth=" << inputWidth + << " outputChannels=" << outputChannels + << " filterHeight=" << filterHeight + << " filterWidth=" << filterWidth + << " outputHeight=" << outputHeight + << " outputWidth=" << outputWidth + << " stride=" << stride << " padding=" << padding; + + std::vector paddings = {padding, padding}; + std::vector strides = {stride, stride}; + Compare2Function test( + conv1, + conv2, + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)1) + .set("algo", algo)); + + TensorShape input{ + batchSize, inputChannels, inputHeight, inputWidth}; + TensorShape filter{ + outputChannels, inputChannels, filterHeight, filterWidth}; + TensorShape output{ + batchSize, outputChannels, outputHeight, outputWidth}; + + if (type == kForwardTest) { + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output)); + test.run(); + } else if (type == kBackwardInputTest) { + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO); + test.run(); + } else if (type == kBackwardFilterTest) { + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.run(); + } + } + } + } + } + } + } + } + } +}; + +TEST(Forward, GEMM) { + ConvolutionTest test( + "NaiveConv-CPU", "GemmConv-CPU", kForwardTest); + ConvolutionTest2 test2( + "NaiveConv-CPU", "GemmConv-CPU", kForwardTest); +} + +#ifndef PADDLE_ONLY_CPU +TEST(Forward, GEMM2) { + ConvolutionTest test( + "GemmConv-CPU", "GemmConv-GPU", kForwardTest); + ConvolutionTest2 test2( + "GemmConv-CPU", "GemmConv-GPU", kForwardTest); +} + +TEST(BackwardInput, GEMM) { + ConvolutionTest test( + "GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest); + ConvolutionTest2 test2( + "GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest); +} + +TEST(BackwardFilter, GEMM) { + ConvolutionTest test( + "GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest); + ConvolutionTest2 test2( + "GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest); +} +#endif + +} // namespace paddle diff --git a/paddle/function/CosSimOpTest.cpp b/paddle/function/CosSimOpTest.cpp index 48c815f027161..f6c0041101f50 100644 --- a/paddle/function/CosSimOpTest.cpp +++ b/paddle/function/CosSimOpTest.cpp @@ -22,7 +22,7 @@ void testCosSimForward(size_t height_x, size_t height_y, size_t width, real scale) { - FunctionCompare test("CosSimForward", FuncConfig().set("scale", scale)); + CpuGpuFuncCompare test("CosSimForward", FuncConfig().set("scale", scale)); // prepare input arguments test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width})); @@ -36,7 +36,7 @@ void testCosSimBackward(size_t height_x, size_t height_y, size_t width, real scale) { - FunctionCompare test("CosSimBackward", FuncConfig().set("scale", scale)); + CpuGpuFuncCompare test("CosSimBackward", FuncConfig().set("scale", scale)); // prepare input arguments test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1})); diff --git a/paddle/function/CrossMapNormalOpTest.cpp b/paddle/function/CrossMapNormalOpTest.cpp index 51f5da81bfc9a..ed17b17da616d 100644 --- a/paddle/function/CrossMapNormalOpTest.cpp +++ b/paddle/function/CrossMapNormalOpTest.cpp @@ -28,11 +28,11 @@ TEST(CrossMapNormal, real) { << " size=" << size; // init Test object - FunctionCompare test("CrossMapNormal", - FuncConfig() - .set("size", size) - .set("scale", (real)1.5) - .set("pow", (real)0.5)); + CpuGpuFuncCompare test("CrossMapNormal", + FuncConfig() + .set("size", size) + .set("scale", (real)1.5) + .set("pow", (real)0.5)); // prepare input arguments TensorShape shape{numSamples, channels, imgSizeH, imgSizeW}; test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape)); @@ -57,11 +57,11 @@ TEST(CrossMapNormalGrad, real) { << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW << " size=" << size; - FunctionCompare test("CrossMapNormalGrad", - FuncConfig() - .set("size", size) - .set("scale", (real)1.5) - .set("pow", (real)0.5)); + CpuGpuFuncCompare test("CrossMapNormalGrad", + FuncConfig() + .set("size", size) + .set("scale", (real)1.5) + .set("pow", (real)0.5)); TensorShape shape{numSamples, channels, imgSizeH, imgSizeW}; test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape)); diff --git a/paddle/function/FunctionTest.h b/paddle/function/FunctionTest.h index 0cfafdb27f55a..ba446bf92da26 100644 --- a/paddle/function/FunctionTest.h +++ b/paddle/function/FunctionTest.h @@ -22,14 +22,62 @@ namespace paddle { typedef std::shared_ptr BufferArgPtr; +namespace test { +template +struct Allocator; + +template <> +struct Allocator { + using type = CpuMemoryHandle; +}; + +template <> +struct Allocator { + using type = GpuMemoryHandle; +}; + +// Copy argument1 to argument2 +template +class CopyArgument { +public: + void operator()(const BufferArg& arg1, BufferArg& arg2) { + CHECK_EQ(arg1.valueType(), arg2.valueType()); + CHECK_LE(arg1.shape().getElements(), arg2.shape().getElements()); + + if (arg1.valueType() == VALUE_TYPE_INT32) { + IVectorPtr vector1 = + IVector::create((int*)arg1.data(), + arg1.shape().getElements(), + DType1 == DEVICE_TYPE_CPU ? false : true); + IVectorPtr vector2 = + IVector::create((int*)arg2.data(), + arg2.shape().getElements(), + DType2 == DEVICE_TYPE_CPU ? false : true); + vector2->copyFrom(*vector1); + } else { + VectorPtr vector1 = + Vector::create((real*)arg1.data(), + arg1.shape().getElements(), + DType1 == DEVICE_TYPE_CPU ? false : true); + VectorPtr vector2 = + Vector::create((real*)arg2.data(), + arg2.shape().getElements(), + DType2 == DEVICE_TYPE_CPU ? false : true); + vector2->copyFrom(*vector1); + } + } +}; +} // namespace test + /** - * \brief A class for comparing CPU and GPU implementations of Function. - * + * \brief A class for comparing two Functions of different implementations. + * For example, can be used to compare the CPU and GPU implementation + * of the function is consistent. * * Use case: * // Initializes a test object, the corresponding cpu and gpu Function * // are constructed according to FunctionName and FuncConfig. - * FunctionCompare test(FunctionName, FuncConfig); + * CpuGpuFuncCompare test(FunctionName, FuncConfig); * // Prepare inputs and outputs arguments. * // Here the input and output can not contain real data, * // only contains the argument type and shape. @@ -45,28 +93,38 @@ typedef std::shared_ptr BufferArgPtr; * // Compares CPU and GPU calculation results for consistency. * test.run(); */ -class FunctionCompare { +template +class Compare2Function { public: - FunctionCompare(const std::string& name, const FuncConfig& config) - : cpuFunc_(FunctionBase::funcRegistrar_.createByType(name + "-CPU")), - gpuFunc_(FunctionBase::funcRegistrar_.createByType(name + "-GPU")) { - cpuFunc_->init(config); - gpuFunc_->init(config); + typedef typename test::Allocator::type Allocator1; + typedef typename test::Allocator::type Allocator2; + typedef typename Tensor::Vector Vector1; + typedef typename Tensor::Vector Vector2; + typedef typename Tensor::SparseMatrix SparseMatrix1; + typedef typename Tensor::SparseMatrix SparseMatrix2; + + Compare2Function(const std::string& name1, + const std::string& name2, + const FuncConfig& config) + : function1_(FunctionBase::funcRegistrar_.createByType(name1)), + function2_(FunctionBase::funcRegistrar_.createByType(name2)) { + function1_->init(config); + function2_->init(config); } - ~FunctionCompare() {} + ~Compare2Function() {} // input need only contains shape, do not contains data. void addInputs(const BufferArg& input) { size_t size = input.shape().getElements() * sizeOfValuType(input.valueType()); - cpuMemory_.emplace_back(std::make_shared(size)); - gpuMemory_.emplace_back(std::make_shared(size)); + func1Memory_.emplace_back(std::make_shared(size)); + func2Memory_.emplace_back(std::make_shared(size)); - cpuInputs_.emplace_back(std::make_shared( - cpuMemory_.back()->getBuf(), input.valueType(), input.shape())); - gpuInputs_.emplace_back(std::make_shared( - gpuMemory_.back()->getBuf(), input.valueType(), input.shape())); + func1Inputs_.emplace_back(std::make_shared( + func1Memory_.back()->getBuf(), input.valueType(), input.shape())); + func2Inputs_.emplace_back(std::make_shared( + func2Memory_.back()->getBuf(), input.valueType(), input.shape())); } // assume one copy of sequence is shared by different SequenceArgs @@ -75,62 +133,57 @@ class FunctionCompare { size_t batchSize = input.shape()[0]; size_t numSeqs = batchSize / 10 + 1; size_t sizeId = (numSeqs + 1) * sizeOfValuType(VALUE_TYPE_INT32); - cpuMemory_.emplace_back(std::make_shared(sizeId)); - gpuMemory_.emplace_back(std::make_shared(sizeId)); - cpuSeq_ = std::make_shared(cpuMemory_.back()->getBuf(), - TensorShape{numSeqs + 1}); - gpuSeq_ = std::make_shared(gpuMemory_.back()->getBuf(), - TensorShape{numSeqs + 1}); + func1Memory_.emplace_back(std::make_shared(sizeId)); + func2Memory_.emplace_back(std::make_shared(sizeId)); + seq1_ = std::make_shared(func1Memory_.back()->getBuf(), + TensorShape{numSeqs + 1}); + seq2_ = std::make_shared(func2Memory_.back()->getBuf(), + TensorShape{numSeqs + 1}); /// init sequence Id - initArg(*cpuSeq_, batchSize); + initArg(*seq1_, batchSize); - // todo(tianbing), delete it - CHECK_EQ(cpuSeq_->shape().getElements(), cpuSeq_->numSeqs() + 1); - - CpuIVector cpuSeq(cpuSeq_->shape().getElements(), (int*)cpuSeq_->data()); - GpuIVector gpuSeq(gpuSeq_->shape().getElements(), (int*)gpuSeq_->data()); - gpuSeq.copyFrom(cpuSeq); + copyArg_(*seq1_, *seq2_); } void addInputs(const SequenceArg& input) { CHECK_EQ(input.shape().ndims(), 2UL); size_t batchSize = input.shape()[0]; - if (!cpuSeq_ || !gpuSeq_) { // sequence not exist + if (!seq1_ || !seq2_) { // sequence not exist addSequence(SequenceIdArg(TensorShape{batchSize})); } size_t size = input.shape().getElements() * sizeOfValuType(input.valueType()); - cpuMemory_.emplace_back(std::make_shared(size)); - gpuMemory_.emplace_back(std::make_shared(size)); + func1Memory_.emplace_back(std::make_shared(size)); + func2Memory_.emplace_back(std::make_shared(size)); /// SequenceArg - cpuInputs_.emplace_back( - std::make_shared(cpuMemory_.back()->getBuf(), + func1Inputs_.emplace_back( + std::make_shared(func1Memory_.back()->getBuf(), input.valueType(), input.shape(), - *cpuSeq_)); - gpuInputs_.emplace_back( - std::make_shared(gpuMemory_.back()->getBuf(), + *seq1_)); + func2Inputs_.emplace_back( + std::make_shared(func2Memory_.back()->getBuf(), input.valueType(), input.shape(), - *gpuSeq_)); + *seq2_)); } // output need only contains shape, do not contains data. void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) { size_t size = output.shape().getElements() * sizeOfValuType(output.valueType()); - cpuMemory_.emplace_back(std::make_shared(size)); - gpuMemory_.emplace_back(std::make_shared(size)); + func1Memory_.emplace_back(std::make_shared(size)); + func2Memory_.emplace_back(std::make_shared(size)); - cpuOutputs_.emplace_back( - std::make_shared(cpuMemory_.back()->getBuf(), + func1Outputs_.emplace_back( + std::make_shared(func1Memory_.back()->getBuf(), output.valueType(), output.shape(), argType)); - gpuOutputs_.emplace_back( - std::make_shared(gpuMemory_.back()->getBuf(), + func2Outputs_.emplace_back( + std::make_shared(func2Memory_.back()->getBuf(), output.valueType(), output.shape(), argType)); @@ -138,14 +191,14 @@ class FunctionCompare { /// add and init output sparse matrix void addOutputs(const SparseMatrixArg& output, ArgType argType = ASSIGN_TO) { - cpuSparse_ = std::make_shared( + sparse1_ = std::make_shared( output.shape()[0], output.shape()[1], output.nnz(), static_cast(output.dataType()), static_cast(output.dataFormat())); - gpuSparse_ = std::make_shared( + sparse2_ = std::make_shared( output.shape()[0], output.shape()[1], output.nnz(), @@ -154,52 +207,52 @@ class FunctionCompare { /// init sparse matrix hl_stream_t stream(HPPL_STREAM_1); - cpuSparse_->randomizeUniform(); - gpuSparse_->copyFrom(*cpuSparse_, stream); + sparse1_->randomizeUniform(); + sparse2_->copyFrom(*sparse1_, stream); hl_stream_synchronize(stream); - cpuOutputs_.emplace_back( - std::make_shared(*cpuSparse_, argType)); - gpuOutputs_.emplace_back( - std::make_shared(*gpuSparse_, argType)); + func1Outputs_.emplace_back( + std::make_shared(*sparse1_, argType)); + func2Outputs_.emplace_back( + std::make_shared(*sparse2_, argType)); } void addOutputs(const SequenceArg& output, ArgType argType = ASSIGN_TO) { CHECK_EQ(output.shape().ndims(), 2UL); size_t batchSize = output.shape()[0]; - if (!cpuSeq_ || !gpuSeq_) { // sequence not exist + if (!seq1_ || !seq2_) { // sequence not exist addSequence(SequenceIdArg(TensorShape{batchSize})); } size_t size = output.shape().getElements() * sizeOfValuType(output.valueType()); - cpuMemory_.emplace_back(std::make_shared(size)); - gpuMemory_.emplace_back(std::make_shared(size)); + func1Memory_.emplace_back(std::make_shared(size)); + func2Memory_.emplace_back(std::make_shared(size)); /// SequenceArg - cpuOutputs_.emplace_back( - std::make_shared(cpuMemory_.back()->getBuf(), + func1Outputs_.emplace_back( + std::make_shared(func1Memory_.back()->getBuf(), output.valueType(), output.shape(), - *cpuSeq_, + *seq1_, argType)); - gpuOutputs_.emplace_back( - std::make_shared(gpuMemory_.back()->getBuf(), + func2Outputs_.emplace_back( + std::make_shared(func2Memory_.back()->getBuf(), output.valueType(), output.shape(), - *gpuSeq_, + *seq2_, argType)); } void addInputs(const SparseMatrixArg& input) { - cpuSparse_ = std::make_shared( + sparse1_ = std::make_shared( input.shape()[0], input.shape()[1], input.nnz(), static_cast(input.dataType()), static_cast(input.dataFormat())); - gpuSparse_ = std::make_shared( + sparse2_ = std::make_shared( input.shape()[0], input.shape()[1], input.nnz(), @@ -208,12 +261,12 @@ class FunctionCompare { /// init sparse matrix hl_stream_t stream(HPPL_STREAM_1); - cpuSparse_->randomizeUniform(); - gpuSparse_->copyFrom(*cpuSparse_, stream); + sparse1_->randomizeUniform(); + sparse2_->copyFrom(*sparse1_, stream); hl_stream_synchronize(stream); - cpuInputs_.emplace_back(std::make_shared(*cpuSparse_)); - gpuInputs_.emplace_back(std::make_shared(*gpuSparse_)); + func1Inputs_.emplace_back(std::make_shared(*sparse1_)); + func2Inputs_.emplace_back(std::make_shared(*sparse2_)); } void run() { @@ -236,27 +289,27 @@ class FunctionCompare { function->calc(inArgs, outArgs); }; - callFunction(cpuFunc_.get(), cpuInputs_, cpuOutputs_); - callFunction(gpuFunc_.get(), gpuInputs_, gpuOutputs_); + callFunction(function1_.get(), func1Inputs_, func1Outputs_); + callFunction(function2_.get(), func2Inputs_, func2Outputs_); // check outputs compareOutputs(); } - std::shared_ptr getCpuFunction() const { return cpuFunc_; } + std::shared_ptr getFunction1() const { return function1_; } - std::shared_ptr getGpuFunction() const { return gpuFunc_; } + std::shared_ptr getFunction2() const { return function2_; } protected: // only init cpu argument, gpu argument copy from cpu argument. void initArg(BufferArg& arg) { - CpuVector vector(arg.shape().getElements(), (real*)arg.data()); + Vector1 vector(arg.shape().getElements(), (real*)arg.data()); vector.uniform(0.001, 1); } void initArg(SequenceArg& arg) { /// init only matrix - CpuVector vector(arg.shape().getElements(), (real*)arg.data()); + Vector1 vector(arg.shape().getElements(), (real*)arg.data()); vector.uniform(0.001, 1); } @@ -276,73 +329,72 @@ class FunctionCompare { } void initInputs() { - for (size_t i = 0; i < cpuInputs_.size(); i++) { - if (cpuInputs_[i]->isSparseArg()) { + for (size_t i = 0; i < func1Inputs_.size(); i++) { + if (func1Inputs_[i]->isSparseArg()) { continue; /// sparse matrix already init } - if (cpuInputs_[i]->isSequenceArg()) { - initArg(dynamic_cast(*cpuInputs_[i])); + if (func1Inputs_[i]->isSequenceArg()) { + initArg(dynamic_cast(*func1Inputs_[i])); } else { - initArg(*cpuInputs_[i]); + initArg(*func1Inputs_[i]); } - // TODO: Need a BufferCopy used to copy from one BufferArg to another. - CpuVector cpuVector(cpuInputs_[i]->shape().getElements(), - (real*)cpuInputs_[i]->data()); - GpuVector gpuVector(gpuInputs_[i]->shape().getElements(), - (real*)gpuInputs_[i]->data()); - gpuVector.copyFrom(cpuVector); + copyArg_(*func1Inputs_[i], *func2Inputs_[i]); } } void initOutputs() { - for (size_t i = 0; i < cpuOutputs_.size(); i++) { - if (cpuOutputs_[i]->isSparseArg()) { + for (size_t i = 0; i < func1Outputs_.size(); i++) { + if (func1Outputs_[i]->isSparseArg()) { continue; /// sparse matrix already init } - if (cpuOutputs_[i]->isSequenceArg()) { - initArg(dynamic_cast(*cpuOutputs_[i])); + if (func1Outputs_[i]->isSequenceArg()) { + initArg(dynamic_cast(*func1Outputs_[i])); } else { - initArg(*cpuOutputs_[i]); + initArg(*func1Outputs_[i]); } - // TODO: Need a BufferCopy used to copy from one BufferArg to another. - CpuVector cpuVector(cpuOutputs_[i]->shape().getElements(), - (real*)cpuOutputs_[i]->data()); - GpuVector gpuVector(gpuOutputs_[i]->shape().getElements(), - (real*)gpuOutputs_[i]->data()); - - gpuVector.copyFrom(cpuVector); + copyArg_(*func1Outputs_[i], *func2Outputs_[i]); } } void compareOutputs() { - for (size_t i = 0; i < cpuOutputs_.size(); i++) { + for (size_t i = 0; i < func1Outputs_.size(); i++) { // TODO, Need a BufferCheck used to compare the two buffers. - const auto cpu = cpuOutputs_[i]; - const auto gpu = gpuOutputs_[i]; + const auto cpu = func1Outputs_[i]; + const auto gpu = func2Outputs_[i]; CHECK_EQ(cpu->numElements(), gpu->numElements()); - CpuVector cpuVector(cpu->numElements(), (real*)cpu->data()); - GpuVector gpuVector(gpu->numElements(), (real*)gpu->data()); + Vector1 cpuVector(cpu->numElements(), (real*)cpu->data()); + Vector2 gpuVector(gpu->numElements(), (real*)gpu->data()); autotest::TensorCheckErr(cpuVector, gpuVector); } } protected: - std::shared_ptr cpuFunc_; - std::shared_ptr gpuFunc_; - std::vector cpuMemory_; - std::vector gpuMemory_; - std::vector cpuInputs_; - std::vector cpuOutputs_; - std::vector gpuInputs_; - std::vector gpuOutputs_; - std::shared_ptr cpuSparse_; - std::shared_ptr gpuSparse_; - std::shared_ptr cpuSeq_; - std::shared_ptr gpuSeq_; + std::shared_ptr function1_; + std::shared_ptr function2_; + std::vector> func1Memory_; + std::vector> func2Memory_; + std::vector func1Inputs_; + std::vector func1Outputs_; + std::vector func2Inputs_; + std::vector func2Outputs_; + std::shared_ptr sparse1_; + std::shared_ptr sparse2_; + std::shared_ptr seq1_; + std::shared_ptr seq2_; + test::CopyArgument copyArg_; +}; + +class CpuGpuFuncCompare + : public Compare2Function { +public: + CpuGpuFuncCompare(const std::string& name, const FuncConfig& config) + : Compare2Function(name + "-CPU", name + "-GPU", config) {} + + ~CpuGpuFuncCompare() {} }; } // namespace paddle diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp new file mode 100644 index 0000000000000..c7a57801ed609 --- /dev/null +++ b/paddle/function/GemmConvOp.cpp @@ -0,0 +1,386 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "GemmConvOp.h" +#include "GemmFunctor.h" +#include "paddle/math/MemoryHandle.h" + +namespace paddle { + +/* + * imData = [input_channels, input_height, input_width] + * colData = [input_channels, filter_height, filter_width, + * output_height, output_width] + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* colData) { + int channelsCol = inputChannels * filterHeight * filterWidth; + + for (int c = 0; c < channelsCol; ++c) { + int wOffset = c % filterWidth; + int hOffset = (c / filterWidth) % filterHeight; + int c_im = c / filterWidth / filterHeight; + for (int h = 0; h < outputHeight; ++h) { + for (int w = 0; w < outputWidth; ++w) { + int imRowIdx = h * strideHeight + hOffset; + int imColIdx = w * strideWidth + wOffset; + if ((imRowIdx - paddingHeight) < 0 || + (imRowIdx - paddingHeight) >= inputHeight || + (imColIdx - paddingWidth) < 0 || + (imColIdx - paddingWidth) >= inputWidth) { + colData[(c * outputHeight + h) * outputWidth + w] = T(0); + } else { + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; + colData[(c * outputHeight + h) * outputWidth + w] = + imData[imRowIdx * inputWidth + imColIdx]; + } + } + } + } + } +}; + +template +class Col2ImFunctor { +public: + void operator()(const T* colData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* imData) { + int channelsCol = inputChannels * filterHeight * filterWidth; + + for (int c = 0; c < channelsCol; ++c) { + int wOffset = c % filterWidth; + int hOffset = (c / filterWidth) % filterHeight; + int c_im = c / filterWidth / filterHeight; + for (int h = 0; h < outputHeight; ++h) { + for (int w = 0; w < outputWidth; ++w) { + int imRowIdx = h * strideHeight + hOffset; + int imColIdx = w * strideWidth + wOffset; + if ((imRowIdx - paddingHeight) >= 0 && + (imRowIdx - paddingHeight) < inputHeight && + (imColIdx - paddingWidth) >= 0 && + (imColIdx - paddingWidth) < inputWidth) { + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; + imData[imRowIdx * inputWidth + imColIdx] += + colData[(c * outputHeight + h) * outputWidth + w]; + } + } + } + } + } +}; + +/* + * \brief Forward calculation of convolution. + */ +template +class GemmConvFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + // TODO(hedaoyuan): Need to define some index macros, + // to avoid useing 0 and 1. + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + check(input, filter, output); + + real beta; + if (outputs[0].getArgType() == ADD_TO) { + beta = 1.0; + } else { + beta = 0.0; + } + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + + real* inputData = inputs[0].data(); + real* filterData = inputs[1].data(); + real* outputData = outputs[0].data(); + + size_t size = inputChannels / groups_ * filterHeight * filterWidth * + outputHeight * outputWidth; + resizeBuffer(size); + real* colData = reinterpret_cast(memory_->getBuf()); + + Im2ColFunctor im2col; + GemmFunctor gemm; + size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; + size_t outputOffset = + (outputChannels / groups_) * outputHeight * outputWidth; + size_t filterOffset = filter.getElements() / groups_; + + for (size_t i = 0; i < batchSize; i++) { + for (size_t g = 0; g < groups_; g++) { + im2col(inputData + g * inputOffset, + inputChannels / groups_, + inputHeight, + inputWidth, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + outputHeight, + outputWidth, + colData); + + int M = outputChannels / groups_; + int N = outputHeight * outputWidth; + int K = inputChannels / groups_ * filterHeight * filterWidth; + gemm(CblasNoTrans, + CblasNoTrans, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + K, + colData, + N, + beta, + outputData + g * outputOffset, + N); + } + inputData += inputChannels * inputHeight * inputWidth; + outputData += outputChannels * outputHeight * outputWidth; + } + } +}; + +/* + * \brief Backward input calculation of convolution. + */ +template +class GemmConvGradInputFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + // Since the implementation of Col2ImFunctor is ADD_TO, + // this function only supports ADD_TO mode. + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + const TensorShape& output = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& input = outputs[0].shape(); + check(input, filter, output); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + + real* outputGrad = inputs[0].data(); + real* filterData = inputs[1].data(); + real* inputGrad = outputs[0].data(); + + size_t size = inputChannels / groups_ * filterHeight * filterWidth * + outputHeight * outputWidth; + resizeBuffer(size); + real* colData = reinterpret_cast(memory_->getBuf()); + + Col2ImFunctor col2im; + GemmFunctor gemm; + size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; + size_t outputOffset = + (outputChannels / groups_) * outputHeight * outputWidth; + size_t filterOffset = filter.getElements() / groups_; + + for (size_t i = 0; i < batchSize; i++) { + for (size_t g = 0; g < groups_; g++) { + int K = outputChannels / groups_; + int N = outputHeight * outputWidth; + int M = inputChannels / groups_ * filterHeight * filterWidth; + gemm(CblasTrans, + CblasNoTrans, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + M, + outputGrad + g * outputOffset, + N, + 0.0f, + colData, + N); + + col2im(colData, + inputChannels / groups_, + inputHeight, + inputWidth, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + outputHeight, + outputWidth, + inputGrad + g * inputOffset); + } + inputGrad += inputChannels * inputHeight * inputWidth; + outputGrad += outputChannels * outputHeight * outputWidth; + } + } +}; + +/* + * \brief Backward filter calculation of convolution. + */ +template +class GemmConvGradFilterFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + const TensorShape& output = inputs[0].shape(); + const TensorShape& input = inputs[1].shape(); + const TensorShape& filter = outputs[0].shape(); + check(input, filter, output); + + real beta; + if (outputs[0].getArgType() == ADD_TO) { + beta = 1.0; + } else { + beta = 0.0; + } + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + + real* outputGrad = inputs[0].data(); + real* inputData = inputs[1].data(); + real* filterGrad = outputs[0].data(); + + size_t size = inputChannels / groups_ * filterHeight * filterWidth * + outputHeight * outputWidth; + resizeBuffer(size); + real* colData = reinterpret_cast(memory_->getBuf()); + + Im2ColFunctor im2col; + GemmFunctor gemm; + size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; + size_t outputOffset = + (outputChannels / groups_) * outputHeight * outputWidth; + size_t filterOffset = filter.getElements() / groups_; + for (size_t i = 0; i < batchSize; i++) { + for (size_t g = 0; g < groups_; g++) { + im2col(inputData + g * inputOffset, + inputChannels / groups_, + inputHeight, + inputWidth, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + outputHeight, + outputWidth, + colData); + + int M = outputChannels / groups_; + int K = outputHeight * outputWidth; + int N = inputChannels / groups_ * filterHeight * filterWidth; + gemm(CblasNoTrans, + CblasTrans, + M, + N, + K, + 1.0f, + outputGrad + g * outputOffset, + K, + colData, + K, + i == 0 ? beta : 1.0f, + filterGrad + g * filterOffset, + N); + } + inputData += inputChannels * inputHeight * inputWidth; + outputGrad += outputChannels * outputHeight * outputWidth; + } + } +}; + +REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction); +REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction); +REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction); +REGISTER_TYPED_FUNC(GemmConvGradInput, GPU, GemmConvGradInputFunction); +REGISTER_TYPED_FUNC(GemmConvGradFilter, GPU, GemmConvGradFilterFunction); +#endif + +} // namespace paddle diff --git a/paddle/function/GemmConvOp.h b/paddle/function/GemmConvOp.h new file mode 100644 index 0000000000000..9f11cce597a07 --- /dev/null +++ b/paddle/function/GemmConvOp.h @@ -0,0 +1,62 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "ConvOp.h" + +namespace paddle { + +/* + * imData = [input_channels, input_height, input_width] + * colData = [input_channels, filter_height, filter_width, + * output_height, output_width] + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* colData); +}; + +template +class Col2ImFunctor { +public: + void operator()(const T* colData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* imData); +}; + +} // namespace paddle diff --git a/paddle/function/GemmConvOpGpu.cu b/paddle/function/GemmConvOpGpu.cu new file mode 100644 index 0000000000000..2a1795ff0fb56 --- /dev/null +++ b/paddle/function/GemmConvOpGpu.cu @@ -0,0 +1,186 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "ConvOp.h" +#include "GemmConvOp.h" + +namespace paddle { + +template +__global__ +void im2col(const T* data_im, int numOuts, int height, int width, + int blockH, int blockW, + int strideH, int strideW, + int paddingH, int paddingW, + int height_col, int width_col, + T* data_col) { + int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < numOuts) { + int w_out = index % width_col; + index /= width_col; + int h_out = index % height_col; + int channel_in = index / height_col; + int channel_out = channel_in * blockH * blockW; + int h_in = h_out * strideH; + int w_in = w_out * strideW; + + data_col += (channel_out * height_col + h_out) * width_col + w_out; + for (int i = 0; i < blockH; ++i) { + for (int j = 0; j < blockW; ++j) { + int rIdx = int(h_in+i); + int cIdx = int(w_in+j); + if ((rIdx-(int)paddingH) >= (int)height || + (rIdx-(int)paddingH) < 0 || + (cIdx-(int)paddingW) >= (int)width || + (cIdx-(int)paddingW) < 0) { + *data_col = 0; + } else { + rIdx = rIdx + channel_in*height - paddingH; + cIdx = cIdx - paddingW; + *data_col = data_im[rIdx* width + cIdx]; + } + data_col += height_col * width_col; + } + } + } +} + +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* colData) { + int numKernels = inputChannels * outputHeight * outputWidth; + int blocks = (numKernels + 1024 -1) / 1024; + int blockX = 512; + int blockY = (blocks + 512 - 1) / 512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + im2col<<< grid, threads, 0, STREAM_DEFAULT >>> + (imData, numKernels, inputHeight, inputWidth, filterHeight, filterWidth, + strideHeight, strideWidth, paddingHeight, paddingWidth, + outputHeight, outputWidth, colData); + CHECK_SYNC("Im2ColFunctor GPU failed"); + } +}; + +template +__global__ +void col2im(size_t n, const T* data_col, size_t height, + size_t width, size_t channels, + size_t blockH, size_t blockW, + size_t strideH, size_t strideW, + size_t paddingH, size_t paddingW, + size_t height_col, size_t width_col, + T* data_im) { + size_t index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < n) { + T val = 0; + int w = int(index % width); + int h = int((index / width) % height); + int c = int(index / (width * height)); + if ((w - (int)paddingW) >= 0 && + (w - (int)paddingW) < (width-2 * paddingW) && + (h - (int)paddingH) >= 0 && + (h - paddingH) < (height - 2 * paddingH)) { + // compute the start and end of the output + int w_col_start = + (w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1; + int w_col_end = + min((int)(w / (int)strideW + 1), (int)(width_col)); + int h_col_start = + (h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1; + int h_col_end = min(int(h / strideH + 1), int(height_col)); + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + // the col location: [c * width * height + h_out, w_out] + int c_col = int(c * blockH* blockW) + \ + (h - h_col * (int)strideH) * (int)blockW + + (w - w_col * (int)strideW); + val += data_col[(c_col * height_col + h_col) * width_col + w_col]; + } + } + h -= paddingH; + w -= paddingW; + data_im[c*((width-2*paddingW) * (height-2*paddingH)) + + h*(width-2*paddingW) + w] += val; + } + } +} + +template +class Col2ImFunctor { +public: + void operator()(const T* colData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* imData) { + size_t numKernels = inputChannels * (inputHeight + 2*paddingHeight) + * (inputWidth + 2*paddingWidth); + + size_t blocks = (numKernels + 1024 -1) / 1024; + size_t blockX = 512; + size_t blockY = (blocks+512-1)/512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + col2im<<< grid, threads, 0, STREAM_DEFAULT >>> + (numKernels, + colData, + inputHeight + 2*paddingHeight, + inputWidth + 2*paddingWidth, + inputChannels, + filterHeight, + filterWidth, + strideHeight, + strideWidth, + paddingHeight, + paddingWidth, + outputHeight, + outputWidth, + imData); + CHECK_SYNC("Col2ImFunctor GPU failed"); + } +}; + +template class Im2ColFunctor; +template class Im2ColFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; + +} // namespace paddle diff --git a/paddle/function/GemmFunctor.h b/paddle/function/GemmFunctor.h new file mode 100644 index 0000000000000..d5db5cf5e7a85 --- /dev/null +++ b/paddle/function/GemmFunctor.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/math/MathFunctions.h" + +namespace paddle { + +// TODO(hedaoyuan): Since the hl_matrix_mul interface does not conform to the +// cblas_dgemm interface's parameter format, it is necessary to introduce +// GemmFunctor as a new interface. Later, when considering the implementation +// of MatMulFunction, we need to consider the reconstruction of hl_matrix_mul +// interface. +template +class GemmFunctor { +public: + void operator()(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE TransB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc); +}; + +template +class GemmFunctor { +public: + void operator()(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE TransB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + gemm(transA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + } +}; + +template +class GemmFunctor { +public: + void operator()(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE TransB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + hl_matrix_mul((T*)A, + transA == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T, + (T*)B, + TransB == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T, + C, + M, + N, + K, + alpha, + beta, + lda, + ldb, + ldc); + } +}; + +} // namespace paddle diff --git a/paddle/function/MulOpTest.cpp b/paddle/function/MulOpTest.cpp index 8753057ebf73c..d31eb0c74f25f 100644 --- a/paddle/function/MulOpTest.cpp +++ b/paddle/function/MulOpTest.cpp @@ -35,7 +35,7 @@ void testFuncDDDMatrix( size_t heightC = dimM; size_t widthC = dimN; // init Test object - FunctionCompare test( + CpuGpuFuncCompare test( "MulOp", FuncConfig().set("aTrans", transa).set("bTrans", transb)); // prepare input arguments /// matrix A : HA * WA @@ -81,8 +81,8 @@ void testFuncDSparseDMatrix( size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) { real scaleT = 1.0; // init Test object - FunctionCompare test("MulOp", - FuncConfig().set("aTrans", false).set("bTrans", false)); + CpuGpuFuncCompare test( + "MulOp", FuncConfig().set("aTrans", false).set("bTrans", false)); // prepare input arguments /// sparse matrix A : M * K test.addInputs(SparseMatrixArg( @@ -126,8 +126,8 @@ void testFuncDDSparseMatrix( size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) { real scaleT = 1.0; // init Test object - FunctionCompare test("MulOp", - FuncConfig().set("aTrans", false).set("bTrans", false)); + CpuGpuFuncCompare test( + "MulOp", FuncConfig().set("aTrans", false).set("bTrans", false)); // prepare input arguments /// matrix A : M * K test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK})); @@ -172,8 +172,8 @@ void testFuncSparseDDMatrix( size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) { real scaleT = 1.0; // init Test object - FunctionCompare test("MulOp", - FuncConfig().set("aTrans", false).set("bTrans", false)); + CpuGpuFuncCompare test( + "MulOp", FuncConfig().set("aTrans", false).set("bTrans", false)); // prepare input arguments /// matrix A : M * K test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK})); diff --git a/paddle/function/NaiveConvOp.cpp b/paddle/function/NaiveConvOp.cpp new file mode 100644 index 0000000000000..1d204f99e0e12 --- /dev/null +++ b/paddle/function/NaiveConvOp.cpp @@ -0,0 +1,137 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "ConvOp.h" + +namespace paddle { + +/* + * The three arguments are stored in memory in row major order. + * inputData = [batchSize, inputChannels, inputHeight, inputWidth] + * filterData = [outputChannels, inputChannels, filterHeight, filterWidth] + * outputData = [batchSize, outputChannels, outputHeight, outputWidth] + */ +template +class NaiveConvFunctor { +public: + void operator()(const T* inputData, + size_t batchSize, + size_t inputChannels, + size_t inputHeight, + size_t inputWidth, + const T* filterData, + size_t filterHeight, + size_t filterWidth, + T* outputData, + size_t outputChannels, + size_t outputHeight, + size_t outputWidth, + size_t paddingH, + size_t paddingW, + size_t strideH, + size_t strideW) { + for (size_t batch = 0; batch < batchSize; batch++) { + for (size_t outC = 0; outC < outputChannels; outC++) { + for (size_t outH = 0; outH < outputHeight; outH++) { + for (size_t outW = 0; outW < outputWidth; outW++) { + const int inStartH = (outH * strideH) - paddingH; + const int inStartW = (outW * strideW) - paddingW; + T outValue = (T)0; + for (size_t inC = 0; inC < inputChannels; inC++) { + for (size_t fH = 0; fH < filterHeight; fH++) { + for (size_t fW = 0; fW < filterWidth; fW++) { + T inValue; + const int inH = inStartH + fH; + const int inW = inStartW + fW; + if ((inH >= 0 && inH < inputHeight) && + (inW >= 0 && inW < inputWidth)) { + size_t offsetInput = + batch * inputChannels * inputHeight * inputWidth + + inC * inputHeight * inputWidth + inH * inputWidth + inW; + inValue = inputData[offsetInput]; + } else { + inValue = (T)0; + } + size_t offsetFilter = + outC * inputChannels * filterHeight * filterWidth + + inC * filterHeight * filterWidth + fH * filterWidth + fW; + T filterValue = filterData[offsetFilter]; + outValue += (inValue * filterValue); + } + } + } + + size_t offset = + batch * outputChannels * outputHeight * outputWidth + + outC * outputHeight * outputWidth + outH * outputWidth + outW; + outputData[offset] = outValue; + } + } + } + } + } +}; + +template +class NaiveConvFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + check(input, filter, output); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + size_t batchSize = inputs[0].shape()[0]; + size_t inputChannels = inputs[0].shape()[1]; + size_t inputHeight = inputs[0].shape()[2]; + size_t inputWidth = inputs[0].shape()[3]; + size_t filterHeight = inputs[1].shape()[2]; + size_t filterWidth = inputs[1].shape()[3]; + size_t outputChannels = outputs[0].shape()[1]; + size_t outputHeight = outputs[0].shape()[2]; + size_t outputWidth = outputs[0].shape()[3]; + + real* inputData = inputs[0].data(); + real* filterData = inputs[1].data(); + real* outputData = outputs[0].data(); + NaiveConvFunctor conv; + conv(inputData, + batchSize, + inputChannels, + inputHeight, + inputWidth, + filterData, + filterHeight, + filterWidth, + outputData, + outputChannels, + outputHeight, + outputWidth, + paddingH(), + paddingW(), + strideH(), + strideW()); + } +}; + +REGISTER_TYPED_FUNC(NaiveConv, CPU, NaiveConvFunction); + +} // namespace paddle diff --git a/paddle/function/PadOpTest.cpp b/paddle/function/PadOpTest.cpp index f77ac2a8c49c8..e286f4e5b8a42 100644 --- a/paddle/function/PadOpTest.cpp +++ b/paddle/function/PadOpTest.cpp @@ -25,7 +25,7 @@ TEST(Pad, real) { VLOG(3) << " numSamples=" << numSamples << " channels=" << channels << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; for (bool test_grad : {false, true}) { - FunctionCompare compare( + CpuGpuFuncCompare compare( test_grad ? "PadGrad" : "Pad", FuncConfig() .set>("channel", {2, 3}) diff --git a/paddle/function/RowConvOpTest.cpp b/paddle/function/RowConvOpTest.cpp index 1c95d3ff2cccb..f52d18b0491ec 100644 --- a/paddle/function/RowConvOpTest.cpp +++ b/paddle/function/RowConvOpTest.cpp @@ -18,7 +18,7 @@ limitations under the License. */ namespace paddle { void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) { - FunctionCompare test("RowConv", FuncConfig()); + CpuGpuFuncCompare test("RowConv", FuncConfig()); test.addSequence(SequenceIdArg(TensorShape{batchSize})); test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim})); @@ -31,7 +31,7 @@ void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) { } void testRowConvBw(size_t batchSize, size_t dim, size_t contextLength) { - FunctionCompare test("RowConvGrad", FuncConfig()); + CpuGpuFuncCompare test("RowConvGrad", FuncConfig()); test.addSequence(SequenceIdArg(TensorShape{batchSize})); test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim})); diff --git a/paddle/gserver/layers/ConvBaseLayer.cpp b/paddle/gserver/layers/ConvBaseLayer.cpp index 7b234dc2a6663..e161d89c38a29 100644 --- a/paddle/gserver/layers/ConvBaseLayer.cpp +++ b/paddle/gserver/layers/ConvBaseLayer.cpp @@ -118,11 +118,7 @@ size_t ConvBaseLayer::calOutputSize() { layerSize = outH[0] * outW[0] * size_t(numFilters_); }; - if (isDeconv_) { - setLayerSize(outputH_, outputW_, imgSizeH_, imgSizeW_); - } else { - setLayerSize(imgSizeH_, imgSizeW_, outputH_, outputW_); - } + setLayerSize(imgSizeH_, imgSizeW_, outputH_, outputW_); return layerSize; } diff --git a/paddle/gserver/layers/CudnnConvBaseLayer.cpp b/paddle/gserver/layers/CudnnConvBaseLayer.cpp index 24363bb8b09cc..c056bbe4d1d35 100644 --- a/paddle/gserver/layers/CudnnConvBaseLayer.cpp +++ b/paddle/gserver/layers/CudnnConvBaseLayer.cpp @@ -70,14 +70,8 @@ void CudnnConvBaseLayer::forward(PassType passType) { if (biases_) { REGISTER_TIMER_INFO("CudnnConvBiasTimer", getName().c_str()); int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); - int outH, outW; - if (isDeconv_) { - outH = imgSizeH_[0]; - outW = imgSizeW_[0]; - } else { - outH = outputH_[0]; - outW = outputW_[0]; - } + int outH = outputH_[0]; + int outW = outputW_[0]; hl_tensor_reshape(outputDesc_, batchSize, diff --git a/paddle/gserver/layers/ExpandConvBaseLayer.cpp b/paddle/gserver/layers/ExpandConvBaseLayer.cpp index fdcf994cdb47f..77736e78f9349 100644 --- a/paddle/gserver/layers/ExpandConvBaseLayer.cpp +++ b/paddle/gserver/layers/ExpandConvBaseLayer.cpp @@ -22,26 +22,8 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap, /* Initialize the basic convolutional parent class */ ConvBaseLayer::init(layerMap, parameterMap); - /* The class fields channels_ and numFilters_ are the same as in the config - * i.e., channels_ is the for the input and numFilters_ is for the output - * - * But in order for the variables in convTrans having the same semantic - * meaning as in conv, we need to swap channels_ and numFilters here for - * convTrans, and in other functions too. - * */ - - /* Initialize the projection */ for (auto &inputConfig : config_.inputs()) { const ConvConfig &conf = inputConfig.conv_conf(); - int numFilters = isDeconv_ ? conf.channels() : numFilters_; - subM_.push_back(numFilters / conf.groups()); - subN_.push_back(conf.output_x() * - (conf.has_output_y() ? conf.output_y() : conf.output_x())); - int channel = isDeconv_ ? numFilters_ : conf.channels(); - subK_.push_back( - channel * conf.filter_size() * - (conf.has_filter_size_y() ? conf.filter_size_y() : conf.filter_size()) / - conf.groups()); /* Consistent caffe mode for multiple input */ caffeMode_ = conf.caffe_mode(); } @@ -54,17 +36,9 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap, size_t ExpandConvBaseLayer::getOutputSize() { CHECK_NE(inputLayers_.size(), 0UL); size_t layerSize = ConvBaseLayer::calOutputSize(); - subN_.clear(); - for (size_t i = 0; i < inputLayers_.size(); i++) { - subN_.push_back(outputH_[i] * outputW_[i]); - } return layerSize; } -void ExpandConvBaseLayer::resetExpandInput(size_t height, size_t width) { - Matrix::resizeOrCreate(expandInput_, height, width, false, useGpu_); -} - void ExpandConvBaseLayer::addSharedBias() { size_t mapW = getOutputSize() / numFilters_; size_t mapH = getOutputValue()->getElementCnt() / mapW; @@ -101,173 +75,6 @@ void ExpandConvBaseLayer::addUnsharedBias() { outValue->addBias(*bias, 1.0f); } -void ExpandConvBaseLayer::expandOneFrame(MatrixPtr image, - size_t startIdx, - int inIdx) { - int channel = isDeconv_ ? numFilters_ : channels_[inIdx]; - - resetExpandInput(subK_[inIdx] * groups_[inIdx], subN_[inIdx]); - - CHECK_EQ(image->getWidth(), - static_cast(imgSizeH_[inIdx] * imgSizeW_[inIdx] * channel)); - - real *imgData = image->getData() + startIdx * image->getWidth(); - MatrixPtr imageTmp = - Matrix::create(imgData, - 1, - imgSizeH_[inIdx] * imgSizeW_[inIdx] * channel, - false, - useGpu_); - expandInput_->convExpand(*imageTmp, - imgSizeH_[inIdx], - imgSizeW_[inIdx], - channel, - filterSizeY_[inIdx], - filterSize_[inIdx], - strideY_[inIdx], - stride_[inIdx], - paddingY_[inIdx], - padding_[inIdx], - outputH_[inIdx], - outputW_[inIdx]); - imageTmp->clear(); -} - -void ExpandConvBaseLayer::expandFwdOnce(MatrixPtr image, - MatrixPtr out, - int inIdx, - int startIdx) { - int subM = subM_[inIdx]; - int subN = subN_[inIdx]; - int subK = subK_[inIdx]; - - expandOneFrame(image, startIdx, inIdx); - - int numFilters = isDeconv_ ? channels_[inIdx] : numFilters_; - - real *outData = out->getData() + startIdx * subN * numFilters; - - real *wgtData = weights_[inIdx]->getW()->getData(); - real *expInData = expandInput_->getData(); - for (int g = 0; g < groups_[inIdx]; ++g) { - MatrixPtr A = - Matrix::create(wgtData, subM, subK, false, useGpu_); // mark transpose - MatrixPtr B = Matrix::create(expInData, subK, subN, false, useGpu_); - MatrixPtr C = Matrix::create(outData, subM, subN, false, useGpu_); - C->mul(*A, *B, 1, 1); - - A->clear(); - B->clear(); - C->clear(); - wgtData += subK * subM; - expInData += subK * subN; - outData += subM * subN; - } -} - -void ExpandConvBaseLayer::bpropActs(MatrixPtr out, - MatrixPtr image, - int inpIdx) { - int channel = isDeconv_ ? numFilters_ : channels_[inpIdx]; - - int subM = subM_[inpIdx]; - int subN = subN_[inpIdx]; - int subK = subK_[inpIdx]; - size_t batchSize = image->getHeight(); - - /* reset the expand-grad memory */ - resetExpandInput(subK * groups_[inpIdx], subN); - - real *localGradData = out->getData(); - real *tgtGradData = image->getData(); - for (size_t n = 0; n < batchSize; n++) { - real *wgtData = weights_[inpIdx]->getW()->getData(); - real *expandInData = expandInput_->getData(); - - for (int g = 0; g < groups_[inpIdx]; g++) { - // create temporary matrix - MatrixPtr C = Matrix::create(expandInData, subK, subN, false, useGpu_); - MatrixPtr B = Matrix::create(localGradData, subM, subN, false, useGpu_); - MatrixPtr A = Matrix::create(wgtData, subM, subK, true, useGpu_); - C->mul(*A, *B); // mul - - // clear the temporary matrix - A->clear(); - B->clear(); - C->clear(); - - expandInData += subK * subN; - localGradData += subM * subN; - wgtData += subK * subM; - } - - // shrink one frame outGrad - MatrixPtr oneGradTmp = Matrix::create( - expandInput_->getData(), subK * groups_[inpIdx], subN, false, useGpu_); - MatrixPtr vTmp = - Matrix::create(tgtGradData, - 1, - imgSizeH_[inpIdx] * imgSizeW_[inpIdx] * channel, - false, - useGpu_); - vTmp->convShrink(*oneGradTmp, - imgSizeH_[inpIdx], - imgSizeW_[inpIdx], - channel, - filterSizeY_[inpIdx], - filterSize_[inpIdx], - strideY_[inpIdx], - stride_[inpIdx], - paddingY_[inpIdx], - padding_[inpIdx], - outputH_[inpIdx], - outputW_[inpIdx], - 1.0f, - 1.0f); - vTmp->clear(); - oneGradTmp->clear(); - - // move the data-pointer - tgtGradData += imgSizeH_[inpIdx] * imgSizeW_[inpIdx] * channel; - } -} - -void ExpandConvBaseLayer::bpropWeights(MatrixPtr image, - MatrixPtr out, - int inpIdx) { - MatrixPtr weightGrad = weights_[inpIdx]->getWGrad(); - - int subM = subM_[inpIdx]; - int subN = subN_[inpIdx]; - int subK = subK_[inpIdx]; - size_t batchSize = image->getHeight(); - resetExpandInput(subK * groups_[inpIdx], subN); - - real *gradData = out->getData(); - - for (size_t n = 0; n < batchSize; n++) { // frame by frame - // expand - expandOneFrame(image, n, inpIdx); - real *wGradData = weightGrad->getData(); - real *expandInData = expandInput_->getData(); - - // expand-mul one-group by one - for (int g = 0; g < groups_[inpIdx]; g++) { - MatrixPtr A = Matrix::create(expandInData, subK, subN, true, useGpu_); - MatrixPtr B = Matrix::create(gradData, subM, subN, false, useGpu_); - MatrixPtr C = Matrix::create(wGradData, subM, subK, false, useGpu_); - C->mul(*B, *A, 1, 1); - - A->clear(); - B->clear(); - C->clear(); - gradData += subM * subN; - wGradData += subK * subM; - expandInData += subK * subN; - } - } -} - void ExpandConvBaseLayer::bpropSharedBias(MatrixPtr biases, MatrixPtr v) { size_t mapW = getOutputSize() / numFilters_; size_t mapH = v->getElementCnt() / mapW; diff --git a/paddle/gserver/layers/ExpandConvBaseLayer.h b/paddle/gserver/layers/ExpandConvBaseLayer.h index aabcdfc392d3e..01c699d234444 100644 --- a/paddle/gserver/layers/ExpandConvBaseLayer.h +++ b/paddle/gserver/layers/ExpandConvBaseLayer.h @@ -26,19 +26,6 @@ namespace paddle { */ class ExpandConvBaseLayer : public ConvBaseLayer { protected: - /// For expand convolution. - /// subM_ = numFilters_ / groups_. - IntV subM_; - /// subN_ = outputH_ * outputW_. - IntV subN_; - /// subK_ = channels_ * filterPixels_ * groups_. - IntV subK_; - - /*The expandInput_ and transOutValue_ are used for CPU expand conv calc - * Expand one sample at a time. shape: - * (numChannels * filterPixels_, outputSizeH * outputSizeW) - * */ - MatrixPtr expandInput_; /// The transpose of output, which is an auxiliary matrix. MatrixPtr transOutValue_; @@ -52,10 +39,6 @@ class ExpandConvBaseLayer : public ConvBaseLayer { const ParameterMap& parameterMap) override; size_t getOutputSize(); - /** - * Create or resize expandInput_. - */ - void resetExpandInput(size_t height, size_t width); /** * Add shared bias. @@ -66,20 +49,9 @@ class ExpandConvBaseLayer : public ConvBaseLayer { * Add unshared bias. */ void addUnsharedBias(); - /** - * Expand one input sample. - */ - void expandOneFrame(MatrixPtr image, size_t startIdx, int inIdx); - - /** - * Expand one input sample and perform matrix multiplication. - */ - void expandFwdOnce(MatrixPtr image, MatrixPtr out, int inIdx, int startIdx); void bpropSharedBias(MatrixPtr biases, MatrixPtr v); void bpropBiases(MatrixPtr v); - void bpropWeights(MatrixPtr image, MatrixPtr out, int inpIdx); - void bpropActs(MatrixPtr image, MatrixPtr out, int inpIdx); }; } // namespace paddle diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp index f9267b81a7d42..914689e66cdb8 100644 --- a/paddle/gserver/layers/ExpandConvLayer.cpp +++ b/paddle/gserver/layers/ExpandConvLayer.cpp @@ -18,32 +18,94 @@ limitations under the License. */ namespace paddle { +/* + * The calculation of the exconvt(convolution transpose (deconv) operation) + * is a swap of forward and backward of the calculation of exconv. + * */ REGISTER_LAYER(exconv, ExpandConvLayer); +REGISTER_LAYER(exconvt, ExpandConvLayer); bool ExpandConvLayer::init(const LayerMap &layerMap, const ParameterMap ¶meterMap) { /* Initialize the basic convolutional parent class */ ExpandConvBaseLayer::init(layerMap, parameterMap); + + size_t numInputs = config_.inputs_size(); + inputShape_.resize(numInputs); + filterShape_.resize(numInputs); + outputShape_.resize(numInputs); + for (int i = 0; i < config_.inputs_size(); i++) { + std::vector paddings = {(size_t)paddingY_[i], (size_t)padding_[i]}; + std::vector strides = {(size_t)strideY_[i], (size_t)stride_[i]}; + createFunction(forward_, + !isDeconv_ ? "GemmConv" : "GemmConvGradInput", + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)groups_[i])); + + createFunction(backward_, + !isDeconv_ ? "GemmConvGradInput" : "GemmConv", + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)groups_[i])); + + createFunction(backward_, + "GemmConvGradFilter", + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)groups_[i])); + } return true; } +// i is the index of input layers +#define BACKWARD_INPUT(i, inputs, outputs) \ + backward_[2 * i]->calc(inputs, outputs) +#define BACKWARD_FILTER(i, inputs, outputs) \ + backward_[2 * i + 1]->calc(inputs, outputs) + void ExpandConvLayer::forward(PassType passType) { Layer::forward(passType); - /* malloc memory for the output_ if necessary */ - int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); + size_t batchSize = inputLayers_[0]->getOutputValue()->getHeight(); resetOutput(batchSize, getOutputSize()); - MatrixPtr image = nullptr; - MatrixPtr outV = getOutputValue(); + // Calculate the shape of the input, output, and filter. for (size_t i = 0; i < inputLayers_.size(); ++i) { - LayerPtr prevLayer = getPrev(i); - image = prevLayer->getOutputValue(); - for (size_t off = 0; off < image->getHeight(); off++) { - REGISTER_TIMER_INFO("expandFwdOnce", getName().c_str()); - expandFwdOnce(image, outV, i, off); - } + inputShape_[i] = TensorShape({(size_t)batchSize, + (size_t)channels_[i], + (size_t)imgSizeH_[i], + (size_t)imgSizeW_[i]}); + filterShape_[i] = + TensorShape({(size_t)groups_[i], + !isDeconv_ ? (size_t)numFilters_ / groups_[i] + : (size_t)channels_[i] / groups_[i], + !isDeconv_ ? (size_t)channels_[i] / groups_[i] + : (size_t)numFilters_ / groups_[i], + (size_t)filterSizeY_[i], + (size_t)filterSize_[i]}); + outputShape_[i] = TensorShape({(size_t)batchSize, + (size_t)numFilters_, + (size_t)outputH_[i], + (size_t)outputW_[i]}); } + + // Calculate the output value. + for (size_t i = 0; i < inputLayers_.size(); ++i) { + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(i), inputShape_[i]); + inputs.addArg(*weights_[i]->getW(), filterShape_[i]); + outputs.addArg(*getOutputValue(), + outputShape_[i], + !isDeconv_ && i == 0 ? ASSIGN_TO : ADD_TO); + + forward_[i]->calc(inputs, outputs); + } + /* add the bias-vector */ if (biases_.get()) { if (sharedBiases_) { @@ -67,14 +129,30 @@ void ExpandConvLayer::backward(const UpdateCallback &callback) { biases_->getParameterPtr()->incUpdate(callback); } + // Calculate the input grad and filter grad. for (size_t i = 0; i < inputLayers_.size(); ++i) { - /* First, calculate the input layers error */ - if (getPrev(i)->getOutputGrad()) { - bpropActs(outGrad, getPrev(i)->getOutputGrad(), i); + if (getInputGrad(i)) { + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outputShape_[i]); + inputs.addArg(*weights_[i]->getW(), filterShape_[i]); + outputs.addArg(*getInputGrad(i), inputShape_[i], ADD_TO); + BACKWARD_INPUT(i, inputs, outputs); } + if (weights_[i]->getWGrad()) { - /* Then, calculate the W-gradient for the current layer */ - bpropWeights(getPrev(i)->getOutputValue(), outGrad, i); + BufferArgs inputs; + BufferArgs outputs; + if (!isDeconv_) { + inputs.addArg(*getOutputGrad(), outputShape_[i]); + inputs.addArg(*getInputValue(i), inputShape_[i]); + } else { + inputs.addArg(*getInputValue(i), inputShape_[i]); + inputs.addArg(*getOutputGrad(), outputShape_[i]); + } + outputs.addArg(*weights_[i]->getWGrad(), filterShape_[i], ADD_TO); + BACKWARD_FILTER(i, inputs, outputs); + /* Increasing the number of gradient */ weights_[i]->getParameterPtr()->incUpdate(callback); } diff --git a/paddle/gserver/layers/ExpandConvLayer.h b/paddle/gserver/layers/ExpandConvLayer.h index 60681690e5dd5..a1f943d152154 100644 --- a/paddle/gserver/layers/ExpandConvLayer.h +++ b/paddle/gserver/layers/ExpandConvLayer.h @@ -40,6 +40,11 @@ class ExpandConvLayer : public ExpandConvBaseLayer { void forward(PassType passType) override; void backward(const UpdateCallback& callback) override; + +protected: + std::vector inputShape_; + std::vector filterShape_; + std::vector outputShape_; }; } // namespace paddle diff --git a/paddle/gserver/layers/ExpandConvTransLayer.cpp b/paddle/gserver/layers/ExpandConvTransLayer.cpp deleted file mode 100644 index 520586b138897..0000000000000 --- a/paddle/gserver/layers/ExpandConvTransLayer.cpp +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "ExpandConvTransLayer.h" -#include "paddle/utils/Logging.h" -#include "paddle/utils/Stat.h" - -/* The implementation of the convTransLayer is basically a swap of forward and - * backward of the original convLayer. - * The variable naming follows the convention of the convLayer. - * */ - -namespace paddle { - -REGISTER_LAYER(exconvt, ExpandConvTransLayer); - -bool ExpandConvTransLayer::init(const LayerMap &layerMap, - const ParameterMap ¶meterMap) { - /* Initialize the basic convolutional parent class */ - ExpandConvBaseLayer::init(layerMap, parameterMap); - - return true; -} - -void ExpandConvTransLayer::forward(PassType passType) { - Layer::forward(passType); - - /* malloc memory for the output_ if necessary */ - int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); - resetOutput(batchSize, getOutputSize()); - - MatrixPtr output = nullptr; - for (size_t i = 0; i < inputLayers_.size(); ++i) { - LayerPtr prevLayer = getPrev(i); - output = prevLayer->getOutputValue(); - REGISTER_TIMER_INFO("shrinkFwd", getName().c_str()); - bpropActs(output, getOutputValue(), i); - } - - /* add the bias-vector */ - if (biases_.get()) { - if (sharedBiases_) { - addSharedBias(); - } else { - addUnsharedBias(); - } - } - - /* activation */ - forwardActivation(); -} - -void ExpandConvTransLayer::backward(const UpdateCallback &callback) { - backwardActivation(); - - MatrixPtr imageGrad = getOutputGrad(); - if (biases_ && biases_->getWGrad()) { - bpropBiases(imageGrad); - /* Increasing the number of gradient */ - biases_->getParameterPtr()->incUpdate(callback); - } - - for (size_t i = 0; i < inputLayers_.size(); ++i) { - /* First, calculate the input layers error */ - for (size_t off = 0; off < imageGrad->getHeight(); off++) { - if (getPrev(i)->getOutputGrad()) { - expandFwdOnce(imageGrad, getPrev(i)->getOutputGrad(), i, off); - } - } - if (weights_[i]->getWGrad()) { - /* Then, calculate the W-gradient for the current layer */ - bpropWeights(imageGrad, getPrev(i)->getOutputValue(), i); - /* Increasing the number of gradient */ - weights_[i]->getParameterPtr()->incUpdate(callback); - } - } -} - -} // namespace paddle diff --git a/paddle/gserver/layers/ExpandConvTransLayer.h b/paddle/gserver/layers/ExpandConvTransLayer.h deleted file mode 100644 index 00b8f241889fd..0000000000000 --- a/paddle/gserver/layers/ExpandConvTransLayer.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include "ExpandConvBaseLayer.h" -#include "paddle/math/Matrix.h" - -namespace paddle { - -/** - * @brief A subclass of convolution layer. - * This layer expands input and use matrix multiplication to - * calculate convolution transpose (deconv) operation. - * - * The config file api is img_conv_layer with flag trans=True. - */ -class ExpandConvTransLayer : public ExpandConvBaseLayer { -public: - explicit ExpandConvTransLayer(const LayerConfig& config) - : ExpandConvBaseLayer(config) {} - - ~ExpandConvTransLayer() {} - - bool init(const LayerMap& layerMap, - const ParameterMap& parameterMap) override; - - void forward(PassType passType) override; - void backward(const UpdateCallback& callback) override; -}; - -} // namespace paddle diff --git a/paddle/gserver/tests/test_BatchNorm.cpp b/paddle/gserver/tests/test_BatchNorm.cpp index d07299bfe3c41..83fcfed46cd56 100644 --- a/paddle/gserver/tests/test_BatchNorm.cpp +++ b/paddle/gserver/tests/test_BatchNorm.cpp @@ -17,7 +17,6 @@ limitations under the License. */ #include #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" -#include "paddle/gserver/layers/ExpandConvTransLayer.h" #include "paddle/trainer/Trainer.h" #include "paddle/utils/GlobalConstants.h" diff --git a/paddle/gserver/tests/test_ConvTrans.cpp b/paddle/gserver/tests/test_ConvTrans.cpp index 40bb1e2d73c81..6035a866b4eee 100644 --- a/paddle/gserver/tests/test_ConvTrans.cpp +++ b/paddle/gserver/tests/test_ConvTrans.cpp @@ -17,7 +17,6 @@ limitations under the License. */ #include #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" -#include "paddle/gserver/layers/ExpandConvTransLayer.h" #include "paddle/math/MathUtils.h" #include "paddle/trainer/Trainer.h" #include "paddle/utils/GlobalConstants.h" diff --git a/paddle/gserver/tests/test_ConvUnify.cpp b/paddle/gserver/tests/test_ConvUnify.cpp index 54b72375b743f..e7325e0cc3b71 100644 --- a/paddle/gserver/tests/test_ConvUnify.cpp +++ b/paddle/gserver/tests/test_ConvUnify.cpp @@ -17,7 +17,6 @@ limitations under the License. */ #include #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" -#include "paddle/gserver/layers/ExpandConvTransLayer.h" #include "paddle/math/MathUtils.h" #include "paddle/trainer/Trainer.h" #include "paddle/utils/GlobalConstants.h"