diff --git a/cmake/util.cmake b/cmake/util.cmake index 38366373c6dbc..8a71b23c62d9f 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -96,6 +96,7 @@ function(link_paddle_exe TARGET_NAME) target_circle_link_libraries(${TARGET_NAME} ARCHIVE_START paddle_gserver + paddle_function ${METRIC_LIBS} ARCHIVE_END paddle_pserver @@ -106,6 +107,7 @@ function(link_paddle_exe TARGET_NAME) paddle_parameter paddle_proto paddle_cuda + paddle_test_main ${METRIC_LIBS} ${PROTOBUF_LIBRARY} ${LIBGLOG_LIBRARY} diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index fb3af8ea92fee..2daea052b01ad 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(cuda) +add_subdirectory(function) add_subdirectory(utils) add_subdirectory(math) add_subdirectory(parameter) diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index 6ad1d79e59b11..ed69bd764f30a 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -46,6 +46,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp WORKING_DIRECTORY ${PROJ_ROOT}/paddle DEPENDS python_swig_sources paddle_parameter + paddle_function paddle_math paddle_utils paddle_gserver diff --git a/paddle/api/paddle_ld_flags.py b/paddle/api/paddle_ld_flags.py index 51d7dfee58b78..7c8206e3fe097 100644 --- a/paddle/api/paddle_ld_flags.py +++ b/paddle/api/paddle_ld_flags.py @@ -30,8 +30,8 @@ whole_end = "" LIB_DIRS = [ - "math", 'utils', 'parameter', "gserver", "api", "cuda", "pserver", - "trainer" + "math", 'function', 'utils', 'parameter', "gserver", "api", "cuda", + "pserver", "trainer" ] PARENT_LIB_DIRS = ['proto'] @@ -75,6 +75,7 @@ def libs_str(self): libs = [ whole_start, "-lpaddle_gserver", + "-lpaddle_function", whole_end, "-lpaddle_pserver", "-lpaddle_trainer_lib", diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index 06ee3b3654b57..c5787630abbe1 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -240,62 +240,6 @@ extern void hl_avgpool_backward(const int frameCnt, real* backGrad, const int outStride); -/** - * @brief Cross-map-respose normalize forward. - * - * @param[in] frameCnt batch size of input image. - * @param[in] in input data. - * @param[in] scale buffer. - * @param[out] out output data. - * @param[in] channels number of channel. - * @param[in] height image height. - * @param[in] width image width. - * @param[in] sizeX size. - * @param[in] alpha scale. - * @param[in] beta scale. - * - */ -extern void hl_CMRNorm_forward(size_t frameCnt, - const real* in, - real* scale, - real* out, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta); - -/** - * @brief Cross-map-respose normalize backward. - * - * @param[in] frameCnt batch size of input image. - * @param[in] inV input data. - * @param[in] scale buffer. - * @param[out] outV output value. - * @param[out] outDiff output grad. - * @param[out] inDiff input grad. - * @param[in] channels number of channel. - * @param[in] height image height. - * @param[in] width image width. - * @param[in] sizeX size. - * @param[in] alpha scale. - * @param[in] beta scale. - * - */ -extern void hl_CMRNorm_backward(size_t frameCnt, - const real* inV, - const real* scale, - const real* outV, - const real* outDiff, - real* inDiff, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta); - /** * @brief Bilinear interpolation forward. * diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index 52c978735279e..039551c6cc695 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -117,30 +117,6 @@ inline void hl_avgpool_backward(const int frameCnt, real* backGrad, const int outStride) {} -inline void hl_CMRNorm_forward(size_t frameCnt, - const real* in, - real* scale, - real* out, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta) {} - -inline void hl_CMRNorm_backward(size_t frameCnt, - const real* inV, - const real* scale, - const real* outV, - const real* outDiff, - real* inDiff, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta) {} - inline void hl_bilinear_forward(const real* inData, const size_t inImgH, const size_t inImgW, diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 0992286f360fb..b94f4d8fe4a25 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -381,164 +381,6 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad, CHECK_SYNC("hl_avgpool_backward failed"); } -__global__ void KeCMRNormFillScale(size_t nthreads, const real* in, - real* scale, size_t channels, - size_t height, size_t width, size_t size, - real alpha) { - size_t index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { - // find out the local offset - size_t w = index % width; - size_t h = (index / width) % height; - size_t n = index / width / height; - size_t offset = (n * channels * height + h) * width + w; - size_t step = height * width; - in += offset; - scale += offset; - size_t head = 0; - size_t pre_pad = (size - 1) / 2; - size_t post_pad = size - pre_pad - 1; - real accum_scale = 0; - // fill the scale at [n, :, h, w] - // accumulate values - while (head < post_pad) { - accum_scale += in[head * step] * in[head * step]; - ++head; - } - // until we reach size, nothing needs to be subtracted - while (head < size) { - accum_scale += in[head * step] * in[head * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; - } - // both add and subtract - while (head < channels) { - accum_scale += in[head * step] * in[head * step]; - accum_scale -= in[(head - size) * step] * in[(head - size) * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; - } - // subtract only - while (head < channels + post_pad) { - accum_scale -= in[(head - size) * step] * in[(head - size) * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; - } - } -} - - __global__ void KeCMRNormOutput(size_t nthreads, const real* in, - const real* scale, real negative_beta, - real* out) { - size_t index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { - out[index] = in[index] * pow(scale[index], negative_beta); - } -} - -void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale, - real* out, size_t channels, - size_t height, size_t width, size_t sizeX, - real alpha, real beta) { - size_t threadsNum = frameCnt * height * width; - size_t blocksX = (threadsNum + 1024 - 1) / 1024; - size_t blocksY = 1; - dim3 threads(1024, 1); - dim3 grid(blocksX, blocksY); - - KeCMRNormFillScale<<>> - (threadsNum, in, scale, channels, height, width, sizeX, alpha); - - threadsNum = frameCnt * height * width *channels; - blocksX = (threadsNum + 1024 -1) / 1024; - dim3 threads2(1024, 1); - dim3 grid2(blocksX, blocksY); - KeCMRNormOutput<<>> - (threadsNum, in, scale, beta, out); - CHECK_SYNC("hl_CMRNorm_forward"); -} - -__global__ void KeCMRNormDiff(size_t nthreads, const real* bottom_data, - const real* top_data, const real* scale, - const real* top_diff, size_t channels, - size_t height, size_t width, size_t size, - real negative_beta, real cache_ratio, - real* bottom_diff ) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { - // find out the local offset - size_t w = index % width; - size_t h = (index / width) % height; - size_t n = index / width / height; - size_t offset = (n * channels * height + h) * width + w; - size_t step = height * width; - bottom_data += offset; - top_data += offset; - scale += offset; - top_diff += offset; - bottom_diff += offset; - int head = 0; - int pre_pad = size - (size + 1) / 2; - int post_pad = size - pre_pad - 1; - real accum_ratio = 0; - // accumulate values - while (head < post_pad) { - accum_ratio += top_diff[head * step] * - top_data[head * step] / scale[head * step]; - ++head; - } - // until we reach size, nothing needs to be subtracted - while (head < size) { - accum_ratio += top_diff[head * step] * - top_data[head * step] / scale[head * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; - } - // both add and subtract - while (head < channels) { - accum_ratio += top_diff[head * step] * top_data[head * step] / - scale[head * step]; - accum_ratio -= top_diff[(head - size) * step] * - top_data[(head - size) * step] / scale[(head - size) * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; - } - // subtract only - while (head < channels + post_pad) { - accum_ratio -= top_diff[(head - size) * step] * - top_data[(head - size) * step] / scale[(head - size) * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; - } - } -} - -void hl_CMRNorm_backward(size_t frameCnt, const real* inV, - const real* scale, - const real* outV, const real* outDiff, - real *inDiff, size_t channels, - size_t height, size_t width, size_t sizeX, - real alpha, real beta) { - size_t threadsNum = frameCnt * height * width; - size_t blocksX = (threadsNum + 1024 - 1) / 1024; - size_t blocksY = 1; - dim3 threads(1024, 1); - dim3 grid(blocksX, blocksY); - KeCMRNormDiff <<>> - (threadsNum, inV, outV, scale, outDiff, channels, - height, width, sizeX, alpha, beta, inDiff); - CHECK_SYNC("hl_CMRNorm_backward"); -} - __global__ void KeBilinearInterpFw(const real* in, const size_t inImgH, const size_t inImgW, diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt new file mode 100644 index 0000000000000..0697842bbef62 --- /dev/null +++ b/paddle/function/CMakeLists.txt @@ -0,0 +1,27 @@ +file(GLOB h_files . *_op.h) +file(GLOB cpp_files . *_op.cpp) + +list(APPEND h_files Function.h) +list(APPEND cpp_files Function.cpp) + +if(WITH_GPU) + file(GLOB cu_files . *_op_gpu.cu) + cuda_compile(cu_objs ${cu_files}) +endif() + +add_library(paddle_function STATIC ${cpp_files} ${cu_objs}) + +add_library(paddle_test_main STATIC TestMain.cpp) + +if(WITH_GPU) + # TODO: + # file(GLOB test_files . *_op_test.cpp) + # add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files}) + add_simple_unittest(cross_map_normal_op_test) +endif() + +add_style_check_target(paddle_function ${h_files}) +add_style_check_target(paddle_function ${cpp_files}) +if(WITH_GPU) + add_style_check_target(paddle_function ${cu_files}) +endif() diff --git a/paddle/function/Function.cpp b/paddle/function/Function.cpp new file mode 100644 index 0000000000000..02880e5ea1acb --- /dev/null +++ b/paddle/function/Function.cpp @@ -0,0 +1,49 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" + +namespace paddle { + +template <> +size_t FuncConfig::get(const std::string& key) const { + auto it = valueMap_.find(key); + CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'"; + return it->second.s; +} + +template <> +real FuncConfig::get(const std::string& key) const { + auto it = valueMap_.find(key); + CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'"; + return it->second.r; +} + +template <> +FuncConfig& FuncConfig::set(const std::string& key, size_t v) { + CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; + valueMap_[key].s = v; + return *this; +} + +template <> +FuncConfig& FuncConfig::set(const std::string& key, real v) { + CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; + valueMap_[key].r = v; + return *this; +} + +ClassRegistrar FunctionBase::funcRegistrar_; + +} // namespace paddle diff --git a/paddle/function/Function.h b/paddle/function/Function.h new file mode 100644 index 0000000000000..095584c0b19f7 --- /dev/null +++ b/paddle/function/Function.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/math/Matrix.h" +#include "paddle/utils/ClassRegistrar.h" + +namespace paddle { + +enum DeviceType { + DEVICE_TYPE_UNSPECIFIED = 0, + DEVICE_TYPE_CPU = 1, + DEVICE_TYPE_GPU = 2, +}; + +template +struct MatrixT; + +template <> +struct MatrixT { + using type = CpuMatrix; +}; + +template <> +struct MatrixT { + using type = GpuMatrix; +}; + +typedef std::vector Dims; + +class Tensor { +public: + Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {} + + real* getData() const { return buf_; } + + real* buf_; + Dims dims_; +}; + +typedef std::vector Arguments; + +class FuncConfig { +public: + union value { + size_t s; + real r; + }; + + template + T get(const std::string& key) const; + + template + FuncConfig& set(const std::string& key, T v); + +protected: + std::map valueMap_; +}; + +class FunctionBase { +public: + virtual ~FunctionBase() {} + + virtual void init(const FuncConfig& config) {} + + virtual void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) {} + + static ClassRegistrar funcRegistrar_; +}; + +#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName + +#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ + static InitFunction __reg_type_##typeName##deviceName([]() { \ + FunctionBase::funcRegistrar_ \ + .registerClass>( \ + FUNC_NAME(typeName, deviceName)); \ + }) + +} // namespace paddle diff --git a/paddle/function/FunctionTest.h b/paddle/function/FunctionTest.h new file mode 100644 index 0000000000000..a8c5e412bd12d --- /dev/null +++ b/paddle/function/FunctionTest.h @@ -0,0 +1,102 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" +#include "paddle/math/Vector.h" +#include "paddle/math/tests/TensorCheck.h" + +namespace paddle { + +class FunctionCompare { +public: + FunctionCompare(const std::string& name, const FuncConfig& config) + : cpu(FunctionBase::funcRegistrar_.createByType(name + "-CPU")), + gpu(FunctionBase::funcRegistrar_.createByType(name + "-GPU")) { + cpu->init(config); + gpu->init(config); + } + + void cmpWithArg(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) { + // init cpu and gpu arguments + auto initArgs = [=]( + Arguments& cpuArgs, Arguments& gpuArgs, const Arguments& inArgs) { + for (auto arg : inArgs) { + size_t size = sizeof(real); + for (auto dim : arg.dims_) { + size *= dim; + } + cpuMemory.emplace_back(std::make_shared(size)); + gpuMemory.emplace_back(std::make_shared(size)); + cpuArgs.emplace_back( + Tensor((real*)cpuMemory.back()->getBuf(), arg.dims_)); + gpuArgs.emplace_back( + Tensor((real*)gpuMemory.back()->getBuf(), arg.dims_)); + + // will use an api to refactor this code. + CpuVector cpuVector(size / sizeof(real), + (real*)cpuArgs.back().getData()); + GpuVector gpuVector(size / sizeof(real), + (real*)gpuArgs.back().getData()); + cpuVector.uniform(0.001, 1); + gpuVector.copyFrom(cpuVector); + } + }; + initArgs(cpuInputs, gpuInputs, inputs); + initArgs(cpuOutputs, gpuOutputs, outputs); + initArgs(cpuInouts, gpuInouts, inouts); + + // function calculate + cpu->calc(cpuInputs, cpuOutputs, cpuInouts); + gpu->calc(gpuInputs, gpuOutputs, gpuInouts); + + // check outputs and inouts + auto checkArgs = [=](const Arguments& cpuArgs, const Arguments& gpuArgs) { + for (size_t i = 0; i < cpuArgs.size(); i++) { + auto cpu = cpuArgs[i]; + auto gpu = gpuArgs[i]; + size_t size = 1; + for (auto dim : cpu.dims_) { + size *= dim; + } + CpuVector cpuVector(size, (real*)cpu.getData()); + GpuVector gpuVector(size, (real*)gpu.getData()); + + autotest::TensorCheckErr(cpuVector, gpuVector); + } + }; + checkArgs(cpuOutputs, gpuOutputs); + checkArgs(cpuInouts, gpuInouts); + } + +protected: + std::shared_ptr cpu; + std::shared_ptr gpu; + std::vector cpuMemory; + std::vector gpuMemory; + Arguments cpuInputs; + Arguments cpuOutputs; + Arguments cpuInouts; + Arguments gpuInputs; + Arguments gpuOutputs; + Arguments gpuInouts; +}; + +} // namespace paddle + +using paddle::FunctionCompare; +using paddle::FuncConfig; +using paddle::Dims; +using paddle::Tensor; diff --git a/paddle/function/TestMain.cpp b/paddle/function/TestMain.cpp new file mode 100644 index 0000000000000..3e14532d1878f --- /dev/null +++ b/paddle/function/TestMain.cpp @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/utils/Util.h" + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + paddle::initMain(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/function/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp new file mode 100644 index 0000000000000..a9c7693830542 --- /dev/null +++ b/paddle/function/cross_map_normal_op.cpp @@ -0,0 +1,227 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "cross_map_normal_op.h" +#include "paddle/math/Vector.h" + +namespace paddle { + +template <> +void CrossMapNormal(real* outputs, + real* denoms, + const real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t oneImage = height * width; + size_t oneSample = channels * oneImage; + + CpuVector outputsV(numSamples * oneSample, outputs); + CpuVector inputsV(numSamples * oneSample, const_cast(inputs)); + CpuVector denomsV(numSamples * oneSample, denoms); + + // f(x) = x * ( 1 + scale * SUM((x)^2) )^(-pow) + // x represents inputs + // f(x) represents outputs + // denoms save the intermediate result for backward + denomsV = denomsV.constant(1.0); + const int start = -((int)size - 1) / 2; + const int end = (int)size + start; + for (size_t i = 0; i < numSamples; i++) { + real* oneDenom = denoms + i * oneSample; + real* oneInput = const_cast(inputs) + i * oneSample; + for (int c = 0; c < (int)channels; c++) { + CpuVector denom(oneImage, oneDenom + c * oneImage); + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + CpuVector input(oneImage, oneInput + (c + s) * oneImage); + denom += input.square() * scale; + } + } + } + } + + outputsV = inputsV * denomsV.pow(-pow); +} + +template <> +void CrossMapNormalGrad(real* inputsGrad, + const real* inputsValue, + const real* outputsValue, + const real* outputsGrad, + const real* denoms, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t oneSample = channels * height * width; + std::function oneImage = [=](real* data, + size_t offset) { + return CpuVector(height * width, data + offset); + }; + + const int start = -((int)size) / 2; + const int end = (int)size + start; + const real ratio = -(real)2 * scale * pow; + for (size_t i = 0; i < numSamples; i++) { + size_t sOffset = i * oneSample; + real* oneInputGrad = inputsGrad + sOffset; + real* oneInputValue = const_cast(inputsValue) + sOffset; + real* oneDenom = const_cast(denoms) + sOffset; + real* oneOutputGrad = const_cast(outputsGrad) + sOffset; + real* oneOutputValue = const_cast(outputsValue) + sOffset; + + for (int c = 0; c < (int)channels; c++) { + size_t cOffset = c * height * width; + CpuVector inputGrad = oneImage(oneInputGrad, cOffset); + CpuVector inputValue = oneImage(oneInputValue, cOffset); + CpuVector denom = oneImage(oneDenom, cOffset); + CpuVector outputGrad = oneImage(oneOutputGrad, cOffset); + + inputGrad = inputGrad + denom.pow(-pow) * outputGrad; + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + size_t offset = (c + s) * height * width; + CpuVector output = oneImage(oneOutputValue, offset); + CpuVector outputGrad = oneImage(oneOutputGrad, offset); + CpuVector denom = oneImage(oneDenom, offset); + + inputGrad += ((outputGrad * output * ratio) / denom) * inputValue; + } + } + } + } +} + +/** + * \param inputs[0] input value. + * \param outputs[0] output value. + * \param outputs[1] denoms. + */ +template +class CrossMapNormalFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + size_ = config.get("size"); + scale_ = config.get("scale"); + pow_ = config.get("pow"); + } + + void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) override { + CHECK_EQ(1, inputs.size()); + CHECK_EQ(2, outputs.size()); + CHECK_EQ(0, inouts.size()); + + CHECK_EQ(inputs[0].dims_.size(), 4); + for (size_t i = 0; i < inputs[0].dims_.size(); i++) { + CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]); + } + + size_t samples = inputs[0].dims_[0]; + size_t channels = inputs[0].dims_[1]; + size_t height = inputs[0].dims_[2]; + size_t width = inputs[0].dims_[3]; + + CrossMapNormal(outputs[0].getData(), + outputs[1].getData(), + inputs[0].getData(), + samples, + channels, + height, + width, + size_, + scale_, + pow_); + } + +private: + size_t size_; + real scale_; + real pow_; +}; + +/** + * \param inputs[0] input value. + * \param inputs[1] output value. + * \param inputs[2] output grad. + * \param inputs[3] denoms. + * \param outputs[0] input grad. + */ +template +class CrossMapNormalGradFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + size_ = config.get("size"); + scale_ = config.get("scale"); + pow_ = config.get("pow"); + } + + void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) override { + CHECK_EQ(4, inputs.size()); + CHECK_EQ(1, outputs.size()); + CHECK_EQ(0, inouts.size()); + + CHECK_EQ(inputs[0].dims_.size(), 4); + for (size_t i = 0; i < inputs[0].dims_.size(); i++) { + CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); + } + + size_t samples = inputs[0].dims_[0]; + size_t channels = inputs[0].dims_[1]; + size_t height = inputs[0].dims_[2]; + size_t width = inputs[0].dims_[3]; + + CrossMapNormalGrad(outputs[0].getData(), + inputs[0].getData(), + inputs[1].getData(), + inputs[2].getData(), + inputs[3].getData(), + samples, + channels, + height, + width, + size_, + scale_, + pow_); + } + +private: + size_t size_; + real scale_; + real pow_; +}; + +REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); +REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); +REGISTER_TYPED_FUNC(CrossMapNormalGrad, GPU, CrossMapNormalGradFunc); +#endif + +} // namespace paddle diff --git a/paddle/function/cross_map_normal_op.h b/paddle/function/cross_map_normal_op.h new file mode 100644 index 0000000000000..b1e401ad0a2f5 --- /dev/null +++ b/paddle/function/cross_map_normal_op.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +/** + * \brief Cross map respose normalize forward. + * The data structure of image data is NCHW. + * + * \param[out] outputs output data. + * \param[in] denoms denoms buffer. + * \param[in] inputs input data. + * \param[in] numSamples batch size of input image. + * \param[in] channels number of channel. + * \param[in] height image height. + * \param[in] width image width. + * \param[in] size size. + * \param[in] scale scale. + * \param[in] pow scale. + * + */ +template +void CrossMapNormal(real* outputs, + real* denoms, + const real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow); + +/** + * \brief Cross map respose normalize backward. + * The data structure of image data is NCHW. + * + * \param[out] inputsGrad input grad. + * \param[in] inputsValue input value. + * \param[out] outputsValue output value. + * \param[out] outputsGrad output grad. + * \param[in] denoms denoms buffer. + * \param[in] numSamples batch size of input image. + * \param[in] channels number of channel. + * \param[in] height image height. + * \param[in] width image width. + * \param[in] size size. + * \param[in] scale scale. + * \param[in] pow scale. + * + */ +template +void CrossMapNormalGrad(real* inputsGrad, + const real* inputsValue, + const real* outputsValue, + const real* outputsGrad, + const real* denoms, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow); + +} // namespace paddle diff --git a/paddle/function/cross_map_normal_op_gpu.cu b/paddle/function/cross_map_normal_op_gpu.cu new file mode 100644 index 0000000000000..aae4f461b6f57 --- /dev/null +++ b/paddle/function/cross_map_normal_op_gpu.cu @@ -0,0 +1,156 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "cross_map_normal_op.h" + +namespace paddle { + +__global__ void KeCMRNormFillScale(size_t imageSize, const real* in, + real* scale, size_t channels, + size_t height, size_t width, size_t size, + real alpha) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + + in += offset; + scale += offset; + const int step = height * width; + const int pre_pad = (size - 1) / 2; + const int post_pad = size - pre_pad - 1; + + real accum = 0; + int index = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += in[index * step] * in[index * step]; + } + if (index >= size) { + accum -= in[(index - size) * step] * in[(index - size) * step]; + } + if (index >= post_pad) { + scale[(index - post_pad) * step] = 1. + accum * alpha; + } + ++index; + } + } +} + +__global__ void KeCMRNormOutput(size_t inputSize, const real* in, + const real* scale, real negative_beta, + real* out) { + const int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < inputSize) { + out[index] = in[index] * pow(scale[index], negative_beta); + } +} + +template <> +void CrossMapNormal(real* outputs, + real* denoms, + const real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t imageSize = numSamples * height * width; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormFillScale<<>> + (imageSize, inputs, denoms, channels, height, width, size, scale); + + size_t inputSize = numSamples * height * width *channels; + blockSize = 1024; + gridSize = (inputSize + 1024 - 1) / 1024; + KeCMRNormOutput<<>> + (inputSize, inputs, denoms, -pow, outputs); + + CHECK_SYNC("CrossMapNormal"); +} + +__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, + const real* top_data, const real* scale, + const real* top_diff, size_t channels, + size_t height, size_t width, size_t size, + real negative_beta, real cache_ratio, + real* bottom_diff ) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + bottom_data += offset; + top_data += offset; + scale += offset; + top_diff += offset; + bottom_diff += offset; + + const int step = height * width; + const int pre_pad = size - (size + 1) / 2; + const int post_pad = size - pre_pad - 1; + + int index = 0; + real accum = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += top_diff[index * step] * top_data[index * step] / + scale[index * step]; + } + if (index >= size) { + accum -= top_diff[(index - size) * step] * + top_data[(index - size) * step] / scale[(index - size) * step]; + } + if (index >= post_pad) { + bottom_diff[(index - post_pad) * step] += + top_diff[(index - post_pad) * step] * + pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(index - post_pad) * step] * accum; + } + ++index; + } + } +} + +template <> +void CrossMapNormalGrad(real* inputsGrad, + const real* inputsValue, + const real* outputsValue, + const real* outputsGrad, + const real* denoms, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t imageSize = numSamples * height * width; + + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormDiff <<>> + (imageSize, inputsValue, outputsValue, denoms, outputsGrad, channels, + height, width, size, -pow, 2.0f * pow * scale, inputsGrad); + CHECK_SYNC("CrossMapNormalGrad"); +} + +} // namespace paddle diff --git a/paddle/function/cross_map_normal_op_test.cpp b/paddle/function/cross_map_normal_op_test.cpp new file mode 100644 index 0000000000000..22692691bdb64 --- /dev/null +++ b/paddle/function/cross_map_normal_op_test.cpp @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +TEST(CrossMapNormal, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {1, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + for (size_t size : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " size=" << size; + + FunctionCompare compare("CrossMapNormal", + FuncConfig() + .set("size", size) + .set("scale", (real)1.5) + .set("pow", (real)0.5)); + Dims dims{numSamples, channels, imgSizeH, imgSizeW}; + compare.cmpWithArg({Tensor(nullptr, dims)}, + {Tensor(nullptr, dims), Tensor(nullptr, dims)}, + {}); + } + } + } + } + } +} + +TEST(CrossMapNormalGrad, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {1, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + for (size_t size : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " size=" << size; + + FunctionCompare compare("CrossMapNormalGrad", + FuncConfig() + .set("size", size) + .set("scale", (real)1.5) + .set("pow", (real)0.5)); + Dims dims{numSamples, channels, imgSizeH, imgSizeW}; + compare.cmpWithArg({Tensor(nullptr, dims), + Tensor(nullptr, dims), + Tensor(nullptr, dims), + Tensor(nullptr, dims)}, + {Tensor(nullptr, dims)}, + {}); + } + } + } + } + } +} diff --git a/paddle/gserver/CMakeLists.txt b/paddle/gserver/CMakeLists.txt index a066f80c221ee..4f92150ec84d6 100644 --- a/paddle/gserver/CMakeLists.txt +++ b/paddle/gserver/CMakeLists.txt @@ -27,16 +27,12 @@ if(NOT WITH_GPU) list(REMOVE_ITEM GSERVER_HEADER layers/CudnnConvLayer.h layers/CudnnPoolLayer.h - layers/CudnnBatchNormLayer.h - layers/NormProjectionLayer.h - layers/NormLayer.h) + layers/CudnnBatchNormLayer.h) list(REMOVE_ITEM GSERVER_SOURCES layers/CudnnConvLayer.cpp layers/CudnnPoolLayer.cpp - layers/CudnnBatchNormLayer.cpp - layers/NormProjectionLayer.cpp - layers/NormLayer.cpp) + layers/CudnnBatchNormLayer.cpp) compile_cu_as_cpp(layers/LstmCompute.cu) compile_cu_as_cpp(layers/GruCompute.cu) endif() diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h index 172e558b82945..6dfd48fb96618 100644 --- a/paddle/gserver/layers/Layer.h +++ b/paddle/gserver/layers/Layer.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include "ModelConfig.pb.h" +#include "paddle/function/Function.h" #include "paddle/math/CpuSparseMatrix.h" #include "paddle/parameter/Parameter.h" #include "paddle/utils/ClassRegistrar.h" @@ -100,6 +101,11 @@ class Layer { /// Mark input grad in(true) or out(false) of backward function. std::vector markInBackward_; + /// Layer forward function + std::vector> forward_; + /// Layer backward function + std::vector> backward_; + public: /** * Wait until all input value ready. @@ -126,6 +132,26 @@ class Layer { virtual void markAllInputGrad(); protected: + /** + * Create layer function. Function is called in forward or backward. + * \param function, Layer::forward_ or Layer::backward_ + * \param name, function name + * \param config, initialization configuration for the function + */ + void createFunction(std::vector>& function, + const std::string& name, + const FuncConfig& config) { + if (useGpu_) { + function.emplace_back( + FunctionBase::funcRegistrar_.createByType(name + "-GPU")); + } else { + function.emplace_back( + FunctionBase::funcRegistrar_.createByType(name + "-CPU")); + } + auto& func = function.back(); + func->init(config); + } + /** * Notify specified layer the output grad ready. * Called in the backward function. diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index 934fc31e0acf9..262d757c67e10 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -45,6 +45,15 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, /* the size of inputs for norm-layer is 1 */ CHECK_EQ(config_.inputs_size(), 1); + createFunction( + forward_, + "CrossMapNormal", + FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); + createFunction( + backward_, + "CrossMapNormalGrad", + FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); + return true; } @@ -54,7 +63,7 @@ void CMRProjectionNormLayer::forward(PassType passType) { /* malloc memory for the output_ if necessary */ /* note: one sample correspond to one row */ MatrixPtr input = inputLayers_[0]->getOutputValue(); - int batchSize = input->getHeight(); + size_t batchSize = input->getHeight(); int size = getSize(); resetOutput(batchSize, size); @@ -62,10 +71,11 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - denoms_->zeroMem(); - - outV->crossMapNormalFwd( - *input, imgSizeH_, imgSizeW_, *denoms_, channels_, size_, scale_, pow_); + dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_}; + forward_[0]->calc( + {Tensor(input->getData(), dims_)}, + {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, + {}); } void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { @@ -80,15 +90,11 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { MatrixPtr localOutV = getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); - preOutGrad->crossMapNormalBwd(*localGrad, - *denoms_, - *preOutV, - *localOutV, - channels_, - imgSizeH_, - imgSizeW_, - size_, - scale_, - pow_); + backward_[0]->calc({Tensor(preOutV->getData(), dims_), + Tensor(localOutV->getData(), dims_), + Tensor(localGrad->getData(), dims_), + Tensor(denoms_->getData(), dims_)}, + {Tensor(preOutGrad->getData(), dims_)}, + {}); } } // namespace paddle diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 4f7b638334afe..6b2c5dde0d74d 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -39,5 +39,8 @@ class CMRProjectionNormLayer : public ResponseNormLayer { bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType); void backward(const UpdateCallback& callback = nullptr); + +protected: + Dims dims_; }; } // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 8a8d094ed357a..2cc25f6b211e3 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1021,11 +1021,10 @@ void testNormLayer(const string& normType, bool trans, bool useGpu) { testLayerGrad(config, "norm", 100, trans, useGpu); } -#ifndef PADDLE_ONLY_CPU TEST(Layer, NormLayer) { testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ true); + testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ false); } -#endif void setPoolConfig(TestConfig* config, PoolConfig* pool, diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index c69e074a76399..a36c31d32b7ca 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1265,69 +1265,6 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad, outGrad.getStride()); } -void GpuMatrix::crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow) { - size_t num = input.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - - CHECK(height * width * channels == input.getWidth()); - CHECK(denoms.getHeight() == input.getHeight() && - denoms.getWidth() == input.getWidth() && input.getHeight() == height_ && - input.getWidth() == width_); - hl_CMRNorm_forward(num, - input.getData(), - denoms.getData(), - data_, - channels, - height, - width, - sizeX, - scale, - -pow); -} - -void GpuMatrix::crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow) { - size_t num = preOutV.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - - CHECK(width * height * channels == preOutV.getWidth()); - CHECK(denoms.getHeight() == preOutV.getHeight() && - denoms.getWidth() == preOutV.getWidth() && - preOutV.getHeight() == height_ && preOutV.getWidth() == width_); - CHECK(denoms.getHeight() == localGrad.getHeight() && - denoms.getWidth() == localGrad.getWidth()); - - hl_CMRNorm_backward(num, - preOutV.getData(), - denoms.getData(), - localOutV.getData(), - localGrad.getData(), - data_, - channels, - height, - width, - sizeX, - -pow, - 2.0f * pow * scale); -} - void GpuMatrix::maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index) { @@ -2219,84 +2156,6 @@ void CpuMatrix::avgPoolBackward(Matrix& input, } } -void CpuMatrix::crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow) { - size_t num = input.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - size_t numCols = input.getWidth(); - CHECK(height * width * channels == input.getWidth()); - CHECK(denoms.getHeight() == input.getHeight() && - denoms.getWidth() == input.getWidth() && input.getHeight() == height_ && - input.getWidth() == width_); - real* imgData = input.getData(); - real* diffData = input.getData(); - real* targetData = getData(); - size_t halfSize = sizeX / 2; - size_t imgPixels = height * width; - - // use integral vector to implement the sum in local window - real* integralData = - (real*)malloc((channels + sizeX + 1) * sizeof(real)); // NOLINT // TODO: - for (size_t i = 0; i <= halfSize; i++) { - integralData[i] = 0; - } - for (size_t i = 0; i < num; i++) { - real* targetPtr = targetData + i * numCols; - real* imgPtr = imgData + i * numCols; - real* diffPtr = diffData + i * numCols; - for (size_t m = 0; m < height; m++) { - for (size_t n = 0; n < width; n++) { - for (size_t c = 0; c < channels; c++) { - integralData[c + halfSize + 1] = - integralData[c + halfSize] + _square(*(diffPtr + c * imgPixels)); - } - for (size_t k = channels + halfSize + 1; k <= channels + sizeX; k++) { - integralData[k] = integralData[channels + halfSize]; - } - for (size_t k = 0; k < channels; k += 1) { - real a = integralData[k + sizeX] - integralData[k]; - a = scale * a + 1; - targetPtr[k * imgPixels] = imgPtr[k * imgPixels] * _pow(a, -pow); - } - diffPtr++; - targetPtr++; - imgPtr++; - } - } - } - free(integralData); - integralData = NULL; -} - -void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t size, - float scale, - float pow) { - LOG(FATAL) << "Not implemented"; - - CHECK(imgSizeH * imgSizeW * channels == preOutV.getWidth()); - CHECK(denoms.getHeight() == preOutV.getHeight() && - denoms.getWidth() == preOutV.getWidth() && - preOutV.getHeight() == height_ && preOutV.getWidth() == width_); - CHECK(denoms.getHeight() == localGrad.getHeight() && - denoms.getWidth() == localGrad.getWidth()); - - // NOLINT // TODO: -} - /** * Input: one or more sequences. Each sequence contains some instances. * Output: output size is the number of input sequences (NOT input instances). diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 1cfb90a9dbf19..dccadc8a38721 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -952,31 +952,6 @@ class Matrix : public BaseMatrix { LOG(FATAL) << "Not implemeted"; } - /// normalize-operation. - virtual void crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow) { - LOG(FATAL) << "Not implemeted"; - } - - virtual void crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t size, - float scale, - float pow) { - LOG(FATAL) << "Not implemeted"; - } - /** * Input: one or more sequences. Each sequence contains some instances. * @@ -1459,26 +1434,6 @@ class GpuMatrix : public Matrix { size_t paddingH, size_t paddingW); - void crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow); - - void crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow); - void maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index); @@ -1685,26 +1640,6 @@ class CpuMatrix : public Matrix { size_t paddingH, size_t paddingW); - void crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow); - - void crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow); - void maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index);