From 529f24c262850974dd8ba4c5b7ad1a4e3e0230fc Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 12 Dec 2016 18:17:27 +0800 Subject: [PATCH 01/15] cpu cmrnorm --- paddle/cuda/src/hl_cuda_cnn.cu | 192 +++++++++-------------- paddle/gserver/tests/test_LayerGrad.cpp | 3 +- paddle/math/Matrix.cpp | 137 ++++++++++------ paddle/math/tests/test_matrixCompare.cpp | 115 ++++++++++++++ 4 files changed, 279 insertions(+), 168 deletions(-) diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 0992286f360fb..1516accaae17f 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -381,57 +381,45 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad, CHECK_SYNC("hl_avgpool_backward failed"); } -__global__ void KeCMRNormFillScale(size_t nthreads, const real* in, +__global__ void KeCMRNormFillScale(size_t imageSize, const real* in, real* scale, size_t channels, size_t height, size_t width, size_t size, real alpha) { - size_t index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { - // find out the local offset - size_t w = index % width; - size_t h = (index / width) % height; - size_t n = index / width / height; - size_t offset = (n * channels * height + h) * width + w; - size_t step = height * width; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + in += offset; scale += offset; - size_t head = 0; - size_t pre_pad = (size - 1) / 2; - size_t post_pad = size - pre_pad - 1; - real accum_scale = 0; - // fill the scale at [n, :, h, w] - // accumulate values - while (head < post_pad) { - accum_scale += in[head * step] * in[head * step]; - ++head; - } - // until we reach size, nothing needs to be subtracted - while (head < size) { - accum_scale += in[head * step] * in[head * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; - } - // both add and subtract - while (head < channels) { - accum_scale += in[head * step] * in[head * step]; - accum_scale -= in[(head - size) * step] * in[(head - size) * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; - } - // subtract only - while (head < channels + post_pad) { - accum_scale -= in[(head - size) * step] * in[(head - size) * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; + const int step = height * width; + const int pre_pad = (size - 1) / 2; + const int post_pad = size - pre_pad - 1; + + real accum = 0; + int index = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += in[index * step] * in[index * step]; + } + if (index >= size) { + accum -= in[(index - size) * step] * in[(index - size) * step]; + } + if (index >= post_pad) { + scale[(index - post_pad) * step] = 1. + accum * alpha; + } + ++index; } } } - __global__ void KeCMRNormOutput(size_t nthreads, const real* in, - const real* scale, real negative_beta, - real* out) { - size_t index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { +__global__ void KeCMRNormOutput(size_t inputSize, const real* in, + const real* scale, real negative_beta, + real* out) { + const int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < inputSize) { out[index] = in[index] * pow(scale[index], negative_beta); } } @@ -440,84 +428,60 @@ void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale, real* out, size_t channels, size_t height, size_t width, size_t sizeX, real alpha, real beta) { - size_t threadsNum = frameCnt * height * width; - size_t blocksX = (threadsNum + 1024 - 1) / 1024; - size_t blocksY = 1; - dim3 threads(1024, 1); - dim3 grid(blocksX, blocksY); - - KeCMRNormFillScale<<>> - (threadsNum, in, scale, channels, height, width, sizeX, alpha); - - threadsNum = frameCnt * height * width *channels; - blocksX = (threadsNum + 1024 -1) / 1024; - dim3 threads2(1024, 1); - dim3 grid2(blocksX, blocksY); - KeCMRNormOutput<<>> - (threadsNum, in, scale, beta, out); + size_t imageSize = frameCnt * height * width; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormFillScale<<>> + (imageSize, in, scale, channels, height, width, sizeX, alpha); + + size_t inputSize = frameCnt * height * width *channels; + blockSize = 1024; + gridSize = (inputSize + 1024 - 1) / 1024; + KeCMRNormOutput<<>> + (inputSize, in, scale, beta, out); CHECK_SYNC("hl_CMRNorm_forward"); } -__global__ void KeCMRNormDiff(size_t nthreads, const real* bottom_data, +__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, const real* top_data, const real* scale, const real* top_diff, size_t channels, size_t height, size_t width, size_t size, real negative_beta, real cache_ratio, real* bottom_diff ) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { - // find out the local offset - size_t w = index % width; - size_t h = (index / width) % height; - size_t n = index / width / height; - size_t offset = (n * channels * height + h) * width + w; - size_t step = height * width; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; bottom_data += offset; top_data += offset; scale += offset; top_diff += offset; bottom_diff += offset; - int head = 0; - int pre_pad = size - (size + 1) / 2; - int post_pad = size - pre_pad - 1; - real accum_ratio = 0; - // accumulate values - while (head < post_pad) { - accum_ratio += top_diff[head * step] * - top_data[head * step] / scale[head * step]; - ++head; - } - // until we reach size, nothing needs to be subtracted - while (head < size) { - accum_ratio += top_diff[head * step] * - top_data[head * step] / scale[head * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; - } - // both add and subtract - while (head < channels) { - accum_ratio += top_diff[head * step] * top_data[head * step] / - scale[head * step]; - accum_ratio -= top_diff[(head - size) * step] * - top_data[(head - size) * step] / scale[(head - size) * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; - } - // subtract only - while (head < channels + post_pad) { - accum_ratio -= top_diff[(head - size) * step] * - top_data[(head - size) * step] / scale[(head - size) * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; + + const int step = height * width; + const int pre_pad = size - (size + 1) / 2; + const int post_pad = size - pre_pad - 1; + + int index = 0; + real accum = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += top_diff[index * step] * top_data[index * step] / + scale[index * step]; + } + if (index >= size) { + accum -= top_diff[(index - size) * step] * + top_data[(index - size) * step] / scale[(index - size) * step]; + } + if (index >= post_pad) { + bottom_diff[(index - post_pad) * step] += + top_diff[(index - post_pad) * step] * + pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(index - post_pad) * step] * accum; + } + ++index; } } } @@ -528,14 +492,12 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV, real *inDiff, size_t channels, size_t height, size_t width, size_t sizeX, real alpha, real beta) { - size_t threadsNum = frameCnt * height * width; - size_t blocksX = (threadsNum + 1024 - 1) / 1024; - size_t blocksY = 1; - dim3 threads(1024, 1); - dim3 grid(blocksX, blocksY); - KeCMRNormDiff <<>> - (threadsNum, inV, outV, scale, outDiff, channels, - height, width, sizeX, alpha, beta, inDiff); + size_t imageSize = frameCnt * height * width; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormDiff <<>> + (imageSize, inV, outV, scale, outDiff, channels, + height, width, sizeX, alpha, beta, inDiff); CHECK_SYNC("hl_CMRNorm_backward"); } diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 7983d9fe64c61..8ade15daac860 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1021,11 +1021,10 @@ void testNormLayer(const string& normType, bool trans, bool useGpu) { testLayerGrad(config, "norm", 100, trans, useGpu); } -#ifndef PADDLE_ONLY_CPU TEST(Layer, NormLayer) { testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ true); + testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ false); } -#endif void setPoolConfig(TestConfig* config, PoolConfig* pool, diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index c69e074a76399..2cde11dd479dc 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -2227,52 +2227,43 @@ void CpuMatrix::crossMapNormalFwd(Matrix& input, size_t sizeX, float scale, float pow) { - size_t num = input.getHeight(); + CHECK(isContiguous()); + CHECK(input.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK_EQ(getHeight(), input.getHeight()); + CHECK_EQ(getWidth(), input.getWidth()); + CHECK_EQ(getHeight(), denoms.getHeight()); + CHECK_EQ(getWidth(), denoms.getWidth()); + + size_t numSample = input.getHeight(); + size_t numCols = input.getWidth(); size_t height = imgSizeH; size_t width = imgSizeW; - size_t numCols = input.getWidth(); - CHECK(height * width * channels == input.getWidth()); - CHECK(denoms.getHeight() == input.getHeight() && - denoms.getWidth() == input.getWidth() && input.getHeight() == height_ && - input.getWidth() == width_); - real* imgData = input.getData(); - real* diffData = input.getData(); - real* targetData = getData(); - size_t halfSize = sizeX / 2; - size_t imgPixels = height * width; - - // use integral vector to implement the sum in local window - real* integralData = - (real*)malloc((channels + sizeX + 1) * sizeof(real)); // NOLINT // TODO: - for (size_t i = 0; i <= halfSize; i++) { - integralData[i] = 0; - } - for (size_t i = 0; i < num; i++) { - real* targetPtr = targetData + i * numCols; - real* imgPtr = imgData + i * numCols; - real* diffPtr = diffData + i * numCols; - for (size_t m = 0; m < height; m++) { - for (size_t n = 0; n < width; n++) { - for (size_t c = 0; c < channels; c++) { - integralData[c + halfSize + 1] = - integralData[c + halfSize] + _square(*(diffPtr + c * imgPixels)); - } - for (size_t k = channels + halfSize + 1; k <= channels + sizeX; k++) { - integralData[k] = integralData[channels + halfSize]; + CHECK(height * width * channels == numCols); + + // TODO(hedaoyuan) After commit TensorExpress code, + // Reconstruction this code to remove the temporary memory. + CpuMatrix tmp(channels, height * width); + CpuMatrix tmp2(tmp.getData(), 1, channels * height * width); + denoms.zero(); + const int start = -((int)sizeX - 1) / 2; + const int end = (int)sizeX + start; + for (size_t i = 0; i < numSample; i++) { + input.subMatrix(i, 1)->square2(tmp2); + CpuMatrix subDen( + denoms.subMatrix(i, 1)->getData(), channels, height * width); + for (int c = 0; c < (int)channels; c++) { + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + subDen.subMatrix(c, 1)->add(*tmp.subMatrix(c + s, 1)); } - for (size_t k = 0; k < channels; k += 1) { - real a = integralData[k + sizeX] - integralData[k]; - a = scale * a + 1; - targetPtr[k * imgPixels] = imgPtr[k * imgPixels] * _pow(a, -pow); - } - diffPtr++; - targetPtr++; - imgPtr++; } } } - free(integralData); - integralData = NULL; + + denoms.add(scale, (real)1); + this->pow2(denoms, -pow); + this->dotMul(input); } void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, @@ -2282,19 +2273,63 @@ void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, size_t channels, size_t imgSizeH, size_t imgSizeW, - size_t size, + size_t sizeX, float scale, float pow) { - LOG(FATAL) << "Not implemented"; - - CHECK(imgSizeH * imgSizeW * channels == preOutV.getWidth()); - CHECK(denoms.getHeight() == preOutV.getHeight() && - denoms.getWidth() == preOutV.getWidth() && - preOutV.getHeight() == height_ && preOutV.getWidth() == width_); - CHECK(denoms.getHeight() == localGrad.getHeight() && - denoms.getWidth() == localGrad.getWidth()); - - // NOLINT // TODO: + CHECK(isContiguous()); + CHECK(localGrad.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK(preOutV.isContiguous()); + CHECK(localOutV.isContiguous()); + CHECK_EQ(getHeight(), localGrad.getHeight()); + CHECK_EQ(getWidth(), localGrad.getWidth()); + CHECK_EQ(getHeight(), denoms.getHeight()); + CHECK_EQ(getWidth(), denoms.getWidth()); + CHECK_EQ(getHeight(), preOutV.getHeight()); + CHECK_EQ(getWidth(), preOutV.getWidth()); + CHECK_EQ(getHeight(), localOutV.getHeight()); + CHECK_EQ(getWidth(), localOutV.getWidth()); + + size_t numSample = getHeight(); + size_t numCols = getWidth(); + size_t height = imgSizeH; + size_t width = imgSizeW; + CHECK(height * width * channels == numCols); + + // TODO(hedaoyuan) After commit TensorExpress code, + // Reconstruction this code to remove the temporary memory. + CpuMatrix tmp(1, height * width); + + const int start = -((int)sizeX) / 2; + const int end = (int)sizeX + start; + const real ratio = -(real)2 * scale * pow; + for (size_t i = 0; i < numSample; i++) { + CpuMatrix inputDiff( + this->subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix outDiff( + localGrad.subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix input( + preOutV.subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix output( + localOutV.subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix subDen( + denoms.subMatrix(i, 1)->getData(), channels, height * width); + + for (int c = 0; c < (int)channels; c++) { + tmp.pow2(*subDen.subMatrix(c, 1), -pow); + inputDiff.subMatrix(c, 1) + ->addDotMul(tmp, *outDiff.subMatrix(c, 1), (real)1, (real)1); + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + tmp.dotMul(*outDiff.subMatrix(c + s, 1), *output.subMatrix(c + s, 1)); + tmp.mulScalar(ratio); + tmp.dotDiv(tmp, *subDen.subMatrix(c + s, 1)); + tmp.dotMul(*input.subMatrix(c, 1)); + inputDiff.subMatrix(c, 1)->add(tmp); + } + } + } + } } /** diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 713792d82b3c5..5233a9af40155 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1261,6 +1261,121 @@ TEST(Matrix, MaxOutFwdBwd) { } } } +void testCrossMapNormalFwd( + int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { + float scale = 1.5; + float pow = 0.5; + int width = imgSizeH * imgSizeW * channels; + MatrixPtr input = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr denorms = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr target = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr inputGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr denormsGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr targetGpu = GpuMatrix::create(numSamples, width, false, true); + + input->randomizeUniform(); + target->randomizeUniform(); + inputGpu->copyFrom(*input); + targetGpu->copyFrom(*target); + + target->crossMapNormalFwd( + *input, imgSizeH, imgSizeW, *denorms, channels, sizeX, scale, pow); + targetGpu->crossMapNormalFwd( + *inputGpu, imgSizeH, imgSizeW, *denormsGpu, channels, sizeX, scale, pow); + + TensorCheckErr(*target, *targetGpu); + TensorCheckErr(*denorms, *denormsGpu); +} + +TEST(Matrix, crossMapNormalFwd) { + for (auto numSamples : {5, 32}) { + for (auto channels : {1, 5, 32}) { + for (auto imgSizeH : {5, 33, 100}) { + for (auto imgSizeW : {5, 32, 96}) { + for (auto sizeX : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " sizeX=" << sizeX; + testCrossMapNormalFwd( + numSamples, channels, imgSizeH, imgSizeW, sizeX); + } + } + } + } + } +} + +void testCrossMapNormalBwd( + int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { + float scale = 1.5; + float pow = 0.5; + size_t width = imgSizeH * imgSizeW * channels; + MatrixPtr localGrad = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr denoms = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr output = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr preOutV = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr localOutV = CpuMatrix::create(numSamples, width, false, false); + + localGrad->randomizeUniform(); + denoms->randomizeUniform(); + preOutV->randomizeUniform(); + localOutV->randomizeUniform(); + output->randomizeUniform(); + denoms->add(0.01); + + MatrixPtr localGradGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr denomsGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr outputGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr preOutVGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr localOutVGpu = GpuMatrix::create(numSamples, width, false, true); + + localGradGpu->copyFrom(*localGrad); + denomsGpu->copyFrom(*denoms); + preOutVGpu->copyFrom(*preOutV); + localOutVGpu->copyFrom(*localOutV); + outputGpu->copyFrom(*output); + + output->crossMapNormalBwd(*localGrad, + *denoms, + *preOutV, + *localOutV, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + outputGpu->crossMapNormalBwd(*localGradGpu, + *denomsGpu, + *preOutVGpu, + *localOutVGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + TensorCheckErr(*output, *outputGpu); +} + +TEST(Matrix, crossMapNormalBwd) { + for (auto numSamples : {5, 32}) { + for (auto channels : {1, 5, 32}) { + for (auto imgSizeH : {5, 33, 100}) { + for (auto imgSizeW : {5, 32, 96}) { + for (auto sizeX : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " sizeX=" << sizeX; + testCrossMapNormalBwd( + numSamples, channels, imgSizeH, imgSizeW, sizeX); + } + } + } + } + } +} int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); From 95035908b4f47e61bad12d0ed49bf62a1734b2cf Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 13 Dec 2016 14:27:42 +0800 Subject: [PATCH 02/15] add CrossMapNormal --- paddle/math/cross_map_normal_op.cpp | 129 +++++++++++++++++++++ paddle/math/cross_map_normal_op.h | 47 ++++++++ paddle/math/tests/test_matrixCompare.cpp | 137 ++++++++++++----------- 3 files changed, 248 insertions(+), 65 deletions(-) create mode 100644 paddle/math/cross_map_normal_op.cpp create mode 100644 paddle/math/cross_map_normal_op.h diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp new file mode 100644 index 0000000000000..3eb51b5998fc4 --- /dev/null +++ b/paddle/math/cross_map_normal_op.cpp @@ -0,0 +1,129 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "cross_map_normal_op.h" + +namespace paddle { + +// NCHW +void CrossMapNormal::operator()(CpuMatrix& outputs, + CpuMatrix& denoms, + CpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(outputs.isContiguous()); + CHECK(inputs.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK_EQ(outputs.getHeight(), inputs.getHeight()); + CHECK_EQ(outputs.getWidth(), inputs.getWidth()); + CHECK_EQ(outputs.getHeight(), denoms.getHeight()); + CHECK_EQ(outputs.getWidth(), denoms.getWidth()); + + size_t numSample = inputs.getHeight(); + size_t numCols = inputs.getWidth(); + size_t imageSize = imgSizeH * imgSizeW; + CHECK(imageSize * channels == numCols); + + denoms = denoms.constant(1.0); + const int start = -((int)sizeX - 1) / 2; + const int end = (int)sizeX + start; + for (size_t i = 0; i < numSample; i++) { + real* denomsData = denoms.getData() + i * numCols; + real* inputData = inputs.getData() + i * numCols; + for (int c = 0; c < (int)channels; c++) { + CpuVector denom(imageSize, denomsData + c * imageSize); + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + CpuVector input(imageSize, inputData + (c + s) * imageSize); + denom += input.square() * scale; + } + } + } + } + outputs = inputs * denoms.pow(-pow); +} + +void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, + CpuMatrix& inputsValue, + CpuMatrix& outputsGrad, + CpuMatrix& outputsValue, + CpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(inputsGrad.isContiguous()); + CHECK(outputsGrad.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK(inputsValue.isContiguous()); + CHECK(outputsValue.isContiguous()); + CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); + + size_t numSample = inputsGrad.getHeight(); + size_t numCols = inputsGrad.getWidth(); + size_t imageSize = imgSizeH * imgSizeW; + CHECK(imageSize * channels == numCols); + + std::function oneImage = [=](real* data, + size_t offset) { + return CpuVector(imageSize, data + offset); + }; + + const int start = -((int)sizeX) / 2; + const int end = (int)sizeX + start; + const real ratio = -(real)2 * scale * pow; + for (size_t i = 0; i < numSample; i++) { + size_t sOffset = i * numCols; + real* inputGradData = inputsGrad.getData() + sOffset; + real* inputData = inputsValue.getData() + sOffset; + real* denomData = denoms.getData() + sOffset; + real* outputGradData = outputsGrad.getData() + sOffset; + real* outputData = outputsValue.getData() + sOffset; + + for (int c = 0; c < (int)channels; c++) { + size_t cOffset = c * imageSize; + CpuVector inputGrad = oneImage(inputGradData, cOffset); + CpuVector inputValue = oneImage(inputData, cOffset); + CpuVector denom = oneImage(denomData, cOffset); + CpuVector outputGrad = oneImage(outputGradData, cOffset); + + inputGrad = inputGrad + denom.pow(-pow) * outputGrad; + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + size_t offset = (c + s) * imageSize; + CpuVector output = oneImage(outputData, offset); + CpuVector outputGrad = oneImage(outputGradData, offset); + CpuVector denom = oneImage(denomData, offset); + + inputGrad += ((outputGrad * output * ratio) / denom) * inputValue; + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h new file mode 100644 index 0000000000000..2f996072528a0 --- /dev/null +++ b/paddle/math/cross_map_normal_op.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/math/Matrix.h" + +namespace paddle { + +struct CrossMapNormal { + void operator()(CpuMatrix& outputs, + CpuMatrix& denoms, + CpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow); +}; + +struct CrossMapNormalGrad { + void operator()(CpuMatrix& inputsGrad, + CpuMatrix& inputsValue, + CpuMatrix& outputsGrad, + CpuMatrix& outputsValue, + CpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow); +}; + +} // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 5233a9af40155..9bb1fdbdab83a 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/gserver/tests/TestUtil.h" #include "paddle/utils/Stat.h" #include "TensorCheck.h" +#include "paddle/math/cross_map_normal_op.h" using namespace paddle; // NOLINT using namespace std; // NOLINT @@ -1261,30 +1262,32 @@ TEST(Matrix, MaxOutFwdBwd) { } } } + void testCrossMapNormalFwd( int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { float scale = 1.5; float pow = 0.5; int width = imgSizeH * imgSizeW * channels; - MatrixPtr input = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr denorms = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr target = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr inputGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr denormsGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr targetGpu = GpuMatrix::create(numSamples, width, false, true); - - input->randomizeUniform(); - target->randomizeUniform(); - inputGpu->copyFrom(*input); - targetGpu->copyFrom(*target); - - target->crossMapNormalFwd( - *input, imgSizeH, imgSizeW, *denorms, channels, sizeX, scale, pow); - targetGpu->crossMapNormalFwd( - *inputGpu, imgSizeH, imgSizeW, *denormsGpu, channels, sizeX, scale, pow); - - TensorCheckErr(*target, *targetGpu); - TensorCheckErr(*denorms, *denormsGpu); + CpuMatrix inputs(numSamples, width); + CpuMatrix denoms(numSamples, width); + CpuMatrix outputs(numSamples, width); + GpuMatrix inputsGpu(numSamples, width); + GpuMatrix denomsGpu(numSamples, width); + GpuMatrix outputsGpu(numSamples, width); + + inputs.randomizeUniform(); + outputs.randomizeUniform(); + inputsGpu.copyFrom(inputs); + outputsGpu.copyFrom(outputs); + + CrossMapNormal cross; + cross( + outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); + outputsGpu.crossMapNormalFwd( + inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow); + + TensorCheckErr(outputs, outputsGpu); + TensorCheckErr(denoms, denomsGpu); } TEST(Matrix, crossMapNormalFwd) { @@ -1310,53 +1313,57 @@ void testCrossMapNormalBwd( float scale = 1.5; float pow = 0.5; size_t width = imgSizeH * imgSizeW * channels; - MatrixPtr localGrad = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr denoms = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr output = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr preOutV = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr localOutV = CpuMatrix::create(numSamples, width, false, false); - - localGrad->randomizeUniform(); - denoms->randomizeUniform(); - preOutV->randomizeUniform(); - localOutV->randomizeUniform(); - output->randomizeUniform(); - denoms->add(0.01); - - MatrixPtr localGradGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr denomsGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr outputGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr preOutVGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr localOutVGpu = GpuMatrix::create(numSamples, width, false, true); - - localGradGpu->copyFrom(*localGrad); - denomsGpu->copyFrom(*denoms); - preOutVGpu->copyFrom(*preOutV); - localOutVGpu->copyFrom(*localOutV); - outputGpu->copyFrom(*output); - output->crossMapNormalBwd(*localGrad, - *denoms, - *preOutV, - *localOutV, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - outputGpu->crossMapNormalBwd(*localGradGpu, - *denomsGpu, - *preOutVGpu, - *localOutVGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - - TensorCheckErr(*output, *outputGpu); + CpuMatrix inputsGrad(numSamples, width); + CpuMatrix inputsValue(numSamples, width); + CpuMatrix outputsGrad(numSamples, width); + CpuMatrix outputsValue(numSamples, width); + CpuMatrix denoms(numSamples, width); + + outputsGrad.randomizeUniform(); + denoms.randomizeUniform(); + inputsValue.randomizeUniform(); + outputsValue.randomizeUniform(); + inputsGrad.randomizeUniform(); + denoms.add(0.01); + + GpuMatrix inputsGradGpu(numSamples, width); + GpuMatrix inputsValueGpu(numSamples, width); + GpuMatrix outputsGradGpu(numSamples, width); + GpuMatrix outputsValueGpu(numSamples, width); + GpuMatrix denomsGpu(numSamples, width); + + outputsGradGpu.copyFrom(outputsGrad); + denomsGpu.copyFrom(denoms); + inputsValueGpu.copyFrom(inputsValue); + outputsValueGpu.copyFrom(outputsValue); + inputsGradGpu.copyFrom(inputsGrad); + + CrossMapNormalGrad cross; + cross(inputsGrad, + inputsValue, + outputsGrad, + outputsValue, + denoms, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + inputsGradGpu.crossMapNormalBwd(outputsGradGpu, + denomsGpu, + inputsValueGpu, + outputsValueGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + TensorCheckErr(inputsGrad, inputsGradGpu); } TEST(Matrix, crossMapNormalBwd) { From e357f2715843cd531ce0b0143647ed5561d2fceb Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 13 Dec 2016 17:55:31 +0800 Subject: [PATCH 03/15] add GPU CrossMapNormal --- paddle/math/cross_map_normal_op.cpp | 42 ++--- paddle/math/cross_map_normal_op.h | 37 ++++- paddle/math/cross_map_normal_op_gpu.cu | 194 +++++++++++++++++++++++ paddle/math/tests/test_matrixCompare.cpp | 66 +++++--- 4 files changed, 286 insertions(+), 53 deletions(-) create mode 100644 paddle/math/cross_map_normal_op_gpu.cu diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index 3eb51b5998fc4..be242926aff16 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -17,15 +17,16 @@ limitations under the License. */ namespace paddle { // NCHW -void CrossMapNormal::operator()(CpuMatrix& outputs, - CpuMatrix& denoms, - CpuMatrix& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { +template <> +void CrossMapNormal::operator()(CpuMatrix& outputs, + CpuMatrix& denoms, + CpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { CHECK(outputs.isContiguous()); CHECK(inputs.isContiguous()); CHECK(denoms.isContiguous()); @@ -58,17 +59,18 @@ void CrossMapNormal::operator()(CpuMatrix& outputs, outputs = inputs * denoms.pow(-pow); } -void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, - CpuMatrix& inputsValue, - CpuMatrix& outputsGrad, - CpuMatrix& outputsValue, - CpuMatrix& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { +template <> +void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, + CpuMatrix& inputsValue, + CpuMatrix& outputsGrad, + CpuMatrix& outputsValue, + CpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { CHECK(inputsGrad.isContiguous()); CHECK(outputsGrad.isContiguous()); CHECK(denoms.isContiguous()); diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index 2f996072528a0..c2bb95f6b11fb 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -18,10 +18,30 @@ limitations under the License. */ namespace paddle { +enum DeviceType { + DEVICE_TYPE_UNSPECIFIED = 0, + DEVICE_TYPE_CPU = 1, + DEVICE_TYPE_GPU = 2, +}; + +template +struct MatrixT; + +template <> +struct MatrixT { + using type = CpuMatrix; +}; + +template <> +struct MatrixT { + using type = GpuMatrix; +}; + +template struct CrossMapNormal { - void operator()(CpuMatrix& outputs, - CpuMatrix& denoms, - CpuMatrix& inputs, + void operator()(typename MatrixT::type& outputs, + typename MatrixT::type& denoms, + typename MatrixT::type& inputs, size_t channels, size_t imgSizeH, size_t imgSizeW, @@ -30,12 +50,13 @@ struct CrossMapNormal { real pow); }; +template struct CrossMapNormalGrad { - void operator()(CpuMatrix& inputsGrad, - CpuMatrix& inputsValue, - CpuMatrix& outputsGrad, - CpuMatrix& outputsValue, - CpuMatrix& denoms, + void operator()(typename MatrixT::type& inputsGrad, + typename MatrixT::type& inputsValue, + typename MatrixT::type& outputsGrad, + typename MatrixT::type& outputsValue, + typename MatrixT::type& denoms, size_t channels, size_t imgSizeH, size_t imgSizeW, diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/math/cross_map_normal_op_gpu.cu new file mode 100644 index 0000000000000..0a154d97ac02f --- /dev/null +++ b/paddle/math/cross_map_normal_op_gpu.cu @@ -0,0 +1,194 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "cross_map_normal_op.h" + +namespace paddle { + +__global__ void KeCMRNormFillScale(size_t imageSize, const real* in, + real* scale, size_t channels, + size_t height, size_t width, size_t size, + real alpha) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + + in += offset; + scale += offset; + const int step = height * width; + const int pre_pad = (size - 1) / 2; + const int post_pad = size - pre_pad - 1; + + real accum = 0; + int index = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += in[index * step] * in[index * step]; + } + if (index >= size) { + accum -= in[(index - size) * step] * in[(index - size) * step]; + } + if (index >= post_pad) { + scale[(index - post_pad) * step] = 1. + accum * alpha; + } + ++index; + } + } +} + +__global__ void KeCMRNormOutput(size_t inputSize, const real* in, + const real* scale, real negative_beta, + real* out) { + const int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < inputSize) { + out[index] = in[index] * pow(scale[index], negative_beta); + } +} + +template <> +void CrossMapNormal::operator()(GpuMatrix& outputs, + GpuMatrix& denoms, + GpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(outputs.isContiguous()); + CHECK(inputs.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK_EQ(outputs.getHeight(), inputs.getHeight()); + CHECK_EQ(outputs.getWidth(), inputs.getWidth()); + CHECK_EQ(outputs.getHeight(), denoms.getHeight()); + CHECK_EQ(outputs.getWidth(), denoms.getWidth()); + + size_t numSample = inputs.getHeight(); + size_t numCols = inputs.getWidth(); + CHECK(imgSizeH * imgSizeW * channels == numCols); + + real* inputsData = inputs.getData(); + real* denomsData = denoms.getData(); + real* outputsData = outputs.getData(); + + size_t imageSize = numSample * imgSizeH * imgSizeW; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormFillScale<<>> + (imageSize, inputsData, denomsData, + channels, imgSizeH, imgSizeW, sizeX, scale); + + size_t inputSize = numSample * imgSizeH * imgSizeW *channels; + blockSize = 1024; + gridSize = (inputSize + 1024 - 1) / 1024; + KeCMRNormOutput<<>> + (inputSize, inputsData, denomsData, -pow, outputsData); + + CHECK_SYNC("CrossMapNormalFwd"); +} + +__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, + const real* top_data, const real* scale, + const real* top_diff, size_t channels, + size_t height, size_t width, size_t size, + real negative_beta, real cache_ratio, + real* bottom_diff ) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + bottom_data += offset; + top_data += offset; + scale += offset; + top_diff += offset; + bottom_diff += offset; + + const int step = height * width; + const int pre_pad = size - (size + 1) / 2; + const int post_pad = size - pre_pad - 1; + + int index = 0; + real accum = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += top_diff[index * step] * top_data[index * step] / + scale[index * step]; + } + if (index >= size) { + accum -= top_diff[(index - size) * step] * + top_data[(index - size) * step] / scale[(index - size) * step]; + } + if (index >= post_pad) { + bottom_diff[(index - post_pad) * step] += + top_diff[(index - post_pad) * step] * + pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(index - post_pad) * step] * accum; + } + ++index; + } + } +} + +template <> +void CrossMapNormalGrad::operator()(GpuMatrix& inputsGrad, + GpuMatrix& inputsValue, + GpuMatrix& outputsGrad, + GpuMatrix& outputsValue, + GpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(inputsGrad.isContiguous()); + CHECK(outputsGrad.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK(inputsValue.isContiguous()); + CHECK(outputsValue.isContiguous()); + CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); + + size_t numSample = inputsGrad.getHeight(); + size_t numCols = inputsGrad.getWidth(); + CHECK(imgSizeH * imgSizeW * channels == numCols); + + size_t imageSize = numSample * imgSizeH * imgSizeW; + real* inputsGradData = inputsGrad.getData(); + real* inputsData = inputsValue.getData(); + real* denomsData = denoms.getData(); + real* outputsGradData = outputsGrad.getData(); + real* outputsData = outputsValue.getData(); + + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormDiff <<>> + (imageSize, inputsData, outputsData, denomsData, outputsGradData, channels, + imgSizeH, imgSizeW, sizeX, -pow, 2.0f * pow * scale, inputsGradData); + CHECK_SYNC("KeCMRNormDiff"); +} + +} // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 9bb1fdbdab83a..8d7a4fb94d0a1 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1280,11 +1280,25 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); - CrossMapNormal cross; - cross( + CrossMapNormal cpuCross; + cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); + + CrossMapNormal gpuCross; + gpuCross(outputsGpu, + denomsGpu, + inputsGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + +#if 0 outputsGpu.crossMapNormalFwd( inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow); +#endif TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); @@ -1339,29 +1353,31 @@ void testCrossMapNormalBwd( outputsValueGpu.copyFrom(outputsValue); inputsGradGpu.copyFrom(inputsGrad); - CrossMapNormalGrad cross; - cross(inputsGrad, - inputsValue, - outputsGrad, - outputsValue, - denoms, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - - inputsGradGpu.crossMapNormalBwd(outputsGradGpu, - denomsGpu, - inputsValueGpu, - outputsValueGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); + CrossMapNormalGrad cpuCross; + cpuCross(inputsGrad, + inputsValue, + outputsGrad, + outputsValue, + denoms, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + CrossMapNormalGrad gpuCross; + gpuCross(inputsGradGpu, + inputsValueGpu, + outputsGradGpu, + outputsValueGpu, + denomsGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); TensorCheckErr(inputsGrad, inputsGradGpu); } From a1d2abc16d9c7b42af6dcb41902423ae2904ee9a Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 14 Dec 2016 18:46:40 +0800 Subject: [PATCH 04/15] add Function --- paddle/math/Function.cpp | 47 +++++++++++++ paddle/math/Function.h | 84 ++++++++++++++++++++++++ paddle/math/cross_map_normal_op.cpp | 46 +++++++++++++ paddle/math/cross_map_normal_op.h | 20 +----- paddle/math/tests/test_matrixCompare.cpp | 15 +++-- 5 files changed, 188 insertions(+), 24 deletions(-) create mode 100644 paddle/math/Function.cpp create mode 100644 paddle/math/Function.h diff --git a/paddle/math/Function.cpp b/paddle/math/Function.cpp new file mode 100644 index 0000000000000..21d2719172870 --- /dev/null +++ b/paddle/math/Function.cpp @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" + +namespace paddle { + +template <> +size_t FuncConfig::get(const std::string& key) const { + auto it = valueMap_.find(key); + CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'"; + return it->second.s; +} + +template <> +real FuncConfig::get(const std::string& key) const { + auto it = valueMap_.find(key); + CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'"; + return it->second.r; +} + +template <> +void FuncConfig::set(const std::string& key, size_t v) { + CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; + valueMap_[key].s = v; +} + +template <> +void FuncConfig::set(const std::string& key, real v) { + CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; + valueMap_[key].r = v; +} + +ClassRegistrar FunctionBase::funcRegistrar_; + +} // namespace paddle diff --git a/paddle/math/Function.h b/paddle/math/Function.h new file mode 100644 index 0000000000000..b41ba2a13d377 --- /dev/null +++ b/paddle/math/Function.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/utils/ClassRegistrar.h" +#include "paddle/math/Matrix.h" + +namespace paddle { + +enum DeviceType { + DEVICE_TYPE_UNSPECIFIED = 0, + DEVICE_TYPE_CPU = 1, + DEVICE_TYPE_GPU = 2, +}; + +template +struct MatrixT; + +template <> +struct MatrixT { + using type = CpuMatrix; +}; + +template <> +struct MatrixT { + using type = GpuMatrix; +}; + +typedef std::vector Arguments; + +class FuncConfig { +public: + union value { + size_t s; + real r; + }; + + template + T get(const std::string& key) const; + + template + void set(const std::string& key, T v); + +protected: + std::map valueMap_; +}; + +class FunctionBase { +public: + virtual ~FunctionBase() {} + + virtual void init(const FuncConfig& config) {} + + virtual void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) {} + + static ClassRegistrar funcRegistrar_; +}; + +#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName + +#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ + static InitFunction __reg_type_##typeName([]() { \ + FunctionBase::funcRegistrar_ \ + .registerClass>( \ + FUNC_NAME(typeName, deviceName)); \ + }) + +} // namespace paddle diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index be242926aff16..0b7273206381d 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -128,4 +128,50 @@ void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, } } +template +class CrossMapNormalFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + size_ = config.get("size"); + scale_ = config.get("scale"); + pow_ = config.get("pow"); + } + + void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) override { + CHECK_EQ(1, inputs.size()); + CHECK_EQ(2, outputs.size()); + CHECK_EQ(0, inouts.size()); + + auto input = dynamic_cast::type&>(inputs[0]); + auto output = + dynamic_cast::type&>(outputs[0]); + auto denom = + dynamic_cast::type&>(outputs[1]); + + CHECK(input.isContiguous()); + CHECK(output.isContiguous()); + CHECK(denom.isContiguous()); + CHECK_EQ(output.getHeight(), input.getHeight()); + CHECK_EQ(output.getWidth(), input.getWidth()); + CHECK_EQ(output.getHeight(), denom.getHeight()); + CHECK_EQ(output.getWidth(), denom.getWidth()); + + // CrossMapNormal cross; + // need: + // size_t channels, + // size_t imgSizeH, + // size_t imgSizeW, + // cross(output, denom, input, ); + } + +private: + size_t size_; + real scale_; + real pow_; +}; + +REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); + } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index c2bb95f6b11fb..86f54abde108d 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -14,29 +14,11 @@ limitations under the License. */ #pragma once +#include "Function.h" #include "paddle/math/Matrix.h" namespace paddle { -enum DeviceType { - DEVICE_TYPE_UNSPECIFIED = 0, - DEVICE_TYPE_CPU = 1, - DEVICE_TYPE_GPU = 2, -}; - -template -struct MatrixT; - -template <> -struct MatrixT { - using type = CpuMatrix; -}; - -template <> -struct MatrixT { - using type = GpuMatrix; -}; - template struct CrossMapNormal { void operator()(typename MatrixT::type& outputs, diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 8d7a4fb94d0a1..0b75785528f5d 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/utils/Stat.h" #include "TensorCheck.h" #include "paddle/math/cross_map_normal_op.h" +#include "paddle/math/Function.h" using namespace paddle; // NOLINT using namespace std; // NOLINT @@ -1280,6 +1281,15 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); + FuncConfig config; + config.set("size", (size_t)sizeX); + config.set("scale", scale); + config.set("pow", pow); + FunctionBase* cpu = + FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); + cpu->init(config); + // cpu->calc(); + CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); @@ -1295,11 +1305,6 @@ void testCrossMapNormalFwd( scale, pow); -#if 0 - outputsGpu.crossMapNormalFwd( - inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow); -#endif - TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); } From ce1d98e083017afadac9fcd9f94f5c59aceaf6c0 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 10:31:45 +0800 Subject: [PATCH 05/15] Add a Tensor to use as a Function argument --- paddle/math/Function.h | 12 +++++++- paddle/math/cross_map_normal_op.cpp | 37 +++++++++++------------- paddle/math/tests/test_matrixCompare.cpp | 9 ++++-- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/paddle/math/Function.h b/paddle/math/Function.h index b41ba2a13d377..539759782be3a 100644 --- a/paddle/math/Function.h +++ b/paddle/math/Function.h @@ -40,7 +40,17 @@ struct MatrixT { using type = GpuMatrix; }; -typedef std::vector Arguments; +typedef std::vector Dims; + +class Tensor { +public: + Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {} + + real* buf_; + Dims dims_; +}; + +typedef std::vector Arguments; class FuncConfig { public: diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index 0b7273206381d..d55bd78c628f7 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -144,26 +144,23 @@ class CrossMapNormalFunc : public FunctionBase { CHECK_EQ(2, outputs.size()); CHECK_EQ(0, inouts.size()); - auto input = dynamic_cast::type&>(inputs[0]); - auto output = - dynamic_cast::type&>(outputs[0]); - auto denom = - dynamic_cast::type&>(outputs[1]); - - CHECK(input.isContiguous()); - CHECK(output.isContiguous()); - CHECK(denom.isContiguous()); - CHECK_EQ(output.getHeight(), input.getHeight()); - CHECK_EQ(output.getWidth(), input.getWidth()); - CHECK_EQ(output.getHeight(), denom.getHeight()); - CHECK_EQ(output.getWidth(), denom.getWidth()); - - // CrossMapNormal cross; - // need: - // size_t channels, - // size_t imgSizeH, - // size_t imgSizeW, - // cross(output, denom, input, ); + CHECK_EQ(inputs[0].dims_.size(), 4); + for (size_t i = 0; i < inputs[0].dims_.size(); i++) { + CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]); + } + + size_t samples = inputs[0].dims_[0]; + size_t channels = inputs[0].dims_[1]; + size_t height = inputs[0].dims_[2]; + size_t width = inputs[0].dims_[3]; + size_t imageSize = channels * height * width; + CpuMatrix input(inputs[0].buf_, samples, imageSize); + CpuMatrix output(outputs[0].buf_, samples, imageSize); + CpuMatrix denom(outputs[1].buf_, samples, imageSize); + + CrossMapNormal cross; + cross(output, denom, input, channels, height, width, size_, scale_, pow_); } private: diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 0b75785528f5d..cd34ea18a70ea 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1288,12 +1288,17 @@ void testCrossMapNormalFwd( FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); cpu->init(config); - // cpu->calc(); + Dims dims{ + (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; + cpu->calc({Tensor(inputs.getData(), dims)}, + {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, + {}); +#if 0 CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); - +#endif CrossMapNormal gpuCross; gpuCross(outputsGpu, denomsGpu, From 4ebb3eb759903bf95968b578eec99b1364d3bd10 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 11:55:35 +0800 Subject: [PATCH 06/15] imporve Function --- paddle/gserver/layers/NormProjectionLayer.cpp | 60 +++++++++++---- paddle/gserver/layers/NormProjectionLayer.h | 4 + paddle/math/Function.cpp | 6 +- paddle/math/Function.h | 14 ++-- paddle/math/cross_map_normal_op.cpp | 75 ++++++++++--------- paddle/math/cross_map_normal_op.h | 13 ++++ paddle/math/cross_map_normal_op_gpu.cu | 46 ++++-------- paddle/math/tests/test_matrixCompare.cpp | 21 +++++- 8 files changed, 147 insertions(+), 92 deletions(-) diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index ea301292e0dcc..5dda7ee205f40 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" +#include "paddle/math/cross_map_normal_op.h" #include "NormProjectionLayer.h" namespace paddle { @@ -45,6 +46,16 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, /* the size of inputs for norm-layer is 1 */ CHECK_EQ(config_.inputs_size(), 1); + if (useGpu_) { + normal_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormal, GPU)); + } else { + normal_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormal, CPU)); + } + normal_->init( + FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); + return true; } @@ -62,10 +73,14 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - denoms_->zeroMem(); - - outV->crossMapNormalFwd( - *input, imgSizeH_, imgSizeW_, *denoms_, channels_, size_, scale_, pow_); + Dims dims{(size_t)batchSize, + (size_t)channels_, + (size_t)imgSizeH_, + (size_t)imgSizeW_}; + normal_->calc( + {Tensor(input->getData(), dims)}, + {Tensor(outV->getData(), dims), Tensor(denoms_->getData(), dims)}, + {}); } void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { @@ -80,15 +95,32 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { MatrixPtr localOutV = getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); - preOutGrad->crossMapNormalBwd(*localGrad, - *denoms_, - *preOutV, - *localOutV, - channels_, - imgSizeH_, - imgSizeW_, - size_, - scale_, - pow_); + if (useGpu_) { + CrossMapNormalGrad crossGrad; + crossGrad(dynamic_cast(*preOutGrad), + dynamic_cast(*preOutV), + dynamic_cast(*localGrad), + dynamic_cast(*localOutV), + dynamic_cast(*denoms_), + channels_, + imgSizeH_, + imgSizeW_, + size_, + scale_, + pow_); + } else { + CrossMapNormalGrad crossGrad; + crossGrad(dynamic_cast(*preOutGrad), + dynamic_cast(*preOutV), + dynamic_cast(*localGrad), + dynamic_cast(*localOutV), + dynamic_cast(*denoms_), + channels_, + imgSizeH_, + imgSizeW_, + size_, + scale_, + pow_); + } } } // namespace paddle diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 0db8e2551f06d..ea44669be3f8b 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "NormLayer.h" #include "paddle/math/Matrix.h" +#include "paddle/math/Function.h" #include namespace paddle { @@ -39,5 +40,8 @@ class CMRProjectionNormLayer : public ResponseNormLayer { bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType); void backward(const UpdateCallback& callback = nullptr); + +protected: + FunctionBase* normal_; }; } // namespace paddle diff --git a/paddle/math/Function.cpp b/paddle/math/Function.cpp index 21d2719172870..02880e5ea1acb 100644 --- a/paddle/math/Function.cpp +++ b/paddle/math/Function.cpp @@ -31,15 +31,17 @@ real FuncConfig::get(const std::string& key) const { } template <> -void FuncConfig::set(const std::string& key, size_t v) { +FuncConfig& FuncConfig::set(const std::string& key, size_t v) { CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; valueMap_[key].s = v; + return *this; } template <> -void FuncConfig::set(const std::string& key, real v) { +FuncConfig& FuncConfig::set(const std::string& key, real v) { CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; valueMap_[key].r = v; + return *this; } ClassRegistrar FunctionBase::funcRegistrar_; diff --git a/paddle/math/Function.h b/paddle/math/Function.h index 539759782be3a..f8fab972a6902 100644 --- a/paddle/math/Function.h +++ b/paddle/math/Function.h @@ -46,6 +46,8 @@ class Tensor { public: Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {} + real* getData() const { return buf_; } + real* buf_; Dims dims_; }; @@ -63,7 +65,7 @@ class FuncConfig { T get(const std::string& key) const; template - void set(const std::string& key, T v); + FuncConfig& set(const std::string& key, T v); protected: std::map valueMap_; @@ -84,11 +86,11 @@ class FunctionBase { #define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName -#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ - static InitFunction __reg_type_##typeName([]() { \ - FunctionBase::funcRegistrar_ \ - .registerClass>( \ - FUNC_NAME(typeName, deviceName)); \ +#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ + static InitFunction __reg_type_##typeName##deviceName([]() { \ + FunctionBase::funcRegistrar_ \ + .registerClass>( \ + FUNC_NAME(typeName, deviceName)); \ }) } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index d55bd78c628f7..e520351d2e3b8 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -18,45 +18,41 @@ namespace paddle { // NCHW template <> -void CrossMapNormal::operator()(CpuMatrix& outputs, - CpuMatrix& denoms, - CpuMatrix& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(outputs.isContiguous()); - CHECK(inputs.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK_EQ(outputs.getHeight(), inputs.getHeight()); - CHECK_EQ(outputs.getWidth(), inputs.getWidth()); - CHECK_EQ(outputs.getHeight(), denoms.getHeight()); - CHECK_EQ(outputs.getWidth(), denoms.getWidth()); - - size_t numSample = inputs.getHeight(); - size_t numCols = inputs.getWidth(); - size_t imageSize = imgSizeH * imgSizeW; - CHECK(imageSize * channels == numCols); - - denoms = denoms.constant(1.0); - const int start = -((int)sizeX - 1) / 2; - const int end = (int)sizeX + start; - for (size_t i = 0; i < numSample; i++) { - real* denomsData = denoms.getData() + i * numCols; - real* inputData = inputs.getData() + i * numCols; +void CrossMapNormal(real* outputs, + real* denoms, + real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t oneImage = height * width; + size_t oneSample = channels * oneImage; + + CpuVector outputsV(numSamples * oneSample, outputs); + CpuVector inputsV(numSamples * oneSample, inputs); + CpuVector denomsV(numSamples * oneSample, denoms); + + denomsV = denomsV.constant(1.0); + const int start = -((int)size - 1) / 2; + const int end = (int)size + start; + for (size_t i = 0; i < numSamples; i++) { + real* oneDenom = denoms + i * oneSample; + real* oneInput = inputs + i * oneSample; for (int c = 0; c < (int)channels; c++) { - CpuVector denom(imageSize, denomsData + c * imageSize); + CpuVector denom(oneImage, oneDenom + c * oneImage); for (int s = start; s < end; s++) { if (c + s >= 0 && c + s < (int)channels) { - CpuVector input(imageSize, inputData + (c + s) * imageSize); + CpuVector input(oneImage, oneInput + (c + s) * oneImage); denom += input.square() * scale; } } } } - outputs = inputs * denoms.pow(-pow); + + outputsV = inputsV * denomsV.pow(-pow); } template <> @@ -154,13 +150,17 @@ class CrossMapNormalFunc : public FunctionBase { size_t channels = inputs[0].dims_[1]; size_t height = inputs[0].dims_[2]; size_t width = inputs[0].dims_[3]; - size_t imageSize = channels * height * width; - CpuMatrix input(inputs[0].buf_, samples, imageSize); - CpuMatrix output(outputs[0].buf_, samples, imageSize); - CpuMatrix denom(outputs[1].buf_, samples, imageSize); - CrossMapNormal cross; - cross(output, denom, input, channels, height, width, size_, scale_, pow_); + CrossMapNormal(outputs[0].getData(), + outputs[1].getData(), + inputs[0].getData(), + samples, + channels, + height, + width, + size_, + scale_, + pow_); } private: @@ -170,5 +170,6 @@ class CrossMapNormalFunc : public FunctionBase { }; REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); +REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index 86f54abde108d..ef9533485ec9c 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -19,6 +19,18 @@ limitations under the License. */ namespace paddle { +template +void CrossMapNormal(real* outputs, + real* denoms, + real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow); +#if 0 template struct CrossMapNormal { void operator()(typename MatrixT::type& outputs, @@ -31,6 +43,7 @@ struct CrossMapNormal { real scale, real pow); }; +#endif template struct CrossMapNormalGrad { diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/math/cross_map_normal_op_gpu.cu index 0a154d97ac02f..9b92974344955 100644 --- a/paddle/math/cross_map_normal_op_gpu.cu +++ b/paddle/math/cross_map_normal_op_gpu.cu @@ -61,45 +61,29 @@ __global__ void KeCMRNormOutput(size_t inputSize, const real* in, } template <> -void CrossMapNormal::operator()(GpuMatrix& outputs, - GpuMatrix& denoms, - GpuMatrix& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(outputs.isContiguous()); - CHECK(inputs.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK_EQ(outputs.getHeight(), inputs.getHeight()); - CHECK_EQ(outputs.getWidth(), inputs.getWidth()); - CHECK_EQ(outputs.getHeight(), denoms.getHeight()); - CHECK_EQ(outputs.getWidth(), denoms.getWidth()); - - size_t numSample = inputs.getHeight(); - size_t numCols = inputs.getWidth(); - CHECK(imgSizeH * imgSizeW * channels == numCols); - - real* inputsData = inputs.getData(); - real* denomsData = denoms.getData(); - real* outputsData = outputs.getData(); - - size_t imageSize = numSample * imgSizeH * imgSizeW; +void CrossMapNormal(real* outputs, + real* denoms, + real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t imageSize = numSamples * height * width; int blockSize = 1024; int gridSize = (imageSize + 1024 - 1) / 1024; KeCMRNormFillScale<<>> - (imageSize, inputsData, denomsData, - channels, imgSizeH, imgSizeW, sizeX, scale); + (imageSize, inputs, denoms, channels, height, width, size, scale); - size_t inputSize = numSample * imgSizeH * imgSizeW *channels; + size_t inputSize = numSamples * height * width *channels; blockSize = 1024; gridSize = (inputSize + 1024 - 1) / 1024; KeCMRNormOutput<<>> - (inputSize, inputsData, denomsData, -pow, outputsData); + (inputSize, inputs, denoms, -pow, outputs); - CHECK_SYNC("CrossMapNormalFwd"); + CHECK_SYNC("CrossMapNormal"); } __global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index cd34ea18a70ea..aac3f757996eb 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1281,24 +1281,40 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); +#if 0 FuncConfig config; config.set("size", (size_t)sizeX); config.set("scale", scale); config.set("pow", pow); +#endif FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); - cpu->init(config); + FunctionBase* gpu = + FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, GPU)); + cpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); + gpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); Dims dims{ (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; cpu->calc({Tensor(inputs.getData(), dims)}, {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, {}); + + gpu->calc( + {Tensor(inputsGpu.getData(), dims)}, + {Tensor(outputsGpu.getData(), dims), Tensor(denomsGpu.getData(), dims)}, + {}); #if 0 CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); -#endif + CrossMapNormal gpuCross; gpuCross(outputsGpu, denomsGpu, @@ -1309,6 +1325,7 @@ void testCrossMapNormalFwd( sizeX, scale, pow); +#endif TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); From d2d0010609b6ba621360973b6c6972b836607de3 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 16:19:10 +0800 Subject: [PATCH 07/15] add CrossMapNormalGradFunc --- paddle/gserver/layers/NormProjectionLayer.cpp | 41 +++-- paddle/gserver/layers/NormProjectionLayer.h | 7 +- paddle/math/Function.h | 2 +- paddle/math/cross_map_normal_op.cpp | 145 ++++++++++++------ paddle/math/cross_map_normal_op.h | 40 ++--- paddle/math/cross_map_normal_op_gpu.cu | 54 ++----- paddle/math/tests/test_matrixCompare.cpp | 57 ++++--- 7 files changed, 190 insertions(+), 156 deletions(-) diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index 03c6952c30b0c..d6923c2192cf1 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -13,10 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "NormProjectionLayer.h" +#include "paddle/math/cross_map_normal_op.h" #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" -#include "paddle/math/cross_map_normal_op.h" -#include "NormProjectionLayer.h" namespace paddle { size_t CMRProjectionNormLayer::getSize() { @@ -48,13 +47,23 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, CHECK_EQ(config_.inputs_size(), 1); if (useGpu_) { - normal_ = FunctionBase::funcRegistrar_.createByType( + forward_ = FunctionBase::funcRegistrar_.createByType( FUNC_NAME(CrossMapNormal, GPU)); } else { - normal_ = FunctionBase::funcRegistrar_.createByType( + forward_ = FunctionBase::funcRegistrar_.createByType( FUNC_NAME(CrossMapNormal, CPU)); } - normal_->init( + forward_->init( + FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); + + if (useGpu_) { + backward_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, GPU)); + } else { + backward_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, CPU)); + } + backward_->init( FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); return true; @@ -74,13 +83,13 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - Dims dims{(size_t)batchSize, - (size_t)channels_, - (size_t)imgSizeH_, - (size_t)imgSizeW_}; - normal_->calc( - {Tensor(input->getData(), dims)}, - {Tensor(outV->getData(), dims), Tensor(denoms_->getData(), dims)}, + dims_ = {(size_t)batchSize, + (size_t)channels_, + (size_t)imgSizeH_, + (size_t)imgSizeW_}; + forward_->calc( + {Tensor(input->getData(), dims_)}, + {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, {}); } @@ -96,6 +105,13 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { MatrixPtr localOutV = getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); + backward_->calc({Tensor(preOutV->getData(), dims_), + Tensor(localOutV->getData(), dims_), + Tensor(localGrad->getData(), dims_), + Tensor(denoms_->getData(), dims_)}, + {Tensor(preOutGrad->getData(), dims_)}, + {}); +#if 0 if (useGpu_) { CrossMapNormalGrad crossGrad; crossGrad(dynamic_cast(*preOutGrad), @@ -123,5 +139,6 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { scale_, pow_); } +#endif } } // namespace paddle diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 1dc3921283ca7..82aa427f8d425 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,9 +16,8 @@ limitations under the License. */ #include #include "NormLayer.h" -#include "paddle/math/Matrix.h" #include "paddle/math/Function.h" -#include +#include "paddle/math/Matrix.h" namespace paddle { @@ -43,6 +42,8 @@ class CMRProjectionNormLayer : public ResponseNormLayer { void backward(const UpdateCallback& callback = nullptr); protected: - FunctionBase* normal_; + Dims dims_; + FunctionBase* forward_; + FunctionBase* backward_; }; } // namespace paddle diff --git a/paddle/math/Function.h b/paddle/math/Function.h index f8fab972a6902..095584c0b19f7 100644 --- a/paddle/math/Function.h +++ b/paddle/math/Function.h @@ -16,8 +16,8 @@ limitations under the License. */ #include #include -#include "paddle/utils/ClassRegistrar.h" #include "paddle/math/Matrix.h" +#include "paddle/utils/ClassRegistrar.h" namespace paddle { diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index e520351d2e3b8..8547978c99160 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "cross_map_normal_op.h" +#include "paddle/math/Vector.h" namespace paddle { @@ -56,66 +57,49 @@ void CrossMapNormal(real* outputs, } template <> -void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, - CpuMatrix& inputsValue, - CpuMatrix& outputsGrad, - CpuMatrix& outputsValue, - CpuMatrix& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(inputsGrad.isContiguous()); - CHECK(outputsGrad.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK(inputsValue.isContiguous()); - CHECK(outputsValue.isContiguous()); - CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); - - size_t numSample = inputsGrad.getHeight(); - size_t numCols = inputsGrad.getWidth(); - size_t imageSize = imgSizeH * imgSizeW; - CHECK(imageSize * channels == numCols); - +void CrossMapNormalGrad(real* inputsGrad, + real* inputsValue, + real* outputsValue, + real* outputsGrad, + real* denoms, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t oneSample = channels * height * width; std::function oneImage = [=](real* data, size_t offset) { - return CpuVector(imageSize, data + offset); + return CpuVector(height * width, data + offset); }; - const int start = -((int)sizeX) / 2; - const int end = (int)sizeX + start; + const int start = -((int)size) / 2; + const int end = (int)size + start; const real ratio = -(real)2 * scale * pow; - for (size_t i = 0; i < numSample; i++) { - size_t sOffset = i * numCols; - real* inputGradData = inputsGrad.getData() + sOffset; - real* inputData = inputsValue.getData() + sOffset; - real* denomData = denoms.getData() + sOffset; - real* outputGradData = outputsGrad.getData() + sOffset; - real* outputData = outputsValue.getData() + sOffset; + for (size_t i = 0; i < numSamples; i++) { + size_t sOffset = i * oneSample; + real* oneInputGrad = inputsGrad + sOffset; + real* oneInputValue = inputsValue + sOffset; + real* oneDenom = denoms + sOffset; + real* oneOutputGrad = outputsGrad + sOffset; + real* oneOutputValue = outputsValue + sOffset; for (int c = 0; c < (int)channels; c++) { - size_t cOffset = c * imageSize; - CpuVector inputGrad = oneImage(inputGradData, cOffset); - CpuVector inputValue = oneImage(inputData, cOffset); - CpuVector denom = oneImage(denomData, cOffset); - CpuVector outputGrad = oneImage(outputGradData, cOffset); + size_t cOffset = c * height * width; + CpuVector inputGrad = oneImage(oneInputGrad, cOffset); + CpuVector inputValue = oneImage(oneInputValue, cOffset); + CpuVector denom = oneImage(oneDenom, cOffset); + CpuVector outputGrad = oneImage(oneOutputGrad, cOffset); inputGrad = inputGrad + denom.pow(-pow) * outputGrad; for (int s = start; s < end; s++) { if (c + s >= 0 && c + s < (int)channels) { - size_t offset = (c + s) * imageSize; - CpuVector output = oneImage(outputData, offset); - CpuVector outputGrad = oneImage(outputGradData, offset); - CpuVector denom = oneImage(denomData, offset); + size_t offset = (c + s) * height * width; + CpuVector output = oneImage(oneOutputValue, offset); + CpuVector outputGrad = oneImage(oneOutputGrad, offset); + CpuVector denom = oneImage(oneDenom, offset); inputGrad += ((outputGrad * output * ratio) / denom) * inputValue; } @@ -124,6 +108,11 @@ void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, } } +/** + * \param inputs[0] input value. + * \param outputs[0] output value. + * \param outputs[1] denoms. + */ template class CrossMapNormalFunc : public FunctionBase { public: @@ -169,7 +158,65 @@ class CrossMapNormalFunc : public FunctionBase { real pow_; }; +/** + * \param inputs[0] input value. + * \param inputs[1] output value. + * \param inputs[2] output grad. + * \param inputs[3] denoms. + * \param outputs[0] input grad. + */ +template +class CrossMapNormalGradFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + size_ = config.get("size"); + scale_ = config.get("scale"); + pow_ = config.get("pow"); + } + + void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) override { + CHECK_EQ(4, inputs.size()); + CHECK_EQ(1, outputs.size()); + CHECK_EQ(0, inouts.size()); + + CHECK_EQ(inputs[0].dims_.size(), 4); + for (size_t i = 0; i < inputs[0].dims_.size(); i++) { + CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); + } + + size_t samples = inputs[0].dims_[0]; + size_t channels = inputs[0].dims_[1]; + size_t height = inputs[0].dims_[2]; + size_t width = inputs[0].dims_[3]; + + CrossMapNormalGrad(outputs[0].getData(), + inputs[0].getData(), + inputs[1].getData(), + inputs[2].getData(), + inputs[3].getData(), + samples, + channels, + height, + width, + size_, + scale_, + pow_); + } + +private: + size_t size_; + real scale_; + real pow_; +}; + REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); +REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc); +REGISTER_TYPED_FUNC(CrossMapNormalGrad, GPU, CrossMapNormalGradFunc); } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index ef9533485ec9c..f065208084f1d 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include "Function.h" -#include "paddle/math/Matrix.h" namespace paddle { @@ -30,34 +29,19 @@ void CrossMapNormal(real* outputs, size_t size, real scale, real pow); -#if 0 -template -struct CrossMapNormal { - void operator()(typename MatrixT::type& outputs, - typename MatrixT::type& denoms, - typename MatrixT::type& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow); -}; -#endif template -struct CrossMapNormalGrad { - void operator()(typename MatrixT::type& inputsGrad, - typename MatrixT::type& inputsValue, - typename MatrixT::type& outputsGrad, - typename MatrixT::type& outputsValue, - typename MatrixT::type& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow); -}; +void CrossMapNormalGrad(real* inputsGrad, + real* inputsValue, + real* outputsValue, + real* outputsGrad, + real* denoms, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow); } // namespace paddle diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/math/cross_map_normal_op_gpu.cu index 9b92974344955..6339c04194834 100644 --- a/paddle/math/cross_map_normal_op_gpu.cu +++ b/paddle/math/cross_map_normal_op_gpu.cu @@ -131,48 +131,26 @@ __global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, } template <> -void CrossMapNormalGrad::operator()(GpuMatrix& inputsGrad, - GpuMatrix& inputsValue, - GpuMatrix& outputsGrad, - GpuMatrix& outputsValue, - GpuMatrix& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(inputsGrad.isContiguous()); - CHECK(outputsGrad.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK(inputsValue.isContiguous()); - CHECK(outputsValue.isContiguous()); - CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); - - size_t numSample = inputsGrad.getHeight(); - size_t numCols = inputsGrad.getWidth(); - CHECK(imgSizeH * imgSizeW * channels == numCols); - - size_t imageSize = numSample * imgSizeH * imgSizeW; - real* inputsGradData = inputsGrad.getData(); - real* inputsData = inputsValue.getData(); - real* denomsData = denoms.getData(); - real* outputsGradData = outputsGrad.getData(); - real* outputsData = outputsValue.getData(); +void CrossMapNormalGrad(real* inputsGrad, + real* inputsValue, + real* outputsValue, + real* outputsGrad, + real* denoms, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t imageSize = numSamples * height * width; int blockSize = 1024; int gridSize = (imageSize + 1024 - 1) / 1024; KeCMRNormDiff <<>> - (imageSize, inputsData, outputsData, denomsData, outputsGradData, channels, - imgSizeH, imgSizeW, sizeX, -pow, 2.0f * pow * scale, inputsGradData); - CHECK_SYNC("KeCMRNormDiff"); + (imageSize, inputsValue, outputsValue, denoms, outputsGrad, channels, + height, width, size, -pow, 2.0f * pow * scale, inputsGrad); + CHECK_SYNC("CrossMapNormalGrad"); } } // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 0341d757f31b9..bc146514572a4 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -19,12 +19,11 @@ limitations under the License. */ #include #include "TensorCheck.h" #include "paddle/gserver/tests/TestUtil.h" +#include "paddle/math/Function.h" #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h" -#include "paddle/utils/Stat.h" -#include "TensorCheck.h" #include "paddle/math/cross_map_normal_op.h" -#include "paddle/math/Function.h" +#include "paddle/utils/Stat.h" #include "paddle/utils/Util.h" using namespace paddle; // NOLINT @@ -1282,12 +1281,6 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); -#if 0 - FuncConfig config; - config.set("size", (size_t)sizeX); - config.set("scale", scale); - config.set("pow", pow); -#endif FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); FunctionBase* gpu = @@ -1311,22 +1304,6 @@ void testCrossMapNormalFwd( {Tensor(inputsGpu.getData(), dims)}, {Tensor(outputsGpu.getData(), dims), Tensor(denomsGpu.getData(), dims)}, {}); -#if 0 - CrossMapNormal cpuCross; - cpuCross( - outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); - - CrossMapNormal gpuCross; - gpuCross(outputsGpu, - denomsGpu, - inputsGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); -#endif TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); @@ -1381,6 +1358,35 @@ void testCrossMapNormalBwd( outputsValueGpu.copyFrom(outputsValue); inputsGradGpu.copyFrom(inputsGrad); + FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, CPU)); + FunctionBase* gpu = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, GPU)); + cpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); + gpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); + + Dims dims{ + (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; + cpu->calc({Tensor(inputsValue.getData(), dims), + Tensor(outputsValue.getData(), dims), + Tensor(outputsGrad.getData(), dims), + Tensor(denoms.getData(), dims)}, + {Tensor(inputsGrad.getData(), dims)}, + {}); + + gpu->calc({Tensor(inputsValueGpu.getData(), dims), + Tensor(outputsValueGpu.getData(), dims), + Tensor(outputsGradGpu.getData(), dims), + Tensor(denomsGpu.getData(), dims)}, + {Tensor(inputsGradGpu.getData(), dims)}, + {}); +#if 0 CrossMapNormalGrad cpuCross; cpuCross(inputsGrad, inputsValue, @@ -1406,6 +1412,7 @@ void testCrossMapNormalBwd( sizeX, scale, pow); +#endif TensorCheckErr(inputsGrad, inputsGradGpu); } From 22a5e478f3b6ecc0e43d31abce39a686b6331165 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 16:36:51 +0800 Subject: [PATCH 08/15] move Function to function dir --- paddle/{math => function}/Function.cpp | 0 paddle/{math => function}/Function.h | 0 paddle/{math => function}/cross_map_normal_op.cpp | 0 paddle/{math => function}/cross_map_normal_op.h | 0 paddle/{math => function}/cross_map_normal_op_gpu.cu | 0 paddle/gserver/layers/NormProjectionLayer.cpp | 1 - paddle/gserver/layers/NormProjectionLayer.h | 2 +- paddle/math/tests/test_matrixCompare.cpp | 3 +-- 8 files changed, 2 insertions(+), 4 deletions(-) rename paddle/{math => function}/Function.cpp (100%) rename paddle/{math => function}/Function.h (100%) rename paddle/{math => function}/cross_map_normal_op.cpp (100%) rename paddle/{math => function}/cross_map_normal_op.h (100%) rename paddle/{math => function}/cross_map_normal_op_gpu.cu (100%) diff --git a/paddle/math/Function.cpp b/paddle/function/Function.cpp similarity index 100% rename from paddle/math/Function.cpp rename to paddle/function/Function.cpp diff --git a/paddle/math/Function.h b/paddle/function/Function.h similarity index 100% rename from paddle/math/Function.h rename to paddle/function/Function.h diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp similarity index 100% rename from paddle/math/cross_map_normal_op.cpp rename to paddle/function/cross_map_normal_op.cpp diff --git a/paddle/math/cross_map_normal_op.h b/paddle/function/cross_map_normal_op.h similarity index 100% rename from paddle/math/cross_map_normal_op.h rename to paddle/function/cross_map_normal_op.h diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/function/cross_map_normal_op_gpu.cu similarity index 100% rename from paddle/math/cross_map_normal_op_gpu.cu rename to paddle/function/cross_map_normal_op_gpu.cu diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index d6923c2192cf1..e69c406993054 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "NormProjectionLayer.h" -#include "paddle/math/cross_map_normal_op.h" #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 82aa427f8d425..3c4876ece609e 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "NormLayer.h" -#include "paddle/math/Function.h" +#include "paddle/function/Function.h" #include "paddle/math/Matrix.h" namespace paddle { diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index bc146514572a4..da7a585484e7e 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -18,11 +18,10 @@ limitations under the License. */ #include #include "TensorCheck.h" +#include "paddle/function/Function.h" #include "paddle/gserver/tests/TestUtil.h" -#include "paddle/math/Function.h" #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h" -#include "paddle/math/cross_map_normal_op.h" #include "paddle/utils/Stat.h" #include "paddle/utils/Util.h" From 558e86927caa2bbe0bc97b287f9d1abe73cfaaa3 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 17:12:22 +0800 Subject: [PATCH 09/15] add CMakeLists --- cmake/util.cmake | 1 + paddle/CMakeLists.txt | 1 + paddle/function/CMakeLists.txt | 12 ++++++++++++ paddle/function/cross_map_normal_op.cpp | 4 +++- paddle/gserver/CMakeLists.txt | 8 ++------ 5 files changed, 19 insertions(+), 7 deletions(-) create mode 100644 paddle/function/CMakeLists.txt diff --git a/cmake/util.cmake b/cmake/util.cmake index 38366373c6dbc..03734e7839d74 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -96,6 +96,7 @@ function(link_paddle_exe TARGET_NAME) target_circle_link_libraries(${TARGET_NAME} ARCHIVE_START paddle_gserver + paddle_function ${METRIC_LIBS} ARCHIVE_END paddle_pserver diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index fb3af8ea92fee..2daea052b01ad 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(cuda) +add_subdirectory(function) add_subdirectory(utils) add_subdirectory(math) add_subdirectory(parameter) diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt new file mode 100644 index 0000000000000..8fad0e3ebdfb2 --- /dev/null +++ b/paddle/function/CMakeLists.txt @@ -0,0 +1,12 @@ +file(GLOB FUNCTION_HEADERS . *.h) + +if(NOT WITH_GPU) + file(GLOB FUNCTION_SOURCES . *.cpp) + add_library(paddle_function STATIC ${FUNCTION_SOURCES}) +else() + file(GLOB FUNCTION_SOURCES . *.cpp *.cu) + cuda_add_library(paddle_function ${FUNCTION_SOURCES}) +endif() + +add_style_check_target(paddle_function ${FUNCTION_SOURCES}) +add_style_check_target(paddle_function ${FUNCTION_HEADERS}) diff --git a/paddle/function/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp index 8547978c99160..0391a58d89f4a 100644 --- a/paddle/function/cross_map_normal_op.cpp +++ b/paddle/function/cross_map_normal_op.cpp @@ -215,8 +215,10 @@ class CrossMapNormalGradFunc : public FunctionBase { }; REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); -REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); REGISTER_TYPED_FUNC(CrossMapNormalGrad, GPU, CrossMapNormalGradFunc); +#endif } // namespace paddle diff --git a/paddle/gserver/CMakeLists.txt b/paddle/gserver/CMakeLists.txt index a066f80c221ee..4f92150ec84d6 100644 --- a/paddle/gserver/CMakeLists.txt +++ b/paddle/gserver/CMakeLists.txt @@ -27,16 +27,12 @@ if(NOT WITH_GPU) list(REMOVE_ITEM GSERVER_HEADER layers/CudnnConvLayer.h layers/CudnnPoolLayer.h - layers/CudnnBatchNormLayer.h - layers/NormProjectionLayer.h - layers/NormLayer.h) + layers/CudnnBatchNormLayer.h) list(REMOVE_ITEM GSERVER_SOURCES layers/CudnnConvLayer.cpp layers/CudnnPoolLayer.cpp - layers/CudnnBatchNormLayer.cpp - layers/NormProjectionLayer.cpp - layers/NormLayer.cpp) + layers/CudnnBatchNormLayer.cpp) compile_cu_as_cpp(layers/LstmCompute.cu) compile_cu_as_cpp(layers/GruCompute.cu) endif() From d11e2b401348c147b20507863a43b8952f17d6a1 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 17:33:01 +0800 Subject: [PATCH 10/15] Remove some useless code --- paddle/cuda/include/hl_cnn.h | 56 ------ paddle/cuda/include/stub/hl_cnn_stub.h | 24 --- paddle/cuda/src/hl_cuda_cnn.cu | 120 ------------ paddle/gserver/layers/NormProjectionLayer.cpp | 29 --- paddle/math/Matrix.cpp | 176 ------------------ paddle/math/Matrix.h | 65 ------- paddle/math/tests/test_matrixCompare.cpp | 27 --- 7 files changed, 497 deletions(-) diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index 06ee3b3654b57..c5787630abbe1 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -240,62 +240,6 @@ extern void hl_avgpool_backward(const int frameCnt, real* backGrad, const int outStride); -/** - * @brief Cross-map-respose normalize forward. - * - * @param[in] frameCnt batch size of input image. - * @param[in] in input data. - * @param[in] scale buffer. - * @param[out] out output data. - * @param[in] channels number of channel. - * @param[in] height image height. - * @param[in] width image width. - * @param[in] sizeX size. - * @param[in] alpha scale. - * @param[in] beta scale. - * - */ -extern void hl_CMRNorm_forward(size_t frameCnt, - const real* in, - real* scale, - real* out, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta); - -/** - * @brief Cross-map-respose normalize backward. - * - * @param[in] frameCnt batch size of input image. - * @param[in] inV input data. - * @param[in] scale buffer. - * @param[out] outV output value. - * @param[out] outDiff output grad. - * @param[out] inDiff input grad. - * @param[in] channels number of channel. - * @param[in] height image height. - * @param[in] width image width. - * @param[in] sizeX size. - * @param[in] alpha scale. - * @param[in] beta scale. - * - */ -extern void hl_CMRNorm_backward(size_t frameCnt, - const real* inV, - const real* scale, - const real* outV, - const real* outDiff, - real* inDiff, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta); - /** * @brief Bilinear interpolation forward. * diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index 52c978735279e..039551c6cc695 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -117,30 +117,6 @@ inline void hl_avgpool_backward(const int frameCnt, real* backGrad, const int outStride) {} -inline void hl_CMRNorm_forward(size_t frameCnt, - const real* in, - real* scale, - real* out, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta) {} - -inline void hl_CMRNorm_backward(size_t frameCnt, - const real* inV, - const real* scale, - const real* outV, - const real* outDiff, - real* inDiff, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta) {} - inline void hl_bilinear_forward(const real* inData, const size_t inImgH, const size_t inImgW, diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 1516accaae17f..b94f4d8fe4a25 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -381,126 +381,6 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad, CHECK_SYNC("hl_avgpool_backward failed"); } -__global__ void KeCMRNormFillScale(size_t imageSize, const real* in, - real* scale, size_t channels, - size_t height, size_t width, size_t size, - real alpha) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < imageSize) { - const int w = idx % width; - const int h = (idx / width) % height; - const int n = idx / width / height; - const int offset = (n * channels * height + h) * width + w; - - in += offset; - scale += offset; - const int step = height * width; - const int pre_pad = (size - 1) / 2; - const int post_pad = size - pre_pad - 1; - - real accum = 0; - int index = 0; - while (index < channels + post_pad) { - if (index < channels) { - accum += in[index * step] * in[index * step]; - } - if (index >= size) { - accum -= in[(index - size) * step] * in[(index - size) * step]; - } - if (index >= post_pad) { - scale[(index - post_pad) * step] = 1. + accum * alpha; - } - ++index; - } - } -} - -__global__ void KeCMRNormOutput(size_t inputSize, const real* in, - const real* scale, real negative_beta, - real* out) { - const int index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < inputSize) { - out[index] = in[index] * pow(scale[index], negative_beta); - } -} - -void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale, - real* out, size_t channels, - size_t height, size_t width, size_t sizeX, - real alpha, real beta) { - size_t imageSize = frameCnt * height * width; - int blockSize = 1024; - int gridSize = (imageSize + 1024 - 1) / 1024; - KeCMRNormFillScale<<>> - (imageSize, in, scale, channels, height, width, sizeX, alpha); - - size_t inputSize = frameCnt * height * width *channels; - blockSize = 1024; - gridSize = (inputSize + 1024 - 1) / 1024; - KeCMRNormOutput<<>> - (inputSize, in, scale, beta, out); - CHECK_SYNC("hl_CMRNorm_forward"); -} - -__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, - const real* top_data, const real* scale, - const real* top_diff, size_t channels, - size_t height, size_t width, size_t size, - real negative_beta, real cache_ratio, - real* bottom_diff ) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < imageSize) { - const int w = idx % width; - const int h = (idx / width) % height; - const int n = idx / width / height; - const int offset = (n * channels * height + h) * width + w; - bottom_data += offset; - top_data += offset; - scale += offset; - top_diff += offset; - bottom_diff += offset; - - const int step = height * width; - const int pre_pad = size - (size + 1) / 2; - const int post_pad = size - pre_pad - 1; - - int index = 0; - real accum = 0; - while (index < channels + post_pad) { - if (index < channels) { - accum += top_diff[index * step] * top_data[index * step] / - scale[index * step]; - } - if (index >= size) { - accum -= top_diff[(index - size) * step] * - top_data[(index - size) * step] / scale[(index - size) * step]; - } - if (index >= post_pad) { - bottom_diff[(index - post_pad) * step] += - top_diff[(index - post_pad) * step] * - pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(index - post_pad) * step] * accum; - } - ++index; - } - } -} - -void hl_CMRNorm_backward(size_t frameCnt, const real* inV, - const real* scale, - const real* outV, const real* outDiff, - real *inDiff, size_t channels, - size_t height, size_t width, size_t sizeX, - real alpha, real beta) { - size_t imageSize = frameCnt * height * width; - int blockSize = 1024; - int gridSize = (imageSize + 1024 - 1) / 1024; - KeCMRNormDiff <<>> - (imageSize, inV, outV, scale, outDiff, channels, - height, width, sizeX, alpha, beta, inDiff); - CHECK_SYNC("hl_CMRNorm_backward"); -} - __global__ void KeBilinearInterpFw(const real* in, const size_t inImgH, const size_t inImgW, diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index e69c406993054..4ff3b805fbb06 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -110,34 +110,5 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { Tensor(denoms_->getData(), dims_)}, {Tensor(preOutGrad->getData(), dims_)}, {}); -#if 0 - if (useGpu_) { - CrossMapNormalGrad crossGrad; - crossGrad(dynamic_cast(*preOutGrad), - dynamic_cast(*preOutV), - dynamic_cast(*localGrad), - dynamic_cast(*localOutV), - dynamic_cast(*denoms_), - channels_, - imgSizeH_, - imgSizeW_, - size_, - scale_, - pow_); - } else { - CrossMapNormalGrad crossGrad; - crossGrad(dynamic_cast(*preOutGrad), - dynamic_cast(*preOutV), - dynamic_cast(*localGrad), - dynamic_cast(*localOutV), - dynamic_cast(*denoms_), - channels_, - imgSizeH_, - imgSizeW_, - size_, - scale_, - pow_); - } -#endif } } // namespace paddle diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 2cde11dd479dc..a36c31d32b7ca 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1265,69 +1265,6 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad, outGrad.getStride()); } -void GpuMatrix::crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow) { - size_t num = input.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - - CHECK(height * width * channels == input.getWidth()); - CHECK(denoms.getHeight() == input.getHeight() && - denoms.getWidth() == input.getWidth() && input.getHeight() == height_ && - input.getWidth() == width_); - hl_CMRNorm_forward(num, - input.getData(), - denoms.getData(), - data_, - channels, - height, - width, - sizeX, - scale, - -pow); -} - -void GpuMatrix::crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow) { - size_t num = preOutV.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - - CHECK(width * height * channels == preOutV.getWidth()); - CHECK(denoms.getHeight() == preOutV.getHeight() && - denoms.getWidth() == preOutV.getWidth() && - preOutV.getHeight() == height_ && preOutV.getWidth() == width_); - CHECK(denoms.getHeight() == localGrad.getHeight() && - denoms.getWidth() == localGrad.getWidth()); - - hl_CMRNorm_backward(num, - preOutV.getData(), - denoms.getData(), - localOutV.getData(), - localGrad.getData(), - data_, - channels, - height, - width, - sizeX, - -pow, - 2.0f * pow * scale); -} - void GpuMatrix::maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index) { @@ -2219,119 +2156,6 @@ void CpuMatrix::avgPoolBackward(Matrix& input, } } -void CpuMatrix::crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow) { - CHECK(isContiguous()); - CHECK(input.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK_EQ(getHeight(), input.getHeight()); - CHECK_EQ(getWidth(), input.getWidth()); - CHECK_EQ(getHeight(), denoms.getHeight()); - CHECK_EQ(getWidth(), denoms.getWidth()); - - size_t numSample = input.getHeight(); - size_t numCols = input.getWidth(); - size_t height = imgSizeH; - size_t width = imgSizeW; - CHECK(height * width * channels == numCols); - - // TODO(hedaoyuan) After commit TensorExpress code, - // Reconstruction this code to remove the temporary memory. - CpuMatrix tmp(channels, height * width); - CpuMatrix tmp2(tmp.getData(), 1, channels * height * width); - denoms.zero(); - const int start = -((int)sizeX - 1) / 2; - const int end = (int)sizeX + start; - for (size_t i = 0; i < numSample; i++) { - input.subMatrix(i, 1)->square2(tmp2); - CpuMatrix subDen( - denoms.subMatrix(i, 1)->getData(), channels, height * width); - for (int c = 0; c < (int)channels; c++) { - for (int s = start; s < end; s++) { - if (c + s >= 0 && c + s < (int)channels) { - subDen.subMatrix(c, 1)->add(*tmp.subMatrix(c + s, 1)); - } - } - } - } - - denoms.add(scale, (real)1); - this->pow2(denoms, -pow); - this->dotMul(input); -} - -void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow) { - CHECK(isContiguous()); - CHECK(localGrad.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK(preOutV.isContiguous()); - CHECK(localOutV.isContiguous()); - CHECK_EQ(getHeight(), localGrad.getHeight()); - CHECK_EQ(getWidth(), localGrad.getWidth()); - CHECK_EQ(getHeight(), denoms.getHeight()); - CHECK_EQ(getWidth(), denoms.getWidth()); - CHECK_EQ(getHeight(), preOutV.getHeight()); - CHECK_EQ(getWidth(), preOutV.getWidth()); - CHECK_EQ(getHeight(), localOutV.getHeight()); - CHECK_EQ(getWidth(), localOutV.getWidth()); - - size_t numSample = getHeight(); - size_t numCols = getWidth(); - size_t height = imgSizeH; - size_t width = imgSizeW; - CHECK(height * width * channels == numCols); - - // TODO(hedaoyuan) After commit TensorExpress code, - // Reconstruction this code to remove the temporary memory. - CpuMatrix tmp(1, height * width); - - const int start = -((int)sizeX) / 2; - const int end = (int)sizeX + start; - const real ratio = -(real)2 * scale * pow; - for (size_t i = 0; i < numSample; i++) { - CpuMatrix inputDiff( - this->subMatrix(i, 1)->getData(), channels, height * width); - CpuMatrix outDiff( - localGrad.subMatrix(i, 1)->getData(), channels, height * width); - CpuMatrix input( - preOutV.subMatrix(i, 1)->getData(), channels, height * width); - CpuMatrix output( - localOutV.subMatrix(i, 1)->getData(), channels, height * width); - CpuMatrix subDen( - denoms.subMatrix(i, 1)->getData(), channels, height * width); - - for (int c = 0; c < (int)channels; c++) { - tmp.pow2(*subDen.subMatrix(c, 1), -pow); - inputDiff.subMatrix(c, 1) - ->addDotMul(tmp, *outDiff.subMatrix(c, 1), (real)1, (real)1); - for (int s = start; s < end; s++) { - if (c + s >= 0 && c + s < (int)channels) { - tmp.dotMul(*outDiff.subMatrix(c + s, 1), *output.subMatrix(c + s, 1)); - tmp.mulScalar(ratio); - tmp.dotDiv(tmp, *subDen.subMatrix(c + s, 1)); - tmp.dotMul(*input.subMatrix(c, 1)); - inputDiff.subMatrix(c, 1)->add(tmp); - } - } - } - } -} - /** * Input: one or more sequences. Each sequence contains some instances. * Output: output size is the number of input sequences (NOT input instances). diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 5685cb7bcbbb6..62bc1b16fc7b6 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -952,31 +952,6 @@ class Matrix : public BaseMatrix { LOG(FATAL) << "Not implemeted"; } - /// normalize-operation. - virtual void crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow) { - LOG(FATAL) << "Not implemeted"; - } - - virtual void crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t size, - float scale, - float pow) { - LOG(FATAL) << "Not implemeted"; - } - /** * Input: one or more sequences. Each sequence contains some instances. * @@ -1459,26 +1434,6 @@ class GpuMatrix : public Matrix { size_t paddingH, size_t paddingW); - void crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow); - - void crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow); - void maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index); @@ -1685,26 +1640,6 @@ class CpuMatrix : public Matrix { size_t paddingH, size_t paddingW); - void crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow); - - void crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow); - void maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index da7a585484e7e..c89b7ff490232 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1385,33 +1385,6 @@ void testCrossMapNormalBwd( Tensor(denomsGpu.getData(), dims)}, {Tensor(inputsGradGpu.getData(), dims)}, {}); -#if 0 - CrossMapNormalGrad cpuCross; - cpuCross(inputsGrad, - inputsValue, - outputsGrad, - outputsValue, - denoms, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - - CrossMapNormalGrad gpuCross; - gpuCross(inputsGradGpu, - inputsValueGpu, - outputsGradGpu, - outputsValueGpu, - denomsGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); -#endif TensorCheckErr(inputsGrad, inputsGradGpu); } From f13aeb52e9fc666ac1e24acf5315cbdccf108402 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 20:12:53 +0800 Subject: [PATCH 11/15] fix swig_api --- paddle/api/CMakeLists.txt | 1 + paddle/api/paddle_ld_flags.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index 6ad1d79e59b11..ed69bd764f30a 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -46,6 +46,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp WORKING_DIRECTORY ${PROJ_ROOT}/paddle DEPENDS python_swig_sources paddle_parameter + paddle_function paddle_math paddle_utils paddle_gserver diff --git a/paddle/api/paddle_ld_flags.py b/paddle/api/paddle_ld_flags.py index 51d7dfee58b78..7c8206e3fe097 100644 --- a/paddle/api/paddle_ld_flags.py +++ b/paddle/api/paddle_ld_flags.py @@ -30,8 +30,8 @@ whole_end = "" LIB_DIRS = [ - "math", 'utils', 'parameter', "gserver", "api", "cuda", "pserver", - "trainer" + "math", 'function', 'utils', 'parameter', "gserver", "api", "cuda", + "pserver", "trainer" ] PARENT_LIB_DIRS = ['proto'] @@ -75,6 +75,7 @@ def libs_str(self): libs = [ whole_start, "-lpaddle_gserver", + "-lpaddle_function", whole_end, "-lpaddle_pserver", "-lpaddle_trainer_lib", From cee934680467c50d4084dbaf7273a39a40cc832d Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 21:23:05 +0800 Subject: [PATCH 12/15] add some comments --- paddle/function/cross_map_normal_op.cpp | 5 ++- paddle/function/cross_map_normal_op.h | 34 +++++++++++++++++++ paddle/gserver/layers/Layer.h | 6 ++++ paddle/gserver/layers/NormProjectionLayer.cpp | 18 ++++------ paddle/gserver/layers/NormProjectionLayer.h | 3 -- 5 files changed, 50 insertions(+), 16 deletions(-) diff --git a/paddle/function/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp index 0391a58d89f4a..a18c0bb750acf 100644 --- a/paddle/function/cross_map_normal_op.cpp +++ b/paddle/function/cross_map_normal_op.cpp @@ -17,7 +17,6 @@ limitations under the License. */ namespace paddle { -// NCHW template <> void CrossMapNormal(real* outputs, real* denoms, @@ -36,6 +35,10 @@ void CrossMapNormal(real* outputs, CpuVector inputsV(numSamples * oneSample, inputs); CpuVector denomsV(numSamples * oneSample, denoms); + // f(x) = x * ( 1 + scale * SUM((x)^2) )^(-pow) + // x represents inputs + // f(x) represents outputs + // denoms save the intermediate result for backward denomsV = denomsV.constant(1.0); const int start = -((int)size - 1) / 2; const int end = (int)size + start; diff --git a/paddle/function/cross_map_normal_op.h b/paddle/function/cross_map_normal_op.h index f065208084f1d..e935b26e125d3 100644 --- a/paddle/function/cross_map_normal_op.h +++ b/paddle/function/cross_map_normal_op.h @@ -18,6 +18,22 @@ limitations under the License. */ namespace paddle { +/** + * \brief Cross map respose normalize forward. + * The data structure of image data is NCHW. + * + * \param[out] outputs output data. + * \param[in] denoms denoms buffer. + * \param[in] inputs input data. + * \param[in] numSamples batch size of input image. + * \param[in] channels number of channel. + * \param[in] height image height. + * \param[in] width image width. + * \param[in] size size. + * \param[in] scale scale. + * \param[in] pow scale. + * + */ template void CrossMapNormal(real* outputs, real* denoms, @@ -30,6 +46,24 @@ void CrossMapNormal(real* outputs, real scale, real pow); +/** + * \brief Cross map respose normalize backward. + * The data structure of image data is NCHW. + * + * \param[out] inputsGrad input grad. + * \param[in] inputsValue input value. + * \param[out] outputsValue output value. + * \param[out] outputsGrad output grad. + * \param[in] denoms denoms buffer. + * \param[in] numSamples batch size of input image. + * \param[in] channels number of channel. + * \param[in] height image height. + * \param[in] width image width. + * \param[in] size size. + * \param[in] scale scale. + * \param[in] pow scale. + * + */ template void CrossMapNormalGrad(real* inputsGrad, real* inputsValue, diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h index 172e558b82945..16f66a2205f49 100644 --- a/paddle/gserver/layers/Layer.h +++ b/paddle/gserver/layers/Layer.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include "ModelConfig.pb.h" +#include "paddle/function/Function.h" #include "paddle/math/CpuSparseMatrix.h" #include "paddle/parameter/Parameter.h" #include "paddle/utils/ClassRegistrar.h" @@ -100,6 +101,11 @@ class Layer { /// Mark input grad in(true) or out(false) of backward function. std::vector markInBackward_; + /// Layer forward function + FunctionBase* forward_; + /// Layer backward function + FunctionBase* backward_; + public: /** * Wait until all input value ready. diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index 4ff3b805fbb06..0f6f9b91d0578 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -48,20 +48,17 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, if (useGpu_) { forward_ = FunctionBase::funcRegistrar_.createByType( FUNC_NAME(CrossMapNormal, GPU)); + backward_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, GPU)); } else { forward_ = FunctionBase::funcRegistrar_.createByType( FUNC_NAME(CrossMapNormal, CPU)); + backward_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, CPU)); } forward_->init( FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); - if (useGpu_) { - backward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, GPU)); - } else { - backward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, CPU)); - } backward_->init( FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); @@ -74,7 +71,7 @@ void CMRProjectionNormLayer::forward(PassType passType) { /* malloc memory for the output_ if necessary */ /* note: one sample correspond to one row */ MatrixPtr input = inputLayers_[0]->getOutputValue(); - int batchSize = input->getHeight(); + size_t batchSize = input->getHeight(); int size = getSize(); resetOutput(batchSize, size); @@ -82,10 +79,7 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - dims_ = {(size_t)batchSize, - (size_t)channels_, - (size_t)imgSizeH_, - (size_t)imgSizeW_}; + dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_}; forward_->calc( {Tensor(input->getData(), dims_)}, {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 3c4876ece609e..6b2c5dde0d74d 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,7 +16,6 @@ limitations under the License. */ #include #include "NormLayer.h" -#include "paddle/function/Function.h" #include "paddle/math/Matrix.h" namespace paddle { @@ -43,7 +42,5 @@ class CMRProjectionNormLayer : public ResponseNormLayer { protected: Dims dims_; - FunctionBase* forward_; - FunctionBase* backward_; }; } // namespace paddle From 148bd4d0b3240d31c1c96ddac89ffd4935f71b03 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 19 Dec 2016 15:04:48 +0800 Subject: [PATCH 13/15] add Layer::createFunction --- paddle/gserver/layers/Layer.h | 24 +++++++++++-- paddle/gserver/layers/NormProjectionLayer.cpp | 34 +++++++------------ 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h index 16f66a2205f49..6dfd48fb96618 100644 --- a/paddle/gserver/layers/Layer.h +++ b/paddle/gserver/layers/Layer.h @@ -102,9 +102,9 @@ class Layer { std::vector markInBackward_; /// Layer forward function - FunctionBase* forward_; + std::vector> forward_; /// Layer backward function - FunctionBase* backward_; + std::vector> backward_; public: /** @@ -132,6 +132,26 @@ class Layer { virtual void markAllInputGrad(); protected: + /** + * Create layer function. Function is called in forward or backward. + * \param function, Layer::forward_ or Layer::backward_ + * \param name, function name + * \param config, initialization configuration for the function + */ + void createFunction(std::vector>& function, + const std::string& name, + const FuncConfig& config) { + if (useGpu_) { + function.emplace_back( + FunctionBase::funcRegistrar_.createByType(name + "-GPU")); + } else { + function.emplace_back( + FunctionBase::funcRegistrar_.createByType(name + "-CPU")); + } + auto& func = function.back(); + func->init(config); + } + /** * Notify specified layer the output grad ready. * Called in the backward function. diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index 0f6f9b91d0578..262d757c67e10 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -45,21 +45,13 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, /* the size of inputs for norm-layer is 1 */ CHECK_EQ(config_.inputs_size(), 1); - if (useGpu_) { - forward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormal, GPU)); - backward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, GPU)); - } else { - forward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormal, CPU)); - backward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, CPU)); - } - forward_->init( + createFunction( + forward_, + "CrossMapNormal", FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); - - backward_->init( + createFunction( + backward_, + "CrossMapNormalGrad", FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); return true; @@ -80,7 +72,7 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_}; - forward_->calc( + forward_[0]->calc( {Tensor(input->getData(), dims_)}, {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, {}); @@ -98,11 +90,11 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { MatrixPtr localOutV = getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); - backward_->calc({Tensor(preOutV->getData(), dims_), - Tensor(localOutV->getData(), dims_), - Tensor(localGrad->getData(), dims_), - Tensor(denoms_->getData(), dims_)}, - {Tensor(preOutGrad->getData(), dims_)}, - {}); + backward_[0]->calc({Tensor(preOutV->getData(), dims_), + Tensor(localOutV->getData(), dims_), + Tensor(localGrad->getData(), dims_), + Tensor(denoms_->getData(), dims_)}, + {Tensor(preOutGrad->getData(), dims_)}, + {}); } } // namespace paddle From 5fddd99e18f3920ff0d8158fd4a9800d5566943e Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 20 Dec 2016 17:20:22 +0800 Subject: [PATCH 14/15] move TEST from test_matrixCompare.cpp to cross_map_normal_op_test.cpp --- cmake/util.cmake | 1 + paddle/function/CMakeLists.txt | 35 +++-- paddle/function/FunctionTest.h | 102 +++++++++++++ paddle/function/TestMain.cpp | 22 +++ paddle/function/cross_map_normal_op_test.cpp | 71 +++++++++ paddle/math/tests/test_matrixCompare.cpp | 144 ------------------- 6 files changed, 221 insertions(+), 154 deletions(-) create mode 100644 paddle/function/FunctionTest.h create mode 100644 paddle/function/TestMain.cpp create mode 100644 paddle/function/cross_map_normal_op_test.cpp diff --git a/cmake/util.cmake b/cmake/util.cmake index 03734e7839d74..8a71b23c62d9f 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -107,6 +107,7 @@ function(link_paddle_exe TARGET_NAME) paddle_parameter paddle_proto paddle_cuda + paddle_test_main ${METRIC_LIBS} ${PROTOBUF_LIBRARY} ${LIBGLOG_LIBRARY} diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 8fad0e3ebdfb2..0697842bbef62 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -1,12 +1,27 @@ -file(GLOB FUNCTION_HEADERS . *.h) - -if(NOT WITH_GPU) - file(GLOB FUNCTION_SOURCES . *.cpp) - add_library(paddle_function STATIC ${FUNCTION_SOURCES}) -else() - file(GLOB FUNCTION_SOURCES . *.cpp *.cu) - cuda_add_library(paddle_function ${FUNCTION_SOURCES}) +file(GLOB h_files . *_op.h) +file(GLOB cpp_files . *_op.cpp) + +list(APPEND h_files Function.h) +list(APPEND cpp_files Function.cpp) + +if(WITH_GPU) + file(GLOB cu_files . *_op_gpu.cu) + cuda_compile(cu_objs ${cu_files}) endif() -add_style_check_target(paddle_function ${FUNCTION_SOURCES}) -add_style_check_target(paddle_function ${FUNCTION_HEADERS}) +add_library(paddle_function STATIC ${cpp_files} ${cu_objs}) + +add_library(paddle_test_main STATIC TestMain.cpp) + +if(WITH_GPU) + # TODO: + # file(GLOB test_files . *_op_test.cpp) + # add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files}) + add_simple_unittest(cross_map_normal_op_test) +endif() + +add_style_check_target(paddle_function ${h_files}) +add_style_check_target(paddle_function ${cpp_files}) +if(WITH_GPU) + add_style_check_target(paddle_function ${cu_files}) +endif() diff --git a/paddle/function/FunctionTest.h b/paddle/function/FunctionTest.h new file mode 100644 index 0000000000000..a8c5e412bd12d --- /dev/null +++ b/paddle/function/FunctionTest.h @@ -0,0 +1,102 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" +#include "paddle/math/Vector.h" +#include "paddle/math/tests/TensorCheck.h" + +namespace paddle { + +class FunctionCompare { +public: + FunctionCompare(const std::string& name, const FuncConfig& config) + : cpu(FunctionBase::funcRegistrar_.createByType(name + "-CPU")), + gpu(FunctionBase::funcRegistrar_.createByType(name + "-GPU")) { + cpu->init(config); + gpu->init(config); + } + + void cmpWithArg(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) { + // init cpu and gpu arguments + auto initArgs = [=]( + Arguments& cpuArgs, Arguments& gpuArgs, const Arguments& inArgs) { + for (auto arg : inArgs) { + size_t size = sizeof(real); + for (auto dim : arg.dims_) { + size *= dim; + } + cpuMemory.emplace_back(std::make_shared(size)); + gpuMemory.emplace_back(std::make_shared(size)); + cpuArgs.emplace_back( + Tensor((real*)cpuMemory.back()->getBuf(), arg.dims_)); + gpuArgs.emplace_back( + Tensor((real*)gpuMemory.back()->getBuf(), arg.dims_)); + + // will use an api to refactor this code. + CpuVector cpuVector(size / sizeof(real), + (real*)cpuArgs.back().getData()); + GpuVector gpuVector(size / sizeof(real), + (real*)gpuArgs.back().getData()); + cpuVector.uniform(0.001, 1); + gpuVector.copyFrom(cpuVector); + } + }; + initArgs(cpuInputs, gpuInputs, inputs); + initArgs(cpuOutputs, gpuOutputs, outputs); + initArgs(cpuInouts, gpuInouts, inouts); + + // function calculate + cpu->calc(cpuInputs, cpuOutputs, cpuInouts); + gpu->calc(gpuInputs, gpuOutputs, gpuInouts); + + // check outputs and inouts + auto checkArgs = [=](const Arguments& cpuArgs, const Arguments& gpuArgs) { + for (size_t i = 0; i < cpuArgs.size(); i++) { + auto cpu = cpuArgs[i]; + auto gpu = gpuArgs[i]; + size_t size = 1; + for (auto dim : cpu.dims_) { + size *= dim; + } + CpuVector cpuVector(size, (real*)cpu.getData()); + GpuVector gpuVector(size, (real*)gpu.getData()); + + autotest::TensorCheckErr(cpuVector, gpuVector); + } + }; + checkArgs(cpuOutputs, gpuOutputs); + checkArgs(cpuInouts, gpuInouts); + } + +protected: + std::shared_ptr cpu; + std::shared_ptr gpu; + std::vector cpuMemory; + std::vector gpuMemory; + Arguments cpuInputs; + Arguments cpuOutputs; + Arguments cpuInouts; + Arguments gpuInputs; + Arguments gpuOutputs; + Arguments gpuInouts; +}; + +} // namespace paddle + +using paddle::FunctionCompare; +using paddle::FuncConfig; +using paddle::Dims; +using paddle::Tensor; diff --git a/paddle/function/TestMain.cpp b/paddle/function/TestMain.cpp new file mode 100644 index 0000000000000..3e14532d1878f --- /dev/null +++ b/paddle/function/TestMain.cpp @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/utils/Util.h" + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + paddle::initMain(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/function/cross_map_normal_op_test.cpp b/paddle/function/cross_map_normal_op_test.cpp new file mode 100644 index 0000000000000..22692691bdb64 --- /dev/null +++ b/paddle/function/cross_map_normal_op_test.cpp @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +TEST(CrossMapNormal, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {1, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + for (size_t size : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " size=" << size; + + FunctionCompare compare("CrossMapNormal", + FuncConfig() + .set("size", size) + .set("scale", (real)1.5) + .set("pow", (real)0.5)); + Dims dims{numSamples, channels, imgSizeH, imgSizeW}; + compare.cmpWithArg({Tensor(nullptr, dims)}, + {Tensor(nullptr, dims), Tensor(nullptr, dims)}, + {}); + } + } + } + } + } +} + +TEST(CrossMapNormalGrad, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {1, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + for (size_t size : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " size=" << size; + + FunctionCompare compare("CrossMapNormalGrad", + FuncConfig() + .set("size", size) + .set("scale", (real)1.5) + .set("pow", (real)0.5)); + Dims dims{numSamples, channels, imgSizeH, imgSizeW}; + compare.cmpWithArg({Tensor(nullptr, dims), + Tensor(nullptr, dims), + Tensor(nullptr, dims), + Tensor(nullptr, dims)}, + {Tensor(nullptr, dims)}, + {}); + } + } + } + } + } +} diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index c89b7ff490232..440534e722700 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1263,150 +1263,6 @@ TEST(Matrix, MaxOutFwdBwd) { } } -void testCrossMapNormalFwd( - int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { - float scale = 1.5; - float pow = 0.5; - int width = imgSizeH * imgSizeW * channels; - CpuMatrix inputs(numSamples, width); - CpuMatrix denoms(numSamples, width); - CpuMatrix outputs(numSamples, width); - GpuMatrix inputsGpu(numSamples, width); - GpuMatrix denomsGpu(numSamples, width); - GpuMatrix outputsGpu(numSamples, width); - - inputs.randomizeUniform(); - outputs.randomizeUniform(); - inputsGpu.copyFrom(inputs); - outputsGpu.copyFrom(outputs); - - FunctionBase* cpu = - FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); - FunctionBase* gpu = - FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, GPU)); - cpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - gpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - - Dims dims{ - (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; - cpu->calc({Tensor(inputs.getData(), dims)}, - {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, - {}); - - gpu->calc( - {Tensor(inputsGpu.getData(), dims)}, - {Tensor(outputsGpu.getData(), dims), Tensor(denomsGpu.getData(), dims)}, - {}); - - TensorCheckErr(outputs, outputsGpu); - TensorCheckErr(denoms, denomsGpu); -} - -TEST(Matrix, crossMapNormalFwd) { - for (auto numSamples : {5, 32}) { - for (auto channels : {1, 5, 32}) { - for (auto imgSizeH : {5, 33, 100}) { - for (auto imgSizeW : {5, 32, 96}) { - for (auto sizeX : {1, 2, 3, 5, 7}) { - VLOG(3) << " numSamples=" << numSamples << " channels=" << channels - << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW - << " sizeX=" << sizeX; - testCrossMapNormalFwd( - numSamples, channels, imgSizeH, imgSizeW, sizeX); - } - } - } - } - } -} - -void testCrossMapNormalBwd( - int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { - float scale = 1.5; - float pow = 0.5; - size_t width = imgSizeH * imgSizeW * channels; - - CpuMatrix inputsGrad(numSamples, width); - CpuMatrix inputsValue(numSamples, width); - CpuMatrix outputsGrad(numSamples, width); - CpuMatrix outputsValue(numSamples, width); - CpuMatrix denoms(numSamples, width); - - outputsGrad.randomizeUniform(); - denoms.randomizeUniform(); - inputsValue.randomizeUniform(); - outputsValue.randomizeUniform(); - inputsGrad.randomizeUniform(); - denoms.add(0.01); - - GpuMatrix inputsGradGpu(numSamples, width); - GpuMatrix inputsValueGpu(numSamples, width); - GpuMatrix outputsGradGpu(numSamples, width); - GpuMatrix outputsValueGpu(numSamples, width); - GpuMatrix denomsGpu(numSamples, width); - - outputsGradGpu.copyFrom(outputsGrad); - denomsGpu.copyFrom(denoms); - inputsValueGpu.copyFrom(inputsValue); - outputsValueGpu.copyFrom(outputsValue); - inputsGradGpu.copyFrom(inputsGrad); - - FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, CPU)); - FunctionBase* gpu = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, GPU)); - cpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - gpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - - Dims dims{ - (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; - cpu->calc({Tensor(inputsValue.getData(), dims), - Tensor(outputsValue.getData(), dims), - Tensor(outputsGrad.getData(), dims), - Tensor(denoms.getData(), dims)}, - {Tensor(inputsGrad.getData(), dims)}, - {}); - - gpu->calc({Tensor(inputsValueGpu.getData(), dims), - Tensor(outputsValueGpu.getData(), dims), - Tensor(outputsGradGpu.getData(), dims), - Tensor(denomsGpu.getData(), dims)}, - {Tensor(inputsGradGpu.getData(), dims)}, - {}); - - TensorCheckErr(inputsGrad, inputsGradGpu); -} - -TEST(Matrix, crossMapNormalBwd) { - for (auto numSamples : {5, 32}) { - for (auto channels : {1, 5, 32}) { - for (auto imgSizeH : {5, 33, 100}) { - for (auto imgSizeW : {5, 32, 96}) { - for (auto sizeX : {1, 2, 3, 5, 7}) { - VLOG(3) << " numSamples=" << numSamples << " channels=" << channels - << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW - << " sizeX=" << sizeX; - testCrossMapNormalBwd( - numSamples, channels, imgSizeH, imgSizeW, sizeX); - } - } - } - } - } -} - int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); From f1a94e3ff7fce800f6c846da2ae6ad4312c4acfc Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 20 Dec 2016 20:30:13 +0800 Subject: [PATCH 15/15] follow comments --- paddle/function/cross_map_normal_op.cpp | 22 +++++++++++----------- paddle/function/cross_map_normal_op.h | 10 +++++----- paddle/function/cross_map_normal_op_gpu.cu | 10 +++++----- paddle/math/tests/test_matrixCompare.cpp | 1 - 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/paddle/function/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp index a18c0bb750acf..a9c7693830542 100644 --- a/paddle/function/cross_map_normal_op.cpp +++ b/paddle/function/cross_map_normal_op.cpp @@ -20,7 +20,7 @@ namespace paddle { template <> void CrossMapNormal(real* outputs, real* denoms, - real* inputs, + const real* inputs, size_t numSamples, size_t channels, size_t height, @@ -32,7 +32,7 @@ void CrossMapNormal(real* outputs, size_t oneSample = channels * oneImage; CpuVector outputsV(numSamples * oneSample, outputs); - CpuVector inputsV(numSamples * oneSample, inputs); + CpuVector inputsV(numSamples * oneSample, const_cast(inputs)); CpuVector denomsV(numSamples * oneSample, denoms); // f(x) = x * ( 1 + scale * SUM((x)^2) )^(-pow) @@ -44,7 +44,7 @@ void CrossMapNormal(real* outputs, const int end = (int)size + start; for (size_t i = 0; i < numSamples; i++) { real* oneDenom = denoms + i * oneSample; - real* oneInput = inputs + i * oneSample; + real* oneInput = const_cast(inputs) + i * oneSample; for (int c = 0; c < (int)channels; c++) { CpuVector denom(oneImage, oneDenom + c * oneImage); for (int s = start; s < end; s++) { @@ -61,10 +61,10 @@ void CrossMapNormal(real* outputs, template <> void CrossMapNormalGrad(real* inputsGrad, - real* inputsValue, - real* outputsValue, - real* outputsGrad, - real* denoms, + const real* inputsValue, + const real* outputsValue, + const real* outputsGrad, + const real* denoms, size_t numSamples, size_t channels, size_t height, @@ -84,10 +84,10 @@ void CrossMapNormalGrad(real* inputsGrad, for (size_t i = 0; i < numSamples; i++) { size_t sOffset = i * oneSample; real* oneInputGrad = inputsGrad + sOffset; - real* oneInputValue = inputsValue + sOffset; - real* oneDenom = denoms + sOffset; - real* oneOutputGrad = outputsGrad + sOffset; - real* oneOutputValue = outputsValue + sOffset; + real* oneInputValue = const_cast(inputsValue) + sOffset; + real* oneDenom = const_cast(denoms) + sOffset; + real* oneOutputGrad = const_cast(outputsGrad) + sOffset; + real* oneOutputValue = const_cast(outputsValue) + sOffset; for (int c = 0; c < (int)channels; c++) { size_t cOffset = c * height * width; diff --git a/paddle/function/cross_map_normal_op.h b/paddle/function/cross_map_normal_op.h index e935b26e125d3..b1e401ad0a2f5 100644 --- a/paddle/function/cross_map_normal_op.h +++ b/paddle/function/cross_map_normal_op.h @@ -37,7 +37,7 @@ namespace paddle { template void CrossMapNormal(real* outputs, real* denoms, - real* inputs, + const real* inputs, size_t numSamples, size_t channels, size_t height, @@ -66,10 +66,10 @@ void CrossMapNormal(real* outputs, */ template void CrossMapNormalGrad(real* inputsGrad, - real* inputsValue, - real* outputsValue, - real* outputsGrad, - real* denoms, + const real* inputsValue, + const real* outputsValue, + const real* outputsGrad, + const real* denoms, size_t numSamples, size_t channels, size_t height, diff --git a/paddle/function/cross_map_normal_op_gpu.cu b/paddle/function/cross_map_normal_op_gpu.cu index 6339c04194834..aae4f461b6f57 100644 --- a/paddle/function/cross_map_normal_op_gpu.cu +++ b/paddle/function/cross_map_normal_op_gpu.cu @@ -63,7 +63,7 @@ __global__ void KeCMRNormOutput(size_t inputSize, const real* in, template <> void CrossMapNormal(real* outputs, real* denoms, - real* inputs, + const real* inputs, size_t numSamples, size_t channels, size_t height, @@ -132,10 +132,10 @@ __global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, template <> void CrossMapNormalGrad(real* inputsGrad, - real* inputsValue, - real* outputsValue, - real* outputsGrad, - real* denoms, + const real* inputsValue, + const real* outputsValue, + const real* outputsGrad, + const real* denoms, size_t numSamples, size_t channels, size_t height, diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 440534e722700..62de5b25e4cc8 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "TensorCheck.h" -#include "paddle/function/Function.h" #include "paddle/gserver/tests/TestUtil.h" #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h"