From 529f24c262850974dd8ba4c5b7ad1a4e3e0230fc Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 12 Dec 2016 18:17:27 +0800 Subject: [PATCH 01/55] cpu cmrnorm --- paddle/cuda/src/hl_cuda_cnn.cu | 192 +++++++++-------------- paddle/gserver/tests/test_LayerGrad.cpp | 3 +- paddle/math/Matrix.cpp | 137 ++++++++++------ paddle/math/tests/test_matrixCompare.cpp | 115 ++++++++++++++ 4 files changed, 279 insertions(+), 168 deletions(-) diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 0992286f360fb..1516accaae17f 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -381,57 +381,45 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad, CHECK_SYNC("hl_avgpool_backward failed"); } -__global__ void KeCMRNormFillScale(size_t nthreads, const real* in, +__global__ void KeCMRNormFillScale(size_t imageSize, const real* in, real* scale, size_t channels, size_t height, size_t width, size_t size, real alpha) { - size_t index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { - // find out the local offset - size_t w = index % width; - size_t h = (index / width) % height; - size_t n = index / width / height; - size_t offset = (n * channels * height + h) * width + w; - size_t step = height * width; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + in += offset; scale += offset; - size_t head = 0; - size_t pre_pad = (size - 1) / 2; - size_t post_pad = size - pre_pad - 1; - real accum_scale = 0; - // fill the scale at [n, :, h, w] - // accumulate values - while (head < post_pad) { - accum_scale += in[head * step] * in[head * step]; - ++head; - } - // until we reach size, nothing needs to be subtracted - while (head < size) { - accum_scale += in[head * step] * in[head * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; - } - // both add and subtract - while (head < channels) { - accum_scale += in[head * step] * in[head * step]; - accum_scale -= in[(head - size) * step] * in[(head - size) * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; - } - // subtract only - while (head < channels + post_pad) { - accum_scale -= in[(head - size) * step] * in[(head - size) * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; + const int step = height * width; + const int pre_pad = (size - 1) / 2; + const int post_pad = size - pre_pad - 1; + + real accum = 0; + int index = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += in[index * step] * in[index * step]; + } + if (index >= size) { + accum -= in[(index - size) * step] * in[(index - size) * step]; + } + if (index >= post_pad) { + scale[(index - post_pad) * step] = 1. + accum * alpha; + } + ++index; } } } - __global__ void KeCMRNormOutput(size_t nthreads, const real* in, - const real* scale, real negative_beta, - real* out) { - size_t index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { +__global__ void KeCMRNormOutput(size_t inputSize, const real* in, + const real* scale, real negative_beta, + real* out) { + const int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < inputSize) { out[index] = in[index] * pow(scale[index], negative_beta); } } @@ -440,84 +428,60 @@ void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale, real* out, size_t channels, size_t height, size_t width, size_t sizeX, real alpha, real beta) { - size_t threadsNum = frameCnt * height * width; - size_t blocksX = (threadsNum + 1024 - 1) / 1024; - size_t blocksY = 1; - dim3 threads(1024, 1); - dim3 grid(blocksX, blocksY); - - KeCMRNormFillScale<<>> - (threadsNum, in, scale, channels, height, width, sizeX, alpha); - - threadsNum = frameCnt * height * width *channels; - blocksX = (threadsNum + 1024 -1) / 1024; - dim3 threads2(1024, 1); - dim3 grid2(blocksX, blocksY); - KeCMRNormOutput<<>> - (threadsNum, in, scale, beta, out); + size_t imageSize = frameCnt * height * width; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormFillScale<<>> + (imageSize, in, scale, channels, height, width, sizeX, alpha); + + size_t inputSize = frameCnt * height * width *channels; + blockSize = 1024; + gridSize = (inputSize + 1024 - 1) / 1024; + KeCMRNormOutput<<>> + (inputSize, in, scale, beta, out); CHECK_SYNC("hl_CMRNorm_forward"); } -__global__ void KeCMRNormDiff(size_t nthreads, const real* bottom_data, +__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, const real* top_data, const real* scale, const real* top_diff, size_t channels, size_t height, size_t width, size_t size, real negative_beta, real cache_ratio, real* bottom_diff ) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { - // find out the local offset - size_t w = index % width; - size_t h = (index / width) % height; - size_t n = index / width / height; - size_t offset = (n * channels * height + h) * width + w; - size_t step = height * width; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; bottom_data += offset; top_data += offset; scale += offset; top_diff += offset; bottom_diff += offset; - int head = 0; - int pre_pad = size - (size + 1) / 2; - int post_pad = size - pre_pad - 1; - real accum_ratio = 0; - // accumulate values - while (head < post_pad) { - accum_ratio += top_diff[head * step] * - top_data[head * step] / scale[head * step]; - ++head; - } - // until we reach size, nothing needs to be subtracted - while (head < size) { - accum_ratio += top_diff[head * step] * - top_data[head * step] / scale[head * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; - } - // both add and subtract - while (head < channels) { - accum_ratio += top_diff[head * step] * top_data[head * step] / - scale[head * step]; - accum_ratio -= top_diff[(head - size) * step] * - top_data[(head - size) * step] / scale[(head - size) * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; - } - // subtract only - while (head < channels + post_pad) { - accum_ratio -= top_diff[(head - size) * step] * - top_data[(head - size) * step] / scale[(head - size) * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; + + const int step = height * width; + const int pre_pad = size - (size + 1) / 2; + const int post_pad = size - pre_pad - 1; + + int index = 0; + real accum = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += top_diff[index * step] * top_data[index * step] / + scale[index * step]; + } + if (index >= size) { + accum -= top_diff[(index - size) * step] * + top_data[(index - size) * step] / scale[(index - size) * step]; + } + if (index >= post_pad) { + bottom_diff[(index - post_pad) * step] += + top_diff[(index - post_pad) * step] * + pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(index - post_pad) * step] * accum; + } + ++index; } } } @@ -528,14 +492,12 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV, real *inDiff, size_t channels, size_t height, size_t width, size_t sizeX, real alpha, real beta) { - size_t threadsNum = frameCnt * height * width; - size_t blocksX = (threadsNum + 1024 - 1) / 1024; - size_t blocksY = 1; - dim3 threads(1024, 1); - dim3 grid(blocksX, blocksY); - KeCMRNormDiff <<>> - (threadsNum, inV, outV, scale, outDiff, channels, - height, width, sizeX, alpha, beta, inDiff); + size_t imageSize = frameCnt * height * width; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormDiff <<>> + (imageSize, inV, outV, scale, outDiff, channels, + height, width, sizeX, alpha, beta, inDiff); CHECK_SYNC("hl_CMRNorm_backward"); } diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 7983d9fe64c61..8ade15daac860 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1021,11 +1021,10 @@ void testNormLayer(const string& normType, bool trans, bool useGpu) { testLayerGrad(config, "norm", 100, trans, useGpu); } -#ifndef PADDLE_ONLY_CPU TEST(Layer, NormLayer) { testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ true); + testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ false); } -#endif void setPoolConfig(TestConfig* config, PoolConfig* pool, diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index c69e074a76399..2cde11dd479dc 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -2227,52 +2227,43 @@ void CpuMatrix::crossMapNormalFwd(Matrix& input, size_t sizeX, float scale, float pow) { - size_t num = input.getHeight(); + CHECK(isContiguous()); + CHECK(input.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK_EQ(getHeight(), input.getHeight()); + CHECK_EQ(getWidth(), input.getWidth()); + CHECK_EQ(getHeight(), denoms.getHeight()); + CHECK_EQ(getWidth(), denoms.getWidth()); + + size_t numSample = input.getHeight(); + size_t numCols = input.getWidth(); size_t height = imgSizeH; size_t width = imgSizeW; - size_t numCols = input.getWidth(); - CHECK(height * width * channels == input.getWidth()); - CHECK(denoms.getHeight() == input.getHeight() && - denoms.getWidth() == input.getWidth() && input.getHeight() == height_ && - input.getWidth() == width_); - real* imgData = input.getData(); - real* diffData = input.getData(); - real* targetData = getData(); - size_t halfSize = sizeX / 2; - size_t imgPixels = height * width; - - // use integral vector to implement the sum in local window - real* integralData = - (real*)malloc((channels + sizeX + 1) * sizeof(real)); // NOLINT // TODO: - for (size_t i = 0; i <= halfSize; i++) { - integralData[i] = 0; - } - for (size_t i = 0; i < num; i++) { - real* targetPtr = targetData + i * numCols; - real* imgPtr = imgData + i * numCols; - real* diffPtr = diffData + i * numCols; - for (size_t m = 0; m < height; m++) { - for (size_t n = 0; n < width; n++) { - for (size_t c = 0; c < channels; c++) { - integralData[c + halfSize + 1] = - integralData[c + halfSize] + _square(*(diffPtr + c * imgPixels)); - } - for (size_t k = channels + halfSize + 1; k <= channels + sizeX; k++) { - integralData[k] = integralData[channels + halfSize]; + CHECK(height * width * channels == numCols); + + // TODO(hedaoyuan) After commit TensorExpress code, + // Reconstruction this code to remove the temporary memory. + CpuMatrix tmp(channels, height * width); + CpuMatrix tmp2(tmp.getData(), 1, channels * height * width); + denoms.zero(); + const int start = -((int)sizeX - 1) / 2; + const int end = (int)sizeX + start; + for (size_t i = 0; i < numSample; i++) { + input.subMatrix(i, 1)->square2(tmp2); + CpuMatrix subDen( + denoms.subMatrix(i, 1)->getData(), channels, height * width); + for (int c = 0; c < (int)channels; c++) { + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + subDen.subMatrix(c, 1)->add(*tmp.subMatrix(c + s, 1)); } - for (size_t k = 0; k < channels; k += 1) { - real a = integralData[k + sizeX] - integralData[k]; - a = scale * a + 1; - targetPtr[k * imgPixels] = imgPtr[k * imgPixels] * _pow(a, -pow); - } - diffPtr++; - targetPtr++; - imgPtr++; } } } - free(integralData); - integralData = NULL; + + denoms.add(scale, (real)1); + this->pow2(denoms, -pow); + this->dotMul(input); } void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, @@ -2282,19 +2273,63 @@ void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, size_t channels, size_t imgSizeH, size_t imgSizeW, - size_t size, + size_t sizeX, float scale, float pow) { - LOG(FATAL) << "Not implemented"; - - CHECK(imgSizeH * imgSizeW * channels == preOutV.getWidth()); - CHECK(denoms.getHeight() == preOutV.getHeight() && - denoms.getWidth() == preOutV.getWidth() && - preOutV.getHeight() == height_ && preOutV.getWidth() == width_); - CHECK(denoms.getHeight() == localGrad.getHeight() && - denoms.getWidth() == localGrad.getWidth()); - - // NOLINT // TODO: + CHECK(isContiguous()); + CHECK(localGrad.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK(preOutV.isContiguous()); + CHECK(localOutV.isContiguous()); + CHECK_EQ(getHeight(), localGrad.getHeight()); + CHECK_EQ(getWidth(), localGrad.getWidth()); + CHECK_EQ(getHeight(), denoms.getHeight()); + CHECK_EQ(getWidth(), denoms.getWidth()); + CHECK_EQ(getHeight(), preOutV.getHeight()); + CHECK_EQ(getWidth(), preOutV.getWidth()); + CHECK_EQ(getHeight(), localOutV.getHeight()); + CHECK_EQ(getWidth(), localOutV.getWidth()); + + size_t numSample = getHeight(); + size_t numCols = getWidth(); + size_t height = imgSizeH; + size_t width = imgSizeW; + CHECK(height * width * channels == numCols); + + // TODO(hedaoyuan) After commit TensorExpress code, + // Reconstruction this code to remove the temporary memory. + CpuMatrix tmp(1, height * width); + + const int start = -((int)sizeX) / 2; + const int end = (int)sizeX + start; + const real ratio = -(real)2 * scale * pow; + for (size_t i = 0; i < numSample; i++) { + CpuMatrix inputDiff( + this->subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix outDiff( + localGrad.subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix input( + preOutV.subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix output( + localOutV.subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix subDen( + denoms.subMatrix(i, 1)->getData(), channels, height * width); + + for (int c = 0; c < (int)channels; c++) { + tmp.pow2(*subDen.subMatrix(c, 1), -pow); + inputDiff.subMatrix(c, 1) + ->addDotMul(tmp, *outDiff.subMatrix(c, 1), (real)1, (real)1); + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + tmp.dotMul(*outDiff.subMatrix(c + s, 1), *output.subMatrix(c + s, 1)); + tmp.mulScalar(ratio); + tmp.dotDiv(tmp, *subDen.subMatrix(c + s, 1)); + tmp.dotMul(*input.subMatrix(c, 1)); + inputDiff.subMatrix(c, 1)->add(tmp); + } + } + } + } } /** diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 713792d82b3c5..5233a9af40155 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1261,6 +1261,121 @@ TEST(Matrix, MaxOutFwdBwd) { } } } +void testCrossMapNormalFwd( + int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { + float scale = 1.5; + float pow = 0.5; + int width = imgSizeH * imgSizeW * channels; + MatrixPtr input = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr denorms = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr target = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr inputGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr denormsGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr targetGpu = GpuMatrix::create(numSamples, width, false, true); + + input->randomizeUniform(); + target->randomizeUniform(); + inputGpu->copyFrom(*input); + targetGpu->copyFrom(*target); + + target->crossMapNormalFwd( + *input, imgSizeH, imgSizeW, *denorms, channels, sizeX, scale, pow); + targetGpu->crossMapNormalFwd( + *inputGpu, imgSizeH, imgSizeW, *denormsGpu, channels, sizeX, scale, pow); + + TensorCheckErr(*target, *targetGpu); + TensorCheckErr(*denorms, *denormsGpu); +} + +TEST(Matrix, crossMapNormalFwd) { + for (auto numSamples : {5, 32}) { + for (auto channels : {1, 5, 32}) { + for (auto imgSizeH : {5, 33, 100}) { + for (auto imgSizeW : {5, 32, 96}) { + for (auto sizeX : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " sizeX=" << sizeX; + testCrossMapNormalFwd( + numSamples, channels, imgSizeH, imgSizeW, sizeX); + } + } + } + } + } +} + +void testCrossMapNormalBwd( + int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { + float scale = 1.5; + float pow = 0.5; + size_t width = imgSizeH * imgSizeW * channels; + MatrixPtr localGrad = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr denoms = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr output = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr preOutV = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr localOutV = CpuMatrix::create(numSamples, width, false, false); + + localGrad->randomizeUniform(); + denoms->randomizeUniform(); + preOutV->randomizeUniform(); + localOutV->randomizeUniform(); + output->randomizeUniform(); + denoms->add(0.01); + + MatrixPtr localGradGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr denomsGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr outputGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr preOutVGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr localOutVGpu = GpuMatrix::create(numSamples, width, false, true); + + localGradGpu->copyFrom(*localGrad); + denomsGpu->copyFrom(*denoms); + preOutVGpu->copyFrom(*preOutV); + localOutVGpu->copyFrom(*localOutV); + outputGpu->copyFrom(*output); + + output->crossMapNormalBwd(*localGrad, + *denoms, + *preOutV, + *localOutV, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + outputGpu->crossMapNormalBwd(*localGradGpu, + *denomsGpu, + *preOutVGpu, + *localOutVGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + TensorCheckErr(*output, *outputGpu); +} + +TEST(Matrix, crossMapNormalBwd) { + for (auto numSamples : {5, 32}) { + for (auto channels : {1, 5, 32}) { + for (auto imgSizeH : {5, 33, 100}) { + for (auto imgSizeW : {5, 32, 96}) { + for (auto sizeX : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " sizeX=" << sizeX; + testCrossMapNormalBwd( + numSamples, channels, imgSizeH, imgSizeW, sizeX); + } + } + } + } + } +} int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); From 95035908b4f47e61bad12d0ed49bf62a1734b2cf Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 13 Dec 2016 14:27:42 +0800 Subject: [PATCH 02/55] add CrossMapNormal --- paddle/math/cross_map_normal_op.cpp | 129 +++++++++++++++++++++ paddle/math/cross_map_normal_op.h | 47 ++++++++ paddle/math/tests/test_matrixCompare.cpp | 137 ++++++++++++----------- 3 files changed, 248 insertions(+), 65 deletions(-) create mode 100644 paddle/math/cross_map_normal_op.cpp create mode 100644 paddle/math/cross_map_normal_op.h diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp new file mode 100644 index 0000000000000..3eb51b5998fc4 --- /dev/null +++ b/paddle/math/cross_map_normal_op.cpp @@ -0,0 +1,129 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "cross_map_normal_op.h" + +namespace paddle { + +// NCHW +void CrossMapNormal::operator()(CpuMatrix& outputs, + CpuMatrix& denoms, + CpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(outputs.isContiguous()); + CHECK(inputs.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK_EQ(outputs.getHeight(), inputs.getHeight()); + CHECK_EQ(outputs.getWidth(), inputs.getWidth()); + CHECK_EQ(outputs.getHeight(), denoms.getHeight()); + CHECK_EQ(outputs.getWidth(), denoms.getWidth()); + + size_t numSample = inputs.getHeight(); + size_t numCols = inputs.getWidth(); + size_t imageSize = imgSizeH * imgSizeW; + CHECK(imageSize * channels == numCols); + + denoms = denoms.constant(1.0); + const int start = -((int)sizeX - 1) / 2; + const int end = (int)sizeX + start; + for (size_t i = 0; i < numSample; i++) { + real* denomsData = denoms.getData() + i * numCols; + real* inputData = inputs.getData() + i * numCols; + for (int c = 0; c < (int)channels; c++) { + CpuVector denom(imageSize, denomsData + c * imageSize); + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + CpuVector input(imageSize, inputData + (c + s) * imageSize); + denom += input.square() * scale; + } + } + } + } + outputs = inputs * denoms.pow(-pow); +} + +void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, + CpuMatrix& inputsValue, + CpuMatrix& outputsGrad, + CpuMatrix& outputsValue, + CpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(inputsGrad.isContiguous()); + CHECK(outputsGrad.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK(inputsValue.isContiguous()); + CHECK(outputsValue.isContiguous()); + CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); + + size_t numSample = inputsGrad.getHeight(); + size_t numCols = inputsGrad.getWidth(); + size_t imageSize = imgSizeH * imgSizeW; + CHECK(imageSize * channels == numCols); + + std::function oneImage = [=](real* data, + size_t offset) { + return CpuVector(imageSize, data + offset); + }; + + const int start = -((int)sizeX) / 2; + const int end = (int)sizeX + start; + const real ratio = -(real)2 * scale * pow; + for (size_t i = 0; i < numSample; i++) { + size_t sOffset = i * numCols; + real* inputGradData = inputsGrad.getData() + sOffset; + real* inputData = inputsValue.getData() + sOffset; + real* denomData = denoms.getData() + sOffset; + real* outputGradData = outputsGrad.getData() + sOffset; + real* outputData = outputsValue.getData() + sOffset; + + for (int c = 0; c < (int)channels; c++) { + size_t cOffset = c * imageSize; + CpuVector inputGrad = oneImage(inputGradData, cOffset); + CpuVector inputValue = oneImage(inputData, cOffset); + CpuVector denom = oneImage(denomData, cOffset); + CpuVector outputGrad = oneImage(outputGradData, cOffset); + + inputGrad = inputGrad + denom.pow(-pow) * outputGrad; + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + size_t offset = (c + s) * imageSize; + CpuVector output = oneImage(outputData, offset); + CpuVector outputGrad = oneImage(outputGradData, offset); + CpuVector denom = oneImage(denomData, offset); + + inputGrad += ((outputGrad * output * ratio) / denom) * inputValue; + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h new file mode 100644 index 0000000000000..2f996072528a0 --- /dev/null +++ b/paddle/math/cross_map_normal_op.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/math/Matrix.h" + +namespace paddle { + +struct CrossMapNormal { + void operator()(CpuMatrix& outputs, + CpuMatrix& denoms, + CpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow); +}; + +struct CrossMapNormalGrad { + void operator()(CpuMatrix& inputsGrad, + CpuMatrix& inputsValue, + CpuMatrix& outputsGrad, + CpuMatrix& outputsValue, + CpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow); +}; + +} // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 5233a9af40155..9bb1fdbdab83a 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/gserver/tests/TestUtil.h" #include "paddle/utils/Stat.h" #include "TensorCheck.h" +#include "paddle/math/cross_map_normal_op.h" using namespace paddle; // NOLINT using namespace std; // NOLINT @@ -1261,30 +1262,32 @@ TEST(Matrix, MaxOutFwdBwd) { } } } + void testCrossMapNormalFwd( int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { float scale = 1.5; float pow = 0.5; int width = imgSizeH * imgSizeW * channels; - MatrixPtr input = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr denorms = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr target = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr inputGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr denormsGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr targetGpu = GpuMatrix::create(numSamples, width, false, true); - - input->randomizeUniform(); - target->randomizeUniform(); - inputGpu->copyFrom(*input); - targetGpu->copyFrom(*target); - - target->crossMapNormalFwd( - *input, imgSizeH, imgSizeW, *denorms, channels, sizeX, scale, pow); - targetGpu->crossMapNormalFwd( - *inputGpu, imgSizeH, imgSizeW, *denormsGpu, channels, sizeX, scale, pow); - - TensorCheckErr(*target, *targetGpu); - TensorCheckErr(*denorms, *denormsGpu); + CpuMatrix inputs(numSamples, width); + CpuMatrix denoms(numSamples, width); + CpuMatrix outputs(numSamples, width); + GpuMatrix inputsGpu(numSamples, width); + GpuMatrix denomsGpu(numSamples, width); + GpuMatrix outputsGpu(numSamples, width); + + inputs.randomizeUniform(); + outputs.randomizeUniform(); + inputsGpu.copyFrom(inputs); + outputsGpu.copyFrom(outputs); + + CrossMapNormal cross; + cross( + outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); + outputsGpu.crossMapNormalFwd( + inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow); + + TensorCheckErr(outputs, outputsGpu); + TensorCheckErr(denoms, denomsGpu); } TEST(Matrix, crossMapNormalFwd) { @@ -1310,53 +1313,57 @@ void testCrossMapNormalBwd( float scale = 1.5; float pow = 0.5; size_t width = imgSizeH * imgSizeW * channels; - MatrixPtr localGrad = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr denoms = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr output = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr preOutV = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr localOutV = CpuMatrix::create(numSamples, width, false, false); - - localGrad->randomizeUniform(); - denoms->randomizeUniform(); - preOutV->randomizeUniform(); - localOutV->randomizeUniform(); - output->randomizeUniform(); - denoms->add(0.01); - - MatrixPtr localGradGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr denomsGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr outputGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr preOutVGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr localOutVGpu = GpuMatrix::create(numSamples, width, false, true); - - localGradGpu->copyFrom(*localGrad); - denomsGpu->copyFrom(*denoms); - preOutVGpu->copyFrom(*preOutV); - localOutVGpu->copyFrom(*localOutV); - outputGpu->copyFrom(*output); - output->crossMapNormalBwd(*localGrad, - *denoms, - *preOutV, - *localOutV, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - outputGpu->crossMapNormalBwd(*localGradGpu, - *denomsGpu, - *preOutVGpu, - *localOutVGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - - TensorCheckErr(*output, *outputGpu); + CpuMatrix inputsGrad(numSamples, width); + CpuMatrix inputsValue(numSamples, width); + CpuMatrix outputsGrad(numSamples, width); + CpuMatrix outputsValue(numSamples, width); + CpuMatrix denoms(numSamples, width); + + outputsGrad.randomizeUniform(); + denoms.randomizeUniform(); + inputsValue.randomizeUniform(); + outputsValue.randomizeUniform(); + inputsGrad.randomizeUniform(); + denoms.add(0.01); + + GpuMatrix inputsGradGpu(numSamples, width); + GpuMatrix inputsValueGpu(numSamples, width); + GpuMatrix outputsGradGpu(numSamples, width); + GpuMatrix outputsValueGpu(numSamples, width); + GpuMatrix denomsGpu(numSamples, width); + + outputsGradGpu.copyFrom(outputsGrad); + denomsGpu.copyFrom(denoms); + inputsValueGpu.copyFrom(inputsValue); + outputsValueGpu.copyFrom(outputsValue); + inputsGradGpu.copyFrom(inputsGrad); + + CrossMapNormalGrad cross; + cross(inputsGrad, + inputsValue, + outputsGrad, + outputsValue, + denoms, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + inputsGradGpu.crossMapNormalBwd(outputsGradGpu, + denomsGpu, + inputsValueGpu, + outputsValueGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + TensorCheckErr(inputsGrad, inputsGradGpu); } TEST(Matrix, crossMapNormalBwd) { From e357f2715843cd531ce0b0143647ed5561d2fceb Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 13 Dec 2016 17:55:31 +0800 Subject: [PATCH 03/55] add GPU CrossMapNormal --- paddle/math/cross_map_normal_op.cpp | 42 ++--- paddle/math/cross_map_normal_op.h | 37 ++++- paddle/math/cross_map_normal_op_gpu.cu | 194 +++++++++++++++++++++++ paddle/math/tests/test_matrixCompare.cpp | 66 +++++--- 4 files changed, 286 insertions(+), 53 deletions(-) create mode 100644 paddle/math/cross_map_normal_op_gpu.cu diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index 3eb51b5998fc4..be242926aff16 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -17,15 +17,16 @@ limitations under the License. */ namespace paddle { // NCHW -void CrossMapNormal::operator()(CpuMatrix& outputs, - CpuMatrix& denoms, - CpuMatrix& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { +template <> +void CrossMapNormal::operator()(CpuMatrix& outputs, + CpuMatrix& denoms, + CpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { CHECK(outputs.isContiguous()); CHECK(inputs.isContiguous()); CHECK(denoms.isContiguous()); @@ -58,17 +59,18 @@ void CrossMapNormal::operator()(CpuMatrix& outputs, outputs = inputs * denoms.pow(-pow); } -void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, - CpuMatrix& inputsValue, - CpuMatrix& outputsGrad, - CpuMatrix& outputsValue, - CpuMatrix& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { +template <> +void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, + CpuMatrix& inputsValue, + CpuMatrix& outputsGrad, + CpuMatrix& outputsValue, + CpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { CHECK(inputsGrad.isContiguous()); CHECK(outputsGrad.isContiguous()); CHECK(denoms.isContiguous()); diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index 2f996072528a0..c2bb95f6b11fb 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -18,10 +18,30 @@ limitations under the License. */ namespace paddle { +enum DeviceType { + DEVICE_TYPE_UNSPECIFIED = 0, + DEVICE_TYPE_CPU = 1, + DEVICE_TYPE_GPU = 2, +}; + +template +struct MatrixT; + +template <> +struct MatrixT { + using type = CpuMatrix; +}; + +template <> +struct MatrixT { + using type = GpuMatrix; +}; + +template struct CrossMapNormal { - void operator()(CpuMatrix& outputs, - CpuMatrix& denoms, - CpuMatrix& inputs, + void operator()(typename MatrixT::type& outputs, + typename MatrixT::type& denoms, + typename MatrixT::type& inputs, size_t channels, size_t imgSizeH, size_t imgSizeW, @@ -30,12 +50,13 @@ struct CrossMapNormal { real pow); }; +template struct CrossMapNormalGrad { - void operator()(CpuMatrix& inputsGrad, - CpuMatrix& inputsValue, - CpuMatrix& outputsGrad, - CpuMatrix& outputsValue, - CpuMatrix& denoms, + void operator()(typename MatrixT::type& inputsGrad, + typename MatrixT::type& inputsValue, + typename MatrixT::type& outputsGrad, + typename MatrixT::type& outputsValue, + typename MatrixT::type& denoms, size_t channels, size_t imgSizeH, size_t imgSizeW, diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/math/cross_map_normal_op_gpu.cu new file mode 100644 index 0000000000000..0a154d97ac02f --- /dev/null +++ b/paddle/math/cross_map_normal_op_gpu.cu @@ -0,0 +1,194 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "cross_map_normal_op.h" + +namespace paddle { + +__global__ void KeCMRNormFillScale(size_t imageSize, const real* in, + real* scale, size_t channels, + size_t height, size_t width, size_t size, + real alpha) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + + in += offset; + scale += offset; + const int step = height * width; + const int pre_pad = (size - 1) / 2; + const int post_pad = size - pre_pad - 1; + + real accum = 0; + int index = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += in[index * step] * in[index * step]; + } + if (index >= size) { + accum -= in[(index - size) * step] * in[(index - size) * step]; + } + if (index >= post_pad) { + scale[(index - post_pad) * step] = 1. + accum * alpha; + } + ++index; + } + } +} + +__global__ void KeCMRNormOutput(size_t inputSize, const real* in, + const real* scale, real negative_beta, + real* out) { + const int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < inputSize) { + out[index] = in[index] * pow(scale[index], negative_beta); + } +} + +template <> +void CrossMapNormal::operator()(GpuMatrix& outputs, + GpuMatrix& denoms, + GpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(outputs.isContiguous()); + CHECK(inputs.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK_EQ(outputs.getHeight(), inputs.getHeight()); + CHECK_EQ(outputs.getWidth(), inputs.getWidth()); + CHECK_EQ(outputs.getHeight(), denoms.getHeight()); + CHECK_EQ(outputs.getWidth(), denoms.getWidth()); + + size_t numSample = inputs.getHeight(); + size_t numCols = inputs.getWidth(); + CHECK(imgSizeH * imgSizeW * channels == numCols); + + real* inputsData = inputs.getData(); + real* denomsData = denoms.getData(); + real* outputsData = outputs.getData(); + + size_t imageSize = numSample * imgSizeH * imgSizeW; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormFillScale<<>> + (imageSize, inputsData, denomsData, + channels, imgSizeH, imgSizeW, sizeX, scale); + + size_t inputSize = numSample * imgSizeH * imgSizeW *channels; + blockSize = 1024; + gridSize = (inputSize + 1024 - 1) / 1024; + KeCMRNormOutput<<>> + (inputSize, inputsData, denomsData, -pow, outputsData); + + CHECK_SYNC("CrossMapNormalFwd"); +} + +__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, + const real* top_data, const real* scale, + const real* top_diff, size_t channels, + size_t height, size_t width, size_t size, + real negative_beta, real cache_ratio, + real* bottom_diff ) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + bottom_data += offset; + top_data += offset; + scale += offset; + top_diff += offset; + bottom_diff += offset; + + const int step = height * width; + const int pre_pad = size - (size + 1) / 2; + const int post_pad = size - pre_pad - 1; + + int index = 0; + real accum = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += top_diff[index * step] * top_data[index * step] / + scale[index * step]; + } + if (index >= size) { + accum -= top_diff[(index - size) * step] * + top_data[(index - size) * step] / scale[(index - size) * step]; + } + if (index >= post_pad) { + bottom_diff[(index - post_pad) * step] += + top_diff[(index - post_pad) * step] * + pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(index - post_pad) * step] * accum; + } + ++index; + } + } +} + +template <> +void CrossMapNormalGrad::operator()(GpuMatrix& inputsGrad, + GpuMatrix& inputsValue, + GpuMatrix& outputsGrad, + GpuMatrix& outputsValue, + GpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(inputsGrad.isContiguous()); + CHECK(outputsGrad.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK(inputsValue.isContiguous()); + CHECK(outputsValue.isContiguous()); + CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); + + size_t numSample = inputsGrad.getHeight(); + size_t numCols = inputsGrad.getWidth(); + CHECK(imgSizeH * imgSizeW * channels == numCols); + + size_t imageSize = numSample * imgSizeH * imgSizeW; + real* inputsGradData = inputsGrad.getData(); + real* inputsData = inputsValue.getData(); + real* denomsData = denoms.getData(); + real* outputsGradData = outputsGrad.getData(); + real* outputsData = outputsValue.getData(); + + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormDiff <<>> + (imageSize, inputsData, outputsData, denomsData, outputsGradData, channels, + imgSizeH, imgSizeW, sizeX, -pow, 2.0f * pow * scale, inputsGradData); + CHECK_SYNC("KeCMRNormDiff"); +} + +} // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 9bb1fdbdab83a..8d7a4fb94d0a1 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1280,11 +1280,25 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); - CrossMapNormal cross; - cross( + CrossMapNormal cpuCross; + cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); + + CrossMapNormal gpuCross; + gpuCross(outputsGpu, + denomsGpu, + inputsGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + +#if 0 outputsGpu.crossMapNormalFwd( inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow); +#endif TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); @@ -1339,29 +1353,31 @@ void testCrossMapNormalBwd( outputsValueGpu.copyFrom(outputsValue); inputsGradGpu.copyFrom(inputsGrad); - CrossMapNormalGrad cross; - cross(inputsGrad, - inputsValue, - outputsGrad, - outputsValue, - denoms, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - - inputsGradGpu.crossMapNormalBwd(outputsGradGpu, - denomsGpu, - inputsValueGpu, - outputsValueGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); + CrossMapNormalGrad cpuCross; + cpuCross(inputsGrad, + inputsValue, + outputsGrad, + outputsValue, + denoms, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + CrossMapNormalGrad gpuCross; + gpuCross(inputsGradGpu, + inputsValueGpu, + outputsGradGpu, + outputsValueGpu, + denomsGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); TensorCheckErr(inputsGrad, inputsGradGpu); } From 0eac39928090c44fc3b8b4edc18604ff7b662f91 Mon Sep 17 00:00:00 2001 From: yuan Date: Tue, 13 Dec 2016 20:57:59 +0800 Subject: [PATCH 04/55] priorbox layer for ssd --- paddle/gserver/layers/PriorBox.cpp | 137 ++++++++++++++++++ proto/ModelConfig.proto | 10 ++ python/paddle/trainer/config_parser.py | 13 ++ .../paddle/trainer_config_helpers/layers.py | 36 +++++ 4 files changed, 196 insertions(+) create mode 100644 paddle/gserver/layers/PriorBox.cpp diff --git a/paddle/gserver/layers/PriorBox.cpp b/paddle/gserver/layers/PriorBox.cpp new file mode 100644 index 0000000000000..b0d59cd145ce8 --- /dev/null +++ b/paddle/gserver/layers/PriorBox.cpp @@ -0,0 +1,137 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" +#include "paddle/math/Matrix.h" +#include "paddle/math/BaseMatrix.h" + +namespace paddle { + +class PriorBoxLayer : public Layer { +public: + explicit PriorBoxLayer(const LayerConfig& config) : Layer(config) {} + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + void forward(PassType passType); + void backward(const UpdateCallback& callback) {} + int numPriors_; + std::vector minSize_; + std::vector maxSize_; + std::vector aspectRatio_; + std::vector variance_; + MatrixPtr buffer_; +}; + +bool PriorBoxLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + std::copy(config_.inputs(0).priorbox_conf().min_size().begin(), + config_.inputs(0).priorbox_conf().min_size().end(), + std::back_inserter(minSize_)); + std::copy(config_.inputs(0).priorbox_conf().max_size().begin(), + config_.inputs(0).priorbox_conf().max_size().end(), + std::back_inserter(maxSize_)); + std::copy(config_.inputs(0).priorbox_conf().aspect_ratio().begin(), + config_.inputs(0).priorbox_conf().aspect_ratio().end(), + std::back_inserter(aspectRatio_)); + std::copy(config_.inputs(0).priorbox_conf().variance().begin(), + config_.inputs(0).priorbox_conf().variance().end(), + std::back_inserter(variance_)); + // flip + int input_ratio_length = aspectRatio_.size(); + for (int index = 0; index < input_ratio_length; index++) + aspectRatio_.push_back(1 / aspectRatio_[index]); + aspectRatio_.push_back(1.); + numPriors_ = aspectRatio_.size(); + if (maxSize_.size() > 0) + numPriors_++; + buffer_ = Matrix::create(1, 1, false, false); + return true; +} + +void PriorBoxLayer::forward(PassType passType) { + Layer::forward(passType); + auto input = getInput(0); + int layer_width = input.getFrameWidth(); + int layer_height = input.getFrameHeight(); + + MatrixPtr inV1 = getInputValue(1); + int image_width = inV1->getElement(0, 0); + int image_height = inV1->getElement(0, 1); + float step_w = static_cast(image_width) / layer_width; + float step_h = static_cast(image_height) / layer_height; + int dim = layer_height * layer_width * numPriors_ * 4; + reserveOutput(1, dim * 2); + // use a cpu buffer to compute + Matrix::resizeOrCreate(buffer_, 1, dim * 2, false, false); + auto* tmp_ptr = buffer_->getData(); + + int idx = 0; + for (int h = 0; h < layer_height; ++h) { + for (int w = 0; w < layer_width; ++w) { + float center_x = (w + 0.5) * step_w; + float center_y = (h + 0.5) * step_h; + int min_size = 0; + for (size_t s = 0; s < minSize_.size(); s++) { + // first prior. + min_size = minSize_[s]; + int box_width = min_size; + int box_height = min_size; + // xmin, ymin, xmax, ymax. + tmp_ptr[idx++] = (center_x - box_width / 2.) / image_width; + tmp_ptr[idx++] = (center_y - box_height / 2.) / image_height; + tmp_ptr[idx++] = (center_x + box_width / 2.) / image_width; + tmp_ptr[idx++] = (center_y + box_height / 2.) / image_height; + + if (maxSize_.size() > 0) { + CHECK_EQ(minSize_.size(), maxSize_.size()); + // second prior. + for (size_t s = 0; s < maxSize_.size(); s++) { + int max_size = maxSize_[s]; + box_width = box_height = sqrt(min_size * max_size); + tmp_ptr[idx++] = (center_x - box_width / 2.) / image_width; + tmp_ptr[idx++] = (center_y - box_height / 2.) / image_height; + tmp_ptr[idx++] = (center_x + box_width / 2.) / image_width; + tmp_ptr[idx++] = (center_y + box_height / 2.) / image_height; + } + } + } + // rest of priors. + for (size_t r = 0; r < aspectRatio_.size(); r++) { + float ar = aspectRatio_[r]; + if (fabs(ar - 1.) < 1e-6) + continue; + float box_width = min_size * sqrt(ar); + float box_height = min_size / sqrt(ar); + tmp_ptr[idx++] = (center_x - box_width / 2.) / image_width; + tmp_ptr[idx++] = (center_y - box_height / 2.) / image_height; + tmp_ptr[idx++] = (center_x + box_width / 2.) / image_width; + tmp_ptr[idx++] = (center_y + box_height / 2.) / image_height; + } + } + } + // clip the prior's coordidate such that it is within [0, 1] + for (int d = 0; d < dim; ++d) + tmp_ptr[d] = std::min(std::max(tmp_ptr[d], (float)0.), (float)1.); + // set the variance. + for (int h = 0; h < layer_height; h++) + for (int w = 0; w < layer_width; w++) + for (int i = 0; i < numPriors_; i++) + for (int j = 0; j < 4; j++) + tmp_ptr[idx++] = variance_[j]; + MatrixPtr outV = getOutputValue(); + outV->copyFrom(buffer_->data_, dim * 2); +} +REGISTER_LAYER(priorbox, PriorBoxLayer); + +} // namespace paddle diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index b34e1ebdedab1..460a39275fbe4 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -248,6 +248,15 @@ message ImageConfig { required uint32 img_size_y = 9; } +message PriorBoxConfig { + repeated uint32 min_size = 1; + repeated uint32 max_size = 2; + repeated float aspect_ratio = 3; + repeated float variance = 4; + optional bool flip = 5 [default = true]; + optional bool clip = 6 [default = true]; +} + message LayerInputConfig { required string input_layer_name = 1; optional string input_parameter_name = 2; @@ -263,6 +272,7 @@ message LayerInputConfig { optional BilinearInterpConfig bilinear_interp_conf = 10; optional MaxOutConfig maxout_conf = 11; optional SppConfig spp_conf = 12; + optional PriorBoxConfig priorbox_conf = 13; } message LayerConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 5b7f4d85e2c33..5de524e507bdf 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1577,6 +1577,19 @@ class PrintLayer(LayerBase): def __init__(self, name, inputs): super(PrintLayer, self).__init__(name, 'print', 0, inputs) +@config_layer('priorbox') +class PriorBoxLayer(LayerBase): + def __init__(self, name, inputs, size, min_size, max_size, aspect_ratio, variance): + super(PriorBoxLayer, self).__init__(name, 'priorbox', 0, inputs) + config_assert(len(inputs) == 2, 'PriorBoxLayer must have 2 input') + self.config.inputs[0].priorbox_conf.min_size.extend(min_size) + self.config.inputs[0].priorbox_conf.max_size.extend(max_size) + self.config.inputs[0].priorbox_conf.aspect_ratio.extend(aspect_ratio) + self.config.inputs[0].priorbox_conf.variance.extend(variance) + self.config.size = size + input_layer0 = self.get_input_layer(0) + input_layer1 = self.get_input_layer(1) + @config_layer('data') class DataLayer(LayerBase): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 8dd6b7b7d28f8..f04b5646aab03 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -106,6 +106,7 @@ 'maxout_layer', 'out_prod_layer', 'print_layer', + 'priorbox_layer', 'spp_layer', ] @@ -171,6 +172,7 @@ class LayerType(object): SPP_LAYER = "spp" PRINT_LAYER = "print" + PRIORBOX_LAYER = "priorbox" CTC_LAYER = "ctc" WARP_CTC_LAYER = "warp_ctc" @@ -933,6 +935,40 @@ def print_layer(input, name=None): inputs=[l.name for l in input], ) # this layer don't return anything, can not be input of other layer. +@wrap_name_default("priorbox") +def priorbox_layer(input, img_shape, aspect_ratio, variance, min_size, max_size=[], name=None): + """ + Compute the priorbox and set the variance. This layer is necessary for ssd. + + :param name: The Layer Name. + :type name: basestring + :param input: The input layer. + :type input: LayerOutput + :param img_shape: The width and height of the network input image. + :type img_shape: LayerOutput + :param aspect_ratio: The aspect ratio. + :type aspect_ratio: list + :param variance: The bounding box variance. + :type min_size: The min size of the priorbox width/height. + :param min_size: list + :type max_size: The max size of the priorbox width/height. Could be NULL. + :param max_size: list + :return: LayerOutput + """ + # plus one for ratio 1. + num_filters = (len(aspect_ratio) * 2 + 1 + len(max_size)) * 4 + size=(input.size / input.num_filters) * num_filters * 2 + Layer( + name=name, + type=LayerType.PRIORBOX_LAYER, + inputs=[input.name, img_shape.name], + size=size, + min_size=min_size, + max_size=max_size, + aspect_ratio=aspect_ratio, + variance=variance) + return LayerOutput( + name, LayerType.PRIORBOX_LAYER, parents=[input, img_shape], num_filters=num_filters, size=size) @wrap_name_default("seq_pooling") @wrap_bias_attr_default(has_bias=False) From 39d689e2536c0838f01e9f87fc232f0822273557 Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Wed, 14 Dec 2016 15:19:57 +0800 Subject: [PATCH 05/55] Format the priorbox code --- paddle/gserver/layers/PriorBox.cpp | 34 ++++++++++++++---------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/paddle/gserver/layers/PriorBox.cpp b/paddle/gserver/layers/PriorBox.cpp index b0d59cd145ce8..994f7c20384ff 100644 --- a/paddle/gserver/layers/PriorBox.cpp +++ b/paddle/gserver/layers/PriorBox.cpp @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,28 +33,28 @@ class PriorBoxLayer : public Layer { }; bool PriorBoxLayer::init(const LayerMap& layerMap, - const ParameterMap& parameterMap) { + const ParameterMap& parameterMap) { Layer::init(layerMap, parameterMap); - std::copy(config_.inputs(0).priorbox_conf().min_size().begin(), - config_.inputs(0).priorbox_conf().min_size().end(), + auto pb_conf = config_.inputs(0).priorbox_conf(); + std::copy(pb_conf.min_size().begin(), + pb_conf.min_size().end(), std::back_inserter(minSize_)); - std::copy(config_.inputs(0).priorbox_conf().max_size().begin(), - config_.inputs(0).priorbox_conf().max_size().end(), + std::copy(pb_conf.max_size().begin(), + pb_conf.max_size().end(), std::back_inserter(maxSize_)); - std::copy(config_.inputs(0).priorbox_conf().aspect_ratio().begin(), - config_.inputs(0).priorbox_conf().aspect_ratio().end(), + std::copy(pb_conf.aspect_ratio().begin(), + pb_conf.aspect_ratio().end(), std::back_inserter(aspectRatio_)); - std::copy(config_.inputs(0).priorbox_conf().variance().begin(), - config_.inputs(0).priorbox_conf().variance().end(), + std::copy(pb_conf.variance().begin(), + pb_conf.variance().end(), std::back_inserter(variance_)); // flip int input_ratio_length = aspectRatio_.size(); for (int index = 0; index < input_ratio_length; index++) - aspectRatio_.push_back(1 / aspectRatio_[index]); + aspectRatio_.push_back(1 / aspectRatio_[index]); aspectRatio_.push_back(1.); numPriors_ = aspectRatio_.size(); - if (maxSize_.size() > 0) - numPriors_++; + if (maxSize_.size() > 0) numPriors_++; buffer_ = Matrix::create(1, 1, false, false); return true; } @@ -79,7 +79,7 @@ void PriorBoxLayer::forward(PassType passType) { int idx = 0; for (int h = 0; h < layer_height; ++h) { for (int w = 0; w < layer_width; ++w) { - float center_x = (w + 0.5) * step_w; + float center_x = (w + 0.5) * step_w; float center_y = (h + 0.5) * step_h; int min_size = 0; for (size_t s = 0; s < minSize_.size(); s++) { @@ -109,8 +109,7 @@ void PriorBoxLayer::forward(PassType passType) { // rest of priors. for (size_t r = 0; r < aspectRatio_.size(); r++) { float ar = aspectRatio_[r]; - if (fabs(ar - 1.) < 1e-6) - continue; + if (fabs(ar - 1.) < 1e-6) continue; float box_width = min_size * sqrt(ar); float box_height = min_size / sqrt(ar); tmp_ptr[idx++] = (center_x - box_width / 2.) / image_width; @@ -127,8 +126,7 @@ void PriorBoxLayer::forward(PassType passType) { for (int h = 0; h < layer_height; h++) for (int w = 0; w < layer_width; w++) for (int i = 0; i < numPriors_; i++) - for (int j = 0; j < 4; j++) - tmp_ptr[idx++] = variance_[j]; + for (int j = 0; j < 4; j++) tmp_ptr[idx++] = variance_[j]; MatrixPtr outV = getOutputValue(); outV->copyFrom(buffer_->data_, dim * 2); } From 438a70477fcae07a31a1c9a414190c9fb0a8acea Mon Sep 17 00:00:00 2001 From: livc Date: Wed, 14 Dec 2016 15:20:32 +0800 Subject: [PATCH 06/55] add rnn_cn.md --- doc/howto/deep_model/rnn/rnn_cn.md | 226 +++++++++++++++++++++++++++++ 1 file changed, 226 insertions(+) create mode 100644 doc/howto/deep_model/rnn/rnn_cn.md diff --git a/doc/howto/deep_model/rnn/rnn_cn.md b/doc/howto/deep_model/rnn/rnn_cn.md new file mode 100644 index 0000000000000..496a54d0113f8 --- /dev/null +++ b/doc/howto/deep_model/rnn/rnn_cn.md @@ -0,0 +1,226 @@ +RNN 配置 +================= + +本教程将指导你如何在 PaddlePaddle 中配置循环神经网络(RNN)。PaddlePaddle 高度支持灵活和高效的循环神经网络配置。 在本教程中,您将了解如何: + +- 准备用来学习循环神经网络的序列数据。 +- 配置循环神经网络架构。 +- 使用学习完成的循环神经网络模型生成序列。 + +我们将使用 vanilla 循环神经网络和 sequence to sequence 模型来指导你完成这些步骤。sequence to sequence 模型的代码可以在`demo / seqToseq`找到。 + +准备序列数据 +--------------------- + +PaddlePaddle 不需要对序列数据进行任何预处理,例如填充。唯一需要做的是将相应类型设置为输入。例如,以下代码段定义了三个输入。 它们都是序列,它们的大小是`src_dict`,`trg_dict`和`trg_dict`: + +``` sourceCode +settings.input_types = [ + integer_value_sequence(len(settings.src_dict)), + integer_value_sequence(len(settings.trg_dict)), + integer_value_sequence(len(settings.trg_dict))] +``` + +在`process`函数中,每个`yield`函数将返回三个整数列表。每个整数列表被视为一个整数序列: + +``` sourceCode +yield src_ids, trg_ids, trg_ids_next +``` + +有关如何编写数据提供程序的更多细节描述,请参考 [PyDataProvider2](../../ui/data_provider/index.html)。完整的数据提供文件在 `demo/seqToseq/dataprovider.py`。 + +配置循环神经网络架构 +----------------------------------------------- + +### 简单门控(Simple Gated)循环神经网络 + +循环神经网络在每个时间步骤顺序地处理序列。下面列出了 LSTM 的架构的示例。 + +![image](../../../tutorials/sentiment_analysis/bi_lstm.jpg) + +一般来说,循环网络从 *t* = 1 到 *t* = *T* 或者相反从 *t* = *T* 到 *t* = 1 执行以下操作。 + +*x**t* + 1 = *f**x*(*x**t*),*y**t* = *f**y*(*x**t*) + +其中 *f**x*(.) 称为**阶跃函数**,而 *f**y*(.) 称为**输出函数**。在 vanilla 循环神经网络中,阶跃函数和输出函数都非常简单。然而,PaddlePaddle 支持通过修改这两个函数来配置非常复杂的架构。 我们将使用 sequence to sequence 模型演示如何配置复杂的循环神经网络模型。在本节中,我们将使用简单的 vanilla 循环神经网络作为使用`recurrent_group`配置简单循环神经网络的例子。 注意,如果你只需要使用简单的RNN,GRU或LSTM,那么推荐使用`grumemory`和`lstmemory`,因为它们的计算效率比`recurrent_group`更高。 + +对于 vanilla RNN,在每个时间步长,**阶跃函数**为: + +*x**t* + 1 = *W**x**x**t* + *W**i**I**t* + *b* + +其中 *x**t* 是RNN状态,并且 *I**t* 是输入,*W**x* 和 *W**i* 分别是RNN状态和输入的变换矩阵。*b* 是偏差。它的**输出函数**只需要*x**t*作为输出。 + +`recurrent_group`是构建循环神经网络的最重要的工具。 它定义了**阶跃函数**,**输出函数**和循环神经网络的输入。注意,这个函数的`step`参数执行了`step function`(阶跃函数)和`output function`(输出函数): + + +``` sourceCode +def simple_rnn(input, + size=None, + name=None, + reverse=False, + rnn_bias_attr=None, + act=None, + rnn_layer_attr=None): + def __rnn_step__(ipt): + out_mem = memory(name=name, size=size) + rnn_out = mixed_layer(input = [full_matrix_projection(ipt), + full_matrix_projection(out_mem)], + name = name, + bias_attr = rnn_bias_attr, + act = act, + layer_attr = rnn_layer_attr, + size = size) + return rnn_out + return recurrent_group(name='%s_recurrent_group' % name, + step=__rnn_step__, + reverse=reverse, + input=input) +``` + +PaddlePaddle 使用“记忆”构造阶跃函数。**记忆(Memory)**是在PaddlePaddle中构造循环神经网络时最重要的概念。 记忆是在阶跃函数中循环使用的状态,例如*x**t* + 1 = *f**x*(*x**t*)。 一个记忆包含**输出**和**输入**。当前时间步处的记忆的输出作为下一时间步记忆的输入。记忆也可以具有**引导层**,其输出被用作记忆的初始值。 在我们的例子中,门控循环单元的输出被用作输出记忆。请注意,`rnn_out`层的名称与`out_mem`的名称相同。这意味着`rnn_out` (*x**t* + 1)的输出被用作`out_mem`记忆的**输出**。 + +记忆也可以是序列。在这种情况下,在每个时间步中,我们有一个序列作为循环神经网络的状态。这在构造非常复杂的循环神经网络时是有用的。 其他高级功能包括定义多个记忆,以及使用子序列来定义分级循环神经网络架构。 + +我们在函数的结尾返回`rnn_out`。 这意味着 `rnn_out` 层的输出被用作门控循环神经网络的**输出**函数。 + +### Sequence to Sequence Model with Attention + +我们将使用 sequence to sequence model with attention 作为例子演示如何配置复杂的循环神经网络模型。该模型的说明如下图所示。 + +![image](../../../tutorials/text_generation/encoder-decoder-attention-model.png) + +在这个模型中,源序列 *S* = {*s*1, …, *s**T*} 用双向门控循环神经网络编码。双向门控循环神经网络的隐藏状态 *H**S* = {*H*1, …, *H**T*} 被称为 *编码向量*。解码器是门控循环神经网络。当解读每一个*y**t*时, 这个门控循环神经网络生成一系列权重 *W**S**t* = {*W*1*t*, …, *W**T**t*}, 用于计算编码向量的加权和。加权和用来鉴定符号 *y**t* 的生成。 + +模型的编码器部分如下所示。它叫做`grumemory`来表示门控循环神经网络。如果网络架构简单,那么推荐使用循环神经网络的方法,因为它比 `recurrent_group` 更快。我们已经实现了大多数常用的循环神经网络架构,可以参考 [Layers](../../ui/api/trainer_config_helpers/layers_index.html) 了解更多细节。 + +我们还将编码向量投射到`decoder_size`维空间,获得反向循环网络的第一个实例,并将其投射到`decoder_size`维空间: + +``` sourceCode +# 定义源语句的数据层 +src_word_id = data_layer(name='source_language_word', size=source_dict_dim) +# 计算每个词的词向量 +src_embedding = embedding_layer( + input=src_word_id, + size=word_vector_dim, + param_attr=ParamAttr(name='_source_language_embedding')) +# 应用前向循环神经网络 +src_forward = grumemory(input=src_embedding, size=encoder_size) +# 应用反向递归神经网络(reverse=True表示反向循环神经网络) +src_backward = grumemory(input=src_embedding, + size=encoder_size, + reverse=True) +# 将循环神经网络的前向和反向部分混合在一起 +encoded_vector = concat_layer(input=[src_forward, src_backward]) + +# 投射编码向量到 decoder_size +encoder_proj = mixed_layer(input = [full_matrix_projection(encoded_vector)], + size = decoder_size) + +# 计算反向RNN的第一个实例 +backward_first = first_seq(input=src_backward) + +# 投射反向RNN的第一个实例到 decoder size +decoder_boot = mixed_layer(input=[full_matrix_projection(backward_first)], size=decoder_size, act=TanhActivation()) +``` + +解码器使用 `recurrent_group` 来定义循环神经网络。阶跃函数和输出函数在 `gru_decoder_with_attention` 中定义: + +``` sourceCode +group_inputs=[StaticInput(input=encoded_vector,is_seq=True), + StaticInput(input=encoded_proj,is_seq=True)] +trg_embedding = embedding_layer( + input=data_layer(name='target_language_word', + size=target_dict_dim), + size=word_vector_dim, + param_attr=ParamAttr(name='_target_language_embedding')) +group_inputs.append(trg_embedding) + +# 对于配备有注意力机制的解码器,在训练中, +# 目标向量(groudtruth)是数据输入, +# 而编码源序列作为无界存储器被访问。 +# StaticInput 意味着不同时间步的相同值, +# 否则它是一个序列的输入,不同时间步的输入是不同的。 +# 所有输入序列应该有相同的长度。 +decoder = recurrent_group(name=decoder_group_name, + step=gru_decoder_with_attention, + input=group_inputs) +``` + +阶跃函数的实现如下所示。首先,它定义解码网络的**记忆**。然后定义 attention,门控循环单元阶跃函数和输出函数: + +``` sourceCode +def gru_decoder_with_attention(enc_vec, enc_proj, current_word): + # 定义解码器的记忆 + # 记忆的输出定义在 gru_step 内 + # 注意 gru_step 应该与它的记忆名字相同 + decoder_mem = memory(name='gru_decoder', + size=decoder_size, + boot_layer=decoder_boot) + # 计算 attention 加权编码向量 + context = simple_attention(encoded_sequence=enc_vec, + encoded_proj=enc_proj, + decoder_state=decoder_mem) + # 混合当前词向量和attention加权编码向量 + decoder_inputs = mixed_layer(inputs = [full_matrix_projection(context), + full_matrix_projection(current_word)], + size = decoder_size * 3) + # 定义门控循环单元循环神经网络阶跃函数 + gru_step = gru_step_layer(name='gru_decoder', + input=decoder_inputs, + output_mem=decoder_mem, + size=decoder_size) + # 定义输出函数 + out = mixed_layer(input=[full_matrix_projection(input=gru_step)], + size=target_dict_dim, + bias_attr=True, + act=SoftmaxActivation()) + return out +``` + +生成序列 +----------------- + +训练模型后,我们可以使用它来生成序列。通常的做法是使用**柱搜索(beam search** 生成序列。以下代码片段定义柱搜索算法。注意,`beam_search`函数假设`step`的输出函数返回下一个标志的 softmax 归一化概率向量。我们对模型进行了以下更改。 + +- 使用 `GeneratedInput` 来 trg\_embedding。 `GeneratedInput` 计算上一次时间步生成的标记的向量来作为当前时间步的输入。 +- 使用 `beam_search` 函数。这个函数需要设置: + - `bos_id`: 开始标记。每个句子都以开始标记开头。 + - `eos_id`: 结束标记。每个句子都以结束标记结尾。 + - `beam_size`: 柱搜索算法中的柱大小。 + - `max_length`: 生成序列的最大长度。 +- 使用 `seqtext_printer_evaluator` 根据索引矩阵和字典打印文本。这个函数需要设置: + - `id_input`: 数据的整数ID,用于标识生成的文件中的相应输出。 + - `dict_file`: 用于将词ID转换为词的字典文件。 + - `result_file`: 生成结果文件的路径。 + +代码如下: + +``` sourceCode +group_inputs=[StaticInput(input=encoded_vector,is_seq=True), + StaticInput(input=encoded_proj,is_seq=True)] +# 在一代中,解码器预测下一目标词基于编码源序列和最后生成的目标词。 +# 编码源序列(编码器输出)必须由只读记忆的 StaticInput 指定。 +# 这里, GeneratedInputs 自动获取上一个被一个开始符号初始化的生成词,例如 。 +trg_embedding = GeneratedInput( + size=target_dict_dim, + embedding_name='_target_language_embedding', + embedding_size=word_vector_dim) +group_inputs.append(trg_embedding) +beam_gen = beam_search(name=decoder_group_name, + step=gru_decoder_with_attention, + input=group_inputs, + bos_id=0, # Beginnning token. + eos_id=1, # End of sentence token. + beam_size=beam_size, + max_length=max_length) + +seqtext_printer_evaluator(input=beam_gen, + id_input=data_layer(name="sent_id", size=1), + dict_file=trg_dict_path, + result_file=gen_trans_file) +outputs(beam_gen) +``` + +注意,这种生成技术只用于类似解码器的生成过程。如果你正在处理序列标记任务,请参阅 [Semantic Role Labeling Demo](../../demo/semantic_role_labeling/index.html) 了解更多详细信息。 + +完整的配置文件在`demo/seqToseq/seqToseq_net.py`。 From 96009326504de2149c9fcd978b769ae9ba21843a Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Wed, 14 Dec 2016 16:33:01 +0800 Subject: [PATCH 07/55] Add fake gpu support of the priorbox layer for the moment --- paddle/gserver/layers/PriorBox.cpp | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/paddle/gserver/layers/PriorBox.cpp b/paddle/gserver/layers/PriorBox.cpp index 994f7c20384ff..4b8573f05817a 100644 --- a/paddle/gserver/layers/PriorBox.cpp +++ b/paddle/gserver/layers/PriorBox.cpp @@ -24,11 +24,13 @@ class PriorBoxLayer : public Layer { bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType); void backward(const UpdateCallback& callback) {} + void forwardImp(const Argument& featureMap, const Argument& imageShape); int numPriors_; std::vector minSize_; std::vector maxSize_; std::vector aspectRatio_; std::vector variance_; + std::vector tmpCpuInput_; MatrixPtr buffer_; }; @@ -56,16 +58,35 @@ bool PriorBoxLayer::init(const LayerMap& layerMap, numPriors_ = aspectRatio_.size(); if (maxSize_.size() > 0) numPriors_++; buffer_ = Matrix::create(1, 1, false, false); + if (useGpu_) { + tmpCpuInput_.reserve(inputLayers_.size()); + for (size_t i = 0; i < inputLayers_.size(); i++) { + tmpCpuInput_.push_back(Argument()); + } + } return true; } void PriorBoxLayer::forward(PassType passType) { Layer::forward(passType); - auto input = getInput(0); - int layer_width = input.getFrameWidth(); - int layer_height = input.getFrameHeight(); + if (useGpu_) { + for (size_t i = 0; i < inputLayers_.size(); i++) { + tmpCpuInput_[i].resizeAndCopyFrom( + getInput(i), false, HPPL_STREAM_DEFAULT); + hl_stream_synchronize(HPPL_STREAM_DEFAULT); + forwardImp(tmpCpuInput_[0], tmpCpuInput_[1]); + } + } else { + forwardImp(getInput(0), getInput(1)); + } +} + +void PriorBoxLayer::forwardImp(const Argument& featureMap, + const Argument& imageShape) { + int layer_width = featureMap.getFrameWidth(); + int layer_height = featureMap.getFrameHeight(); - MatrixPtr inV1 = getInputValue(1); + MatrixPtr inV1 = imageShape.value; int image_width = inV1->getElement(0, 0); int image_height = inV1->getElement(0, 1); float step_w = static_cast(image_width) / layer_width; @@ -130,6 +151,7 @@ void PriorBoxLayer::forward(PassType passType) { MatrixPtr outV = getOutputValue(); outV->copyFrom(buffer_->data_, dim * 2); } + REGISTER_LAYER(priorbox, PriorBoxLayer); } // namespace paddle From c0076084e24175cfe729f085c7feaf286270dfe8 Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Wed, 14 Dec 2016 16:50:38 +0800 Subject: [PATCH 08/55] Format the python file. --- proto/ModelConfig.proto | 2 -- python/paddle/trainer/config_parser.py | 4 +++- python/paddle/trainer_config_helpers/layers.py | 18 +++++++++++++++--- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 460a39275fbe4..f28f69641b6e4 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -253,8 +253,6 @@ message PriorBoxConfig { repeated uint32 max_size = 2; repeated float aspect_ratio = 3; repeated float variance = 4; - optional bool flip = 5 [default = true]; - optional bool clip = 6 [default = true]; } message LayerInputConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 5de524e507bdf..8a82e5d667aa3 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1577,9 +1577,11 @@ class PrintLayer(LayerBase): def __init__(self, name, inputs): super(PrintLayer, self).__init__(name, 'print', 0, inputs) + @config_layer('priorbox') class PriorBoxLayer(LayerBase): - def __init__(self, name, inputs, size, min_size, max_size, aspect_ratio, variance): + def __init__(self, name, inputs, size, min_size, max_size, aspect_ratio, + variance): super(PriorBoxLayer, self).__init__(name, 'priorbox', 0, inputs) config_assert(len(inputs) == 2, 'PriorBoxLayer must have 2 input') self.config.inputs[0].priorbox_conf.min_size.extend(min_size) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index f04b5646aab03..80c421aa2ec3b 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -935,8 +935,15 @@ def print_layer(input, name=None): inputs=[l.name for l in input], ) # this layer don't return anything, can not be input of other layer. + @wrap_name_default("priorbox") -def priorbox_layer(input, img_shape, aspect_ratio, variance, min_size, max_size=[], name=None): +def priorbox_layer(input, + img_shape, + aspect_ratio, + variance, + min_size, + max_size=[], + name=None): """ Compute the priorbox and set the variance. This layer is necessary for ssd. @@ -957,7 +964,7 @@ def priorbox_layer(input, img_shape, aspect_ratio, variance, min_size, max_size= """ # plus one for ratio 1. num_filters = (len(aspect_ratio) * 2 + 1 + len(max_size)) * 4 - size=(input.size / input.num_filters) * num_filters * 2 + size = (input.size / input.num_filters) * num_filters * 2 Layer( name=name, type=LayerType.PRIORBOX_LAYER, @@ -968,7 +975,12 @@ def priorbox_layer(input, img_shape, aspect_ratio, variance, min_size, max_size= aspect_ratio=aspect_ratio, variance=variance) return LayerOutput( - name, LayerType.PRIORBOX_LAYER, parents=[input, img_shape], num_filters=num_filters, size=size) + name, + LayerType.PRIORBOX_LAYER, + parents=[input, img_shape], + num_filters=num_filters, + size=size) + @wrap_name_default("seq_pooling") @wrap_bias_attr_default(has_bias=False) From a1d2abc16d9c7b42af6dcb41902423ae2904ee9a Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 14 Dec 2016 18:46:40 +0800 Subject: [PATCH 09/55] add Function --- paddle/math/Function.cpp | 47 +++++++++++++ paddle/math/Function.h | 84 ++++++++++++++++++++++++ paddle/math/cross_map_normal_op.cpp | 46 +++++++++++++ paddle/math/cross_map_normal_op.h | 20 +----- paddle/math/tests/test_matrixCompare.cpp | 15 +++-- 5 files changed, 188 insertions(+), 24 deletions(-) create mode 100644 paddle/math/Function.cpp create mode 100644 paddle/math/Function.h diff --git a/paddle/math/Function.cpp b/paddle/math/Function.cpp new file mode 100644 index 0000000000000..21d2719172870 --- /dev/null +++ b/paddle/math/Function.cpp @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" + +namespace paddle { + +template <> +size_t FuncConfig::get(const std::string& key) const { + auto it = valueMap_.find(key); + CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'"; + return it->second.s; +} + +template <> +real FuncConfig::get(const std::string& key) const { + auto it = valueMap_.find(key); + CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'"; + return it->second.r; +} + +template <> +void FuncConfig::set(const std::string& key, size_t v) { + CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; + valueMap_[key].s = v; +} + +template <> +void FuncConfig::set(const std::string& key, real v) { + CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; + valueMap_[key].r = v; +} + +ClassRegistrar FunctionBase::funcRegistrar_; + +} // namespace paddle diff --git a/paddle/math/Function.h b/paddle/math/Function.h new file mode 100644 index 0000000000000..b41ba2a13d377 --- /dev/null +++ b/paddle/math/Function.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/utils/ClassRegistrar.h" +#include "paddle/math/Matrix.h" + +namespace paddle { + +enum DeviceType { + DEVICE_TYPE_UNSPECIFIED = 0, + DEVICE_TYPE_CPU = 1, + DEVICE_TYPE_GPU = 2, +}; + +template +struct MatrixT; + +template <> +struct MatrixT { + using type = CpuMatrix; +}; + +template <> +struct MatrixT { + using type = GpuMatrix; +}; + +typedef std::vector Arguments; + +class FuncConfig { +public: + union value { + size_t s; + real r; + }; + + template + T get(const std::string& key) const; + + template + void set(const std::string& key, T v); + +protected: + std::map valueMap_; +}; + +class FunctionBase { +public: + virtual ~FunctionBase() {} + + virtual void init(const FuncConfig& config) {} + + virtual void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) {} + + static ClassRegistrar funcRegistrar_; +}; + +#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName + +#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ + static InitFunction __reg_type_##typeName([]() { \ + FunctionBase::funcRegistrar_ \ + .registerClass>( \ + FUNC_NAME(typeName, deviceName)); \ + }) + +} // namespace paddle diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index be242926aff16..0b7273206381d 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -128,4 +128,50 @@ void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, } } +template +class CrossMapNormalFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + size_ = config.get("size"); + scale_ = config.get("scale"); + pow_ = config.get("pow"); + } + + void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) override { + CHECK_EQ(1, inputs.size()); + CHECK_EQ(2, outputs.size()); + CHECK_EQ(0, inouts.size()); + + auto input = dynamic_cast::type&>(inputs[0]); + auto output = + dynamic_cast::type&>(outputs[0]); + auto denom = + dynamic_cast::type&>(outputs[1]); + + CHECK(input.isContiguous()); + CHECK(output.isContiguous()); + CHECK(denom.isContiguous()); + CHECK_EQ(output.getHeight(), input.getHeight()); + CHECK_EQ(output.getWidth(), input.getWidth()); + CHECK_EQ(output.getHeight(), denom.getHeight()); + CHECK_EQ(output.getWidth(), denom.getWidth()); + + // CrossMapNormal cross; + // need: + // size_t channels, + // size_t imgSizeH, + // size_t imgSizeW, + // cross(output, denom, input, ); + } + +private: + size_t size_; + real scale_; + real pow_; +}; + +REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); + } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index c2bb95f6b11fb..86f54abde108d 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -14,29 +14,11 @@ limitations under the License. */ #pragma once +#include "Function.h" #include "paddle/math/Matrix.h" namespace paddle { -enum DeviceType { - DEVICE_TYPE_UNSPECIFIED = 0, - DEVICE_TYPE_CPU = 1, - DEVICE_TYPE_GPU = 2, -}; - -template -struct MatrixT; - -template <> -struct MatrixT { - using type = CpuMatrix; -}; - -template <> -struct MatrixT { - using type = GpuMatrix; -}; - template struct CrossMapNormal { void operator()(typename MatrixT::type& outputs, diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 8d7a4fb94d0a1..0b75785528f5d 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/utils/Stat.h" #include "TensorCheck.h" #include "paddle/math/cross_map_normal_op.h" +#include "paddle/math/Function.h" using namespace paddle; // NOLINT using namespace std; // NOLINT @@ -1280,6 +1281,15 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); + FuncConfig config; + config.set("size", (size_t)sizeX); + config.set("scale", scale); + config.set("pow", pow); + FunctionBase* cpu = + FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); + cpu->init(config); + // cpu->calc(); + CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); @@ -1295,11 +1305,6 @@ void testCrossMapNormalFwd( scale, pow); -#if 0 - outputsGpu.crossMapNormalFwd( - inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow); -#endif - TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); } From ce1d98e083017afadac9fcd9f94f5c59aceaf6c0 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 10:31:45 +0800 Subject: [PATCH 10/55] Add a Tensor to use as a Function argument --- paddle/math/Function.h | 12 +++++++- paddle/math/cross_map_normal_op.cpp | 37 +++++++++++------------- paddle/math/tests/test_matrixCompare.cpp | 9 ++++-- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/paddle/math/Function.h b/paddle/math/Function.h index b41ba2a13d377..539759782be3a 100644 --- a/paddle/math/Function.h +++ b/paddle/math/Function.h @@ -40,7 +40,17 @@ struct MatrixT { using type = GpuMatrix; }; -typedef std::vector Arguments; +typedef std::vector Dims; + +class Tensor { +public: + Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {} + + real* buf_; + Dims dims_; +}; + +typedef std::vector Arguments; class FuncConfig { public: diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index 0b7273206381d..d55bd78c628f7 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -144,26 +144,23 @@ class CrossMapNormalFunc : public FunctionBase { CHECK_EQ(2, outputs.size()); CHECK_EQ(0, inouts.size()); - auto input = dynamic_cast::type&>(inputs[0]); - auto output = - dynamic_cast::type&>(outputs[0]); - auto denom = - dynamic_cast::type&>(outputs[1]); - - CHECK(input.isContiguous()); - CHECK(output.isContiguous()); - CHECK(denom.isContiguous()); - CHECK_EQ(output.getHeight(), input.getHeight()); - CHECK_EQ(output.getWidth(), input.getWidth()); - CHECK_EQ(output.getHeight(), denom.getHeight()); - CHECK_EQ(output.getWidth(), denom.getWidth()); - - // CrossMapNormal cross; - // need: - // size_t channels, - // size_t imgSizeH, - // size_t imgSizeW, - // cross(output, denom, input, ); + CHECK_EQ(inputs[0].dims_.size(), 4); + for (size_t i = 0; i < inputs[0].dims_.size(); i++) { + CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]); + } + + size_t samples = inputs[0].dims_[0]; + size_t channels = inputs[0].dims_[1]; + size_t height = inputs[0].dims_[2]; + size_t width = inputs[0].dims_[3]; + size_t imageSize = channels * height * width; + CpuMatrix input(inputs[0].buf_, samples, imageSize); + CpuMatrix output(outputs[0].buf_, samples, imageSize); + CpuMatrix denom(outputs[1].buf_, samples, imageSize); + + CrossMapNormal cross; + cross(output, denom, input, channels, height, width, size_, scale_, pow_); } private: diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 0b75785528f5d..cd34ea18a70ea 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1288,12 +1288,17 @@ void testCrossMapNormalFwd( FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); cpu->init(config); - // cpu->calc(); + Dims dims{ + (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; + cpu->calc({Tensor(inputs.getData(), dims)}, + {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, + {}); +#if 0 CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); - +#endif CrossMapNormal gpuCross; gpuCross(outputsGpu, denomsGpu, From 214343af349f585fc775acb522bbf93bd5427b10 Mon Sep 17 00:00:00 2001 From: livc Date: Thu, 15 Dec 2016 10:54:48 +0800 Subject: [PATCH 11/55] modify details --- doc/howto/deep_model/rnn/rnn_cn.md | 48 +++++++++++++++--------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/doc/howto/deep_model/rnn/rnn_cn.md b/doc/howto/deep_model/rnn/rnn_cn.md index 496a54d0113f8..78779eb39d702 100644 --- a/doc/howto/deep_model/rnn/rnn_cn.md +++ b/doc/howto/deep_model/rnn/rnn_cn.md @@ -32,25 +32,25 @@ yield src_ids, trg_ids, trg_ids_next 配置循环神经网络架构 ----------------------------------------------- -### 简单门控(Simple Gated)循环神经网络 +### 简单门控循环神经网络(Gated Recurrent Neural Network) 循环神经网络在每个时间步骤顺序地处理序列。下面列出了 LSTM 的架构的示例。 ![image](../../../tutorials/sentiment_analysis/bi_lstm.jpg) -一般来说,循环网络从 *t* = 1 到 *t* = *T* 或者相反从 *t* = *T* 到 *t* = 1 执行以下操作。 +一般来说,循环网络从 *t* = 1 到 *t* = *T* 或者反向地从 *t* = *T* 到 *t* = 1 执行以下操作。 *x**t* + 1 = *f**x*(*x**t*),*y**t* = *f**y*(*x**t*) -其中 *f**x*(.) 称为**阶跃函数**,而 *f**y*(.) 称为**输出函数**。在 vanilla 循环神经网络中,阶跃函数和输出函数都非常简单。然而,PaddlePaddle 支持通过修改这两个函数来配置非常复杂的架构。 我们将使用 sequence to sequence 模型演示如何配置复杂的循环神经网络模型。在本节中,我们将使用简单的 vanilla 循环神经网络作为使用`recurrent_group`配置简单循环神经网络的例子。 注意,如果你只需要使用简单的RNN,GRU或LSTM,那么推荐使用`grumemory`和`lstmemory`,因为它们的计算效率比`recurrent_group`更高。 +其中 *f**x*(.) 称为**单步函数**(即单时间步执行的函数,step function),而 *f**y*(.) 称为**输出函数**。在 vanilla 循环神经网络中,单步函数和输出函数都非常简单。然而,PaddlePaddle 可以通过修改这两个函数来实现复杂的网络配置。我们将使用 sequence to sequence 模型演示如何配置复杂的循环神经网络模型。在本节中,我们将使用简单的 vanilla 循环神经网络作为使用`recurrent_group`配置简单循环神经网络的例子。 注意,如果你只需要使用简单的RNN,GRU或LSTM,那么推荐使用`grumemory`和`lstmemory`,因为它们的计算效率比`recurrent_group`更高。 -对于 vanilla RNN,在每个时间步长,**阶跃函数**为: +对于 vanilla RNN,在每个时间步长,**单步函数**为: *x**t* + 1 = *W**x**x**t* + *W**i**I**t* + *b* 其中 *x**t* 是RNN状态,并且 *I**t* 是输入,*W**x* 和 *W**i* 分别是RNN状态和输入的变换矩阵。*b* 是偏差。它的**输出函数**只需要*x**t*作为输出。 -`recurrent_group`是构建循环神经网络的最重要的工具。 它定义了**阶跃函数**,**输出函数**和循环神经网络的输入。注意,这个函数的`step`参数执行了`step function`(阶跃函数)和`output function`(输出函数): +`recurrent_group`是构建循环神经网络的最重要的工具。 它定义了**单步函数**,**输出函数**和循环神经网络的输入。注意,这个函数的`step`参数需要实现`step function`(单步函数)和`output function`(输出函数): ``` sourceCode @@ -77,9 +77,9 @@ def simple_rnn(input, input=input) ``` -PaddlePaddle 使用“记忆”构造阶跃函数。**记忆(Memory)**是在PaddlePaddle中构造循环神经网络时最重要的概念。 记忆是在阶跃函数中循环使用的状态,例如*x**t* + 1 = *f**x*(*x**t*)。 一个记忆包含**输出**和**输入**。当前时间步处的记忆的输出作为下一时间步记忆的输入。记忆也可以具有**引导层**,其输出被用作记忆的初始值。 在我们的例子中,门控循环单元的输出被用作输出记忆。请注意,`rnn_out`层的名称与`out_mem`的名称相同。这意味着`rnn_out` (*x**t* + 1)的输出被用作`out_mem`记忆的**输出**。 +PaddlePaddle 使用“Memory”(记忆模块)实现单步函数。**Memory**是在PaddlePaddle中构造循环神经网络时最重要的概念。 Memory是在单步函数中循环使用的状态,例如*x**t* + 1 = *f**x*(*x**t*)。 一个Memory包含**输出**和**输入**。当前时间步处的Memory的输出作为下一时间步Memory的输入。Memory也可以具有**boot layer(引导层)**,其输出被用作Memory的初始值。 在我们的例子中,门控循环单元的输出被用作输出Memory。请注意,`rnn_out`层的名称与`out_mem`的名称相同。这意味着`rnn_out` (*x**t* + 1)的输出被用作`out_mem`Memory的**输出**。 -记忆也可以是序列。在这种情况下,在每个时间步中,我们有一个序列作为循环神经网络的状态。这在构造非常复杂的循环神经网络时是有用的。 其他高级功能包括定义多个记忆,以及使用子序列来定义分级循环神经网络架构。 +Memory也可以是序列。在这种情况下,在每个时间步中,我们有一个序列作为循环神经网络的状态。这在构造非常复杂的循环神经网络时是有用的。 其他高级功能包括定义多个Memory,以及使用子序列来定义分级循环神经网络架构。 我们在函数的结尾返回`rnn_out`。 这意味着 `rnn_out` 层的输出被用作门控循环神经网络的**输出**函数。 @@ -89,11 +89,11 @@ PaddlePaddle 使用“记忆”构造阶跃函数。**记忆(Memory)**是在 ![image](../../../tutorials/text_generation/encoder-decoder-attention-model.png) -在这个模型中,源序列 *S* = {*s*1, …, *s**T*} 用双向门控循环神经网络编码。双向门控循环神经网络的隐藏状态 *H**S* = {*H*1, …, *H**T*} 被称为 *编码向量*。解码器是门控循环神经网络。当解读每一个*y**t*时, 这个门控循环神经网络生成一系列权重 *W**S**t* = {*W*1*t*, …, *W**T**t*}, 用于计算编码向量的加权和。加权和用来鉴定符号 *y**t* 的生成。 +在这个模型中,源序列 *S* = {*s*1, …, *s**T*} 用双向门控循环神经网络编码。双向门控循环神经网络的隐藏状态 *H**S* = {*H*1, …, *H**T*} 被称为 *编码向量*。解码器是门控循环神经网络。当解读每一个*y**t*时, 这个门控循环神经网络生成一系列权重 *W**S**t* = {*W*1*t*, …, *W**T**t*}, 用于计算编码向量的加权和。加权和用来生成*y**t*。 模型的编码器部分如下所示。它叫做`grumemory`来表示门控循环神经网络。如果网络架构简单,那么推荐使用循环神经网络的方法,因为它比 `recurrent_group` 更快。我们已经实现了大多数常用的循环神经网络架构,可以参考 [Layers](../../ui/api/trainer_config_helpers/layers_index.html) 了解更多细节。 -我们还将编码向量投射到`decoder_size`维空间,获得反向循环网络的第一个实例,并将其投射到`decoder_size`维空间: +我们还将编码向量投射到 `decoder_size` 维空间。这通过获得反向循环网络的第一个实例,并将其投射到 `decoder_size` 维空间完成: ``` sourceCode # 定义源语句的数据层 @@ -123,7 +123,7 @@ backward_first = first_seq(input=src_backward) decoder_boot = mixed_layer(input=[full_matrix_projection(backward_first)], size=decoder_size, act=TanhActivation()) ``` -解码器使用 `recurrent_group` 来定义循环神经网络。阶跃函数和输出函数在 `gru_decoder_with_attention` 中定义: +解码器使用 `recurrent_group` 来定义循环神经网络。单步函数和输出函数在 `gru_decoder_with_attention` 中定义: ``` sourceCode group_inputs=[StaticInput(input=encoded_vector,is_seq=True), @@ -137,22 +137,22 @@ group_inputs.append(trg_embedding) # 对于配备有注意力机制的解码器,在训练中, # 目标向量(groudtruth)是数据输入, -# 而编码源序列作为无界存储器被访问。 -# StaticInput 意味着不同时间步的相同值, -# 否则它是一个序列的输入,不同时间步的输入是不同的。 +# 而源序列的编码向量可以被无边界的memory访问 +# StaticInput 意味着不同时间步的输入都是相同的值, +# 否则它以一个序列输入,不同时间步的输入是不同的。 # 所有输入序列应该有相同的长度。 decoder = recurrent_group(name=decoder_group_name, step=gru_decoder_with_attention, input=group_inputs) ``` -阶跃函数的实现如下所示。首先,它定义解码网络的**记忆**。然后定义 attention,门控循环单元阶跃函数和输出函数: +单步函数的实现如下所示。首先,它定义解码网络的**Memory**。然后定义 attention,门控循环单元单步函数和输出函数: ``` sourceCode def gru_decoder_with_attention(enc_vec, enc_proj, current_word): - # 定义解码器的记忆 - # 记忆的输出定义在 gru_step 内 - # 注意 gru_step 应该与它的记忆名字相同 + # 定义解码器的Memory + # Memory的输出定义在 gru_step 内 + # 注意 gru_step 应该与它的Memory名字相同 decoder_mem = memory(name='gru_decoder', size=decoder_size, boot_layer=decoder_boot) @@ -164,7 +164,7 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word): decoder_inputs = mixed_layer(inputs = [full_matrix_projection(context), full_matrix_projection(current_word)], size = decoder_size * 3) - # 定义门控循环单元循环神经网络阶跃函数 + # 定义门控循环单元循环神经网络单步函数 gru_step = gru_step_layer(name='gru_decoder', input=decoder_inputs, output_mem=decoder_mem, @@ -180,13 +180,13 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word): 生成序列 ----------------- -训练模型后,我们可以使用它来生成序列。通常的做法是使用**柱搜索(beam search** 生成序列。以下代码片段定义柱搜索算法。注意,`beam_search`函数假设`step`的输出函数返回下一个标志的 softmax 归一化概率向量。我们对模型进行了以下更改。 +训练模型后,我们可以使用它来生成序列。通常的做法是使用**beam search** 生成序列。以下代码片段定义柱搜索算法。注意,`beam_search` 函数假设 `step` 的输出函数返回的是下一个时刻输出词的 softmax 归一化概率向量。我们对模型进行了以下更改。 -- 使用 `GeneratedInput` 来 trg\_embedding。 `GeneratedInput` 计算上一次时间步生成的标记的向量来作为当前时间步的输入。 +- 使用 `GeneratedInput` 来表示 trg\_embedding。 `GeneratedInput` 将上一时间步所生成的词的向量来作为当前时间步的输入。 - 使用 `beam_search` 函数。这个函数需要设置: - `bos_id`: 开始标记。每个句子都以开始标记开头。 - `eos_id`: 结束标记。每个句子都以结束标记结尾。 - - `beam_size`: 柱搜索算法中的柱大小。 + - `beam_size`: beam search 算法中的beam大小。 - `max_length`: 生成序列的最大长度。 - 使用 `seqtext_printer_evaluator` 根据索引矩阵和字典打印文本。这个函数需要设置: - `id_input`: 数据的整数ID,用于标识生成的文件中的相应输出。 @@ -198,9 +198,9 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word): ``` sourceCode group_inputs=[StaticInput(input=encoded_vector,is_seq=True), StaticInput(input=encoded_proj,is_seq=True)] -# 在一代中,解码器预测下一目标词基于编码源序列和最后生成的目标词。 -# 编码源序列(编码器输出)必须由只读记忆的 StaticInput 指定。 -# 这里, GeneratedInputs 自动获取上一个被一个开始符号初始化的生成词,例如 。 +# 在生成时,解码器基于编码源序列和最后生成的目标词预测下一目标词。 +# 编码源序列(编码器输出)必须由只读Memory的 StaticInput 指定。 +# 这里, GeneratedInputs 自动获取上一个生成的词,并在最开始初始化为起始词,如 。 trg_embedding = GeneratedInput( size=target_dict_dim, embedding_name='_target_language_embedding', From 660b310f8f246af820c2dd440804501ec12ee955 Mon Sep 17 00:00:00 2001 From: livc Date: Thu, 15 Dec 2016 10:58:53 +0800 Subject: [PATCH 12/55] modify line183 beam search --- doc/howto/deep_model/rnn/rnn_cn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/howto/deep_model/rnn/rnn_cn.md b/doc/howto/deep_model/rnn/rnn_cn.md index 78779eb39d702..5ec05b2cab9ba 100644 --- a/doc/howto/deep_model/rnn/rnn_cn.md +++ b/doc/howto/deep_model/rnn/rnn_cn.md @@ -180,7 +180,7 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word): 生成序列 ----------------- -训练模型后,我们可以使用它来生成序列。通常的做法是使用**beam search** 生成序列。以下代码片段定义柱搜索算法。注意,`beam_search` 函数假设 `step` 的输出函数返回的是下一个时刻输出词的 softmax 归一化概率向量。我们对模型进行了以下更改。 +训练模型后,我们可以使用它来生成序列。通常的做法是使用**beam search** 生成序列。以下代码片段定义 beam search 算法。注意,`beam_search` 函数假设 `step` 的输出函数返回的是下一个时刻输出词的 softmax 归一化概率向量。我们对模型进行了以下更改。 - 使用 `GeneratedInput` 来表示 trg\_embedding。 `GeneratedInput` 将上一时间步所生成的词的向量来作为当前时间步的输入。 - 使用 `beam_search` 函数。这个函数需要设置: From 4ebb3eb759903bf95968b578eec99b1364d3bd10 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 11:55:35 +0800 Subject: [PATCH 13/55] imporve Function --- paddle/gserver/layers/NormProjectionLayer.cpp | 60 +++++++++++---- paddle/gserver/layers/NormProjectionLayer.h | 4 + paddle/math/Function.cpp | 6 +- paddle/math/Function.h | 14 ++-- paddle/math/cross_map_normal_op.cpp | 75 ++++++++++--------- paddle/math/cross_map_normal_op.h | 13 ++++ paddle/math/cross_map_normal_op_gpu.cu | 46 ++++-------- paddle/math/tests/test_matrixCompare.cpp | 21 +++++- 8 files changed, 147 insertions(+), 92 deletions(-) diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index ea301292e0dcc..5dda7ee205f40 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" +#include "paddle/math/cross_map_normal_op.h" #include "NormProjectionLayer.h" namespace paddle { @@ -45,6 +46,16 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, /* the size of inputs for norm-layer is 1 */ CHECK_EQ(config_.inputs_size(), 1); + if (useGpu_) { + normal_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormal, GPU)); + } else { + normal_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormal, CPU)); + } + normal_->init( + FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); + return true; } @@ -62,10 +73,14 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - denoms_->zeroMem(); - - outV->crossMapNormalFwd( - *input, imgSizeH_, imgSizeW_, *denoms_, channels_, size_, scale_, pow_); + Dims dims{(size_t)batchSize, + (size_t)channels_, + (size_t)imgSizeH_, + (size_t)imgSizeW_}; + normal_->calc( + {Tensor(input->getData(), dims)}, + {Tensor(outV->getData(), dims), Tensor(denoms_->getData(), dims)}, + {}); } void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { @@ -80,15 +95,32 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { MatrixPtr localOutV = getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); - preOutGrad->crossMapNormalBwd(*localGrad, - *denoms_, - *preOutV, - *localOutV, - channels_, - imgSizeH_, - imgSizeW_, - size_, - scale_, - pow_); + if (useGpu_) { + CrossMapNormalGrad crossGrad; + crossGrad(dynamic_cast(*preOutGrad), + dynamic_cast(*preOutV), + dynamic_cast(*localGrad), + dynamic_cast(*localOutV), + dynamic_cast(*denoms_), + channels_, + imgSizeH_, + imgSizeW_, + size_, + scale_, + pow_); + } else { + CrossMapNormalGrad crossGrad; + crossGrad(dynamic_cast(*preOutGrad), + dynamic_cast(*preOutV), + dynamic_cast(*localGrad), + dynamic_cast(*localOutV), + dynamic_cast(*denoms_), + channels_, + imgSizeH_, + imgSizeW_, + size_, + scale_, + pow_); + } } } // namespace paddle diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 0db8e2551f06d..ea44669be3f8b 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "NormLayer.h" #include "paddle/math/Matrix.h" +#include "paddle/math/Function.h" #include namespace paddle { @@ -39,5 +40,8 @@ class CMRProjectionNormLayer : public ResponseNormLayer { bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType); void backward(const UpdateCallback& callback = nullptr); + +protected: + FunctionBase* normal_; }; } // namespace paddle diff --git a/paddle/math/Function.cpp b/paddle/math/Function.cpp index 21d2719172870..02880e5ea1acb 100644 --- a/paddle/math/Function.cpp +++ b/paddle/math/Function.cpp @@ -31,15 +31,17 @@ real FuncConfig::get(const std::string& key) const { } template <> -void FuncConfig::set(const std::string& key, size_t v) { +FuncConfig& FuncConfig::set(const std::string& key, size_t v) { CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; valueMap_[key].s = v; + return *this; } template <> -void FuncConfig::set(const std::string& key, real v) { +FuncConfig& FuncConfig::set(const std::string& key, real v) { CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; valueMap_[key].r = v; + return *this; } ClassRegistrar FunctionBase::funcRegistrar_; diff --git a/paddle/math/Function.h b/paddle/math/Function.h index 539759782be3a..f8fab972a6902 100644 --- a/paddle/math/Function.h +++ b/paddle/math/Function.h @@ -46,6 +46,8 @@ class Tensor { public: Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {} + real* getData() const { return buf_; } + real* buf_; Dims dims_; }; @@ -63,7 +65,7 @@ class FuncConfig { T get(const std::string& key) const; template - void set(const std::string& key, T v); + FuncConfig& set(const std::string& key, T v); protected: std::map valueMap_; @@ -84,11 +86,11 @@ class FunctionBase { #define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName -#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ - static InitFunction __reg_type_##typeName([]() { \ - FunctionBase::funcRegistrar_ \ - .registerClass>( \ - FUNC_NAME(typeName, deviceName)); \ +#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ + static InitFunction __reg_type_##typeName##deviceName([]() { \ + FunctionBase::funcRegistrar_ \ + .registerClass>( \ + FUNC_NAME(typeName, deviceName)); \ }) } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index d55bd78c628f7..e520351d2e3b8 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -18,45 +18,41 @@ namespace paddle { // NCHW template <> -void CrossMapNormal::operator()(CpuMatrix& outputs, - CpuMatrix& denoms, - CpuMatrix& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(outputs.isContiguous()); - CHECK(inputs.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK_EQ(outputs.getHeight(), inputs.getHeight()); - CHECK_EQ(outputs.getWidth(), inputs.getWidth()); - CHECK_EQ(outputs.getHeight(), denoms.getHeight()); - CHECK_EQ(outputs.getWidth(), denoms.getWidth()); - - size_t numSample = inputs.getHeight(); - size_t numCols = inputs.getWidth(); - size_t imageSize = imgSizeH * imgSizeW; - CHECK(imageSize * channels == numCols); - - denoms = denoms.constant(1.0); - const int start = -((int)sizeX - 1) / 2; - const int end = (int)sizeX + start; - for (size_t i = 0; i < numSample; i++) { - real* denomsData = denoms.getData() + i * numCols; - real* inputData = inputs.getData() + i * numCols; +void CrossMapNormal(real* outputs, + real* denoms, + real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t oneImage = height * width; + size_t oneSample = channels * oneImage; + + CpuVector outputsV(numSamples * oneSample, outputs); + CpuVector inputsV(numSamples * oneSample, inputs); + CpuVector denomsV(numSamples * oneSample, denoms); + + denomsV = denomsV.constant(1.0); + const int start = -((int)size - 1) / 2; + const int end = (int)size + start; + for (size_t i = 0; i < numSamples; i++) { + real* oneDenom = denoms + i * oneSample; + real* oneInput = inputs + i * oneSample; for (int c = 0; c < (int)channels; c++) { - CpuVector denom(imageSize, denomsData + c * imageSize); + CpuVector denom(oneImage, oneDenom + c * oneImage); for (int s = start; s < end; s++) { if (c + s >= 0 && c + s < (int)channels) { - CpuVector input(imageSize, inputData + (c + s) * imageSize); + CpuVector input(oneImage, oneInput + (c + s) * oneImage); denom += input.square() * scale; } } } } - outputs = inputs * denoms.pow(-pow); + + outputsV = inputsV * denomsV.pow(-pow); } template <> @@ -154,13 +150,17 @@ class CrossMapNormalFunc : public FunctionBase { size_t channels = inputs[0].dims_[1]; size_t height = inputs[0].dims_[2]; size_t width = inputs[0].dims_[3]; - size_t imageSize = channels * height * width; - CpuMatrix input(inputs[0].buf_, samples, imageSize); - CpuMatrix output(outputs[0].buf_, samples, imageSize); - CpuMatrix denom(outputs[1].buf_, samples, imageSize); - CrossMapNormal cross; - cross(output, denom, input, channels, height, width, size_, scale_, pow_); + CrossMapNormal(outputs[0].getData(), + outputs[1].getData(), + inputs[0].getData(), + samples, + channels, + height, + width, + size_, + scale_, + pow_); } private: @@ -170,5 +170,6 @@ class CrossMapNormalFunc : public FunctionBase { }; REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); +REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index 86f54abde108d..ef9533485ec9c 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -19,6 +19,18 @@ limitations under the License. */ namespace paddle { +template +void CrossMapNormal(real* outputs, + real* denoms, + real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow); +#if 0 template struct CrossMapNormal { void operator()(typename MatrixT::type& outputs, @@ -31,6 +43,7 @@ struct CrossMapNormal { real scale, real pow); }; +#endif template struct CrossMapNormalGrad { diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/math/cross_map_normal_op_gpu.cu index 0a154d97ac02f..9b92974344955 100644 --- a/paddle/math/cross_map_normal_op_gpu.cu +++ b/paddle/math/cross_map_normal_op_gpu.cu @@ -61,45 +61,29 @@ __global__ void KeCMRNormOutput(size_t inputSize, const real* in, } template <> -void CrossMapNormal::operator()(GpuMatrix& outputs, - GpuMatrix& denoms, - GpuMatrix& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(outputs.isContiguous()); - CHECK(inputs.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK_EQ(outputs.getHeight(), inputs.getHeight()); - CHECK_EQ(outputs.getWidth(), inputs.getWidth()); - CHECK_EQ(outputs.getHeight(), denoms.getHeight()); - CHECK_EQ(outputs.getWidth(), denoms.getWidth()); - - size_t numSample = inputs.getHeight(); - size_t numCols = inputs.getWidth(); - CHECK(imgSizeH * imgSizeW * channels == numCols); - - real* inputsData = inputs.getData(); - real* denomsData = denoms.getData(); - real* outputsData = outputs.getData(); - - size_t imageSize = numSample * imgSizeH * imgSizeW; +void CrossMapNormal(real* outputs, + real* denoms, + real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t imageSize = numSamples * height * width; int blockSize = 1024; int gridSize = (imageSize + 1024 - 1) / 1024; KeCMRNormFillScale<<>> - (imageSize, inputsData, denomsData, - channels, imgSizeH, imgSizeW, sizeX, scale); + (imageSize, inputs, denoms, channels, height, width, size, scale); - size_t inputSize = numSample * imgSizeH * imgSizeW *channels; + size_t inputSize = numSamples * height * width *channels; blockSize = 1024; gridSize = (inputSize + 1024 - 1) / 1024; KeCMRNormOutput<<>> - (inputSize, inputsData, denomsData, -pow, outputsData); + (inputSize, inputs, denoms, -pow, outputs); - CHECK_SYNC("CrossMapNormalFwd"); + CHECK_SYNC("CrossMapNormal"); } __global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index cd34ea18a70ea..aac3f757996eb 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1281,24 +1281,40 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); +#if 0 FuncConfig config; config.set("size", (size_t)sizeX); config.set("scale", scale); config.set("pow", pow); +#endif FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); - cpu->init(config); + FunctionBase* gpu = + FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, GPU)); + cpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); + gpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); Dims dims{ (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; cpu->calc({Tensor(inputs.getData(), dims)}, {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, {}); + + gpu->calc( + {Tensor(inputsGpu.getData(), dims)}, + {Tensor(outputsGpu.getData(), dims), Tensor(denomsGpu.getData(), dims)}, + {}); #if 0 CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); -#endif + CrossMapNormal gpuCross; gpuCross(outputsGpu, denomsGpu, @@ -1309,6 +1325,7 @@ void testCrossMapNormalFwd( sizeX, scale, pow); +#endif TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); From 707a9c9bbd67e936efeea134cc6eaf2f5fffe464 Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Thu, 15 Dec 2016 13:33:36 +0800 Subject: [PATCH 14/55] Fix variable name and add the annotation --- paddle/gserver/layers/PriorBox.cpp | 130 ++++++++---------- python/paddle/trainer/config_parser.py | 2 - .../paddle/trainer_config_helpers/layers.py | 10 +- 3 files changed, 63 insertions(+), 79 deletions(-) diff --git a/paddle/gserver/layers/PriorBox.cpp b/paddle/gserver/layers/PriorBox.cpp index 4b8573f05817a..c9194235fd14b 100644 --- a/paddle/gserver/layers/PriorBox.cpp +++ b/paddle/gserver/layers/PriorBox.cpp @@ -17,6 +17,15 @@ limitations under the License. */ #include "paddle/math/BaseMatrix.h" namespace paddle { +/** + * @brief A layer for generate prior box locations and variances. + * - Input: Two and only two input layer are accepted. The input layer must be + * be a data output layer and a convolution output layer. + * - Output: The prior box locations and variances of the input data. + * Reference: + * Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, + * Cheng-Yang Fu, Alexander C. Berg. SSD: Single Shot MultiBox Detector + */ class PriorBoxLayer : public Layer { public: @@ -24,106 +33,84 @@ class PriorBoxLayer : public Layer { bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType); void backward(const UpdateCallback& callback) {} - void forwardImp(const Argument& featureMap, const Argument& imageShape); int numPriors_; std::vector minSize_; std::vector maxSize_; std::vector aspectRatio_; std::vector variance_; - std::vector tmpCpuInput_; MatrixPtr buffer_; }; bool PriorBoxLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { Layer::init(layerMap, parameterMap); - auto pb_conf = config_.inputs(0).priorbox_conf(); - std::copy(pb_conf.min_size().begin(), - pb_conf.min_size().end(), + auto pbConf = config_.inputs(0).priorbox_conf(); + std::copy(pbConf.min_size().begin(), + pbConf.min_size().end(), std::back_inserter(minSize_)); - std::copy(pb_conf.max_size().begin(), - pb_conf.max_size().end(), + std::copy(pbConf.max_size().begin(), + pbConf.max_size().end(), std::back_inserter(maxSize_)); - std::copy(pb_conf.aspect_ratio().begin(), - pb_conf.aspect_ratio().end(), + std::copy(pbConf.aspect_ratio().begin(), + pbConf.aspect_ratio().end(), std::back_inserter(aspectRatio_)); - std::copy(pb_conf.variance().begin(), - pb_conf.variance().end(), + std::copy(pbConf.variance().begin(), + pbConf.variance().end(), std::back_inserter(variance_)); // flip - int input_ratio_length = aspectRatio_.size(); - for (int index = 0; index < input_ratio_length; index++) + int inputRatioLength = aspectRatio_.size(); + for (int index = 0; index < inputRatioLength; index++) aspectRatio_.push_back(1 / aspectRatio_[index]); aspectRatio_.push_back(1.); numPriors_ = aspectRatio_.size(); if (maxSize_.size() > 0) numPriors_++; - buffer_ = Matrix::create(1, 1, false, false); - if (useGpu_) { - tmpCpuInput_.reserve(inputLayers_.size()); - for (size_t i = 0; i < inputLayers_.size(); i++) { - tmpCpuInput_.push_back(Argument()); - } - } return true; } void PriorBoxLayer::forward(PassType passType) { Layer::forward(passType); - if (useGpu_) { - for (size_t i = 0; i < inputLayers_.size(); i++) { - tmpCpuInput_[i].resizeAndCopyFrom( - getInput(i), false, HPPL_STREAM_DEFAULT); - hl_stream_synchronize(HPPL_STREAM_DEFAULT); - forwardImp(tmpCpuInput_[0], tmpCpuInput_[1]); - } - } else { - forwardImp(getInput(0), getInput(1)); - } -} - -void PriorBoxLayer::forwardImp(const Argument& featureMap, - const Argument& imageShape) { - int layer_width = featureMap.getFrameWidth(); - int layer_height = featureMap.getFrameHeight(); + auto input = getInput(0); + int layerWidth = input.getFrameWidth(); + int layerHeight = input.getFrameHeight(); - MatrixPtr inV1 = imageShape.value; - int image_width = inV1->getElement(0, 0); - int image_height = inV1->getElement(0, 1); - float step_w = static_cast(image_width) / layer_width; - float step_h = static_cast(image_height) / layer_height; - int dim = layer_height * layer_width * numPriors_ * 4; + auto image = getInput(1); + int imageWidth = image.getFrameWidth(); + int imageHeight = image.getFrameHeight(); + float stepW = static_cast(imageWidth) / layerWidth; + float stepH = static_cast(imageHeight) / layerHeight; + int dim = layerHeight * layerWidth * numPriors_ * 4; reserveOutput(1, dim * 2); // use a cpu buffer to compute Matrix::resizeOrCreate(buffer_, 1, dim * 2, false, false); - auto* tmp_ptr = buffer_->getData(); + auto* tmpPtr = buffer_->getData(); int idx = 0; - for (int h = 0; h < layer_height; ++h) { - for (int w = 0; w < layer_width; ++w) { - float center_x = (w + 0.5) * step_w; - float center_y = (h + 0.5) * step_h; - int min_size = 0; + for (int h = 0; h < layerHeight; ++h) { + for (int w = 0; w < layerWidth; ++w) { + float centerX = (w + 0.5) * stepW; + float centerY = (h + 0.5) * stepH; + int minSize = 0; for (size_t s = 0; s < minSize_.size(); s++) { // first prior. - min_size = minSize_[s]; - int box_width = min_size; - int box_height = min_size; + minSize = minSize_[s]; + int boxWidth = minSize; + int boxHeight = minSize; // xmin, ymin, xmax, ymax. - tmp_ptr[idx++] = (center_x - box_width / 2.) / image_width; - tmp_ptr[idx++] = (center_y - box_height / 2.) / image_height; - tmp_ptr[idx++] = (center_x + box_width / 2.) / image_width; - tmp_ptr[idx++] = (center_y + box_height / 2.) / image_height; + tmpPtr[idx++] = (centerX - boxWidth / 2.) / imageWidth; + tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; + tmpPtr[idx++] = (centerX + boxWidth / 2.) / imageWidth; + tmpPtr[idx++] = (centerY + boxHeight / 2.) / imageHeight; if (maxSize_.size() > 0) { CHECK_EQ(minSize_.size(), maxSize_.size()); // second prior. for (size_t s = 0; s < maxSize_.size(); s++) { - int max_size = maxSize_[s]; - box_width = box_height = sqrt(min_size * max_size); - tmp_ptr[idx++] = (center_x - box_width / 2.) / image_width; - tmp_ptr[idx++] = (center_y - box_height / 2.) / image_height; - tmp_ptr[idx++] = (center_x + box_width / 2.) / image_width; - tmp_ptr[idx++] = (center_y + box_height / 2.) / image_height; + int maxSize = maxSize_[s]; + boxWidth = boxHeight = sqrt(minSize * maxSize); + tmpPtr[idx++] = (centerX - boxWidth / 2.) / imageWidth; + tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; + tmpPtr[idx++] = (centerX + boxWidth / 2.) / imageWidth; + tmpPtr[idx++] = (centerY + boxHeight / 2.) / imageHeight; } } } @@ -131,27 +118,26 @@ void PriorBoxLayer::forwardImp(const Argument& featureMap, for (size_t r = 0; r < aspectRatio_.size(); r++) { float ar = aspectRatio_[r]; if (fabs(ar - 1.) < 1e-6) continue; - float box_width = min_size * sqrt(ar); - float box_height = min_size / sqrt(ar); - tmp_ptr[idx++] = (center_x - box_width / 2.) / image_width; - tmp_ptr[idx++] = (center_y - box_height / 2.) / image_height; - tmp_ptr[idx++] = (center_x + box_width / 2.) / image_width; - tmp_ptr[idx++] = (center_y + box_height / 2.) / image_height; + float boxWidth = minSize * sqrt(ar); + float boxHeight = minSize / sqrt(ar); + tmpPtr[idx++] = (centerX - boxWidth / 2.) / imageWidth; + tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; + tmpPtr[idx++] = (centerX + boxWidth / 2.) / imageWidth; + tmpPtr[idx++] = (centerY + boxHeight / 2.) / imageHeight; } } } // clip the prior's coordidate such that it is within [0, 1] for (int d = 0; d < dim; ++d) - tmp_ptr[d] = std::min(std::max(tmp_ptr[d], (float)0.), (float)1.); + tmpPtr[d] = std::min(std::max(tmpPtr[d], (float)0.), (float)1.); // set the variance. - for (int h = 0; h < layer_height; h++) - for (int w = 0; w < layer_width; w++) + for (int h = 0; h < layerHeight; h++) + for (int w = 0; w < layerWidth; w++) for (int i = 0; i < numPriors_; i++) - for (int j = 0; j < 4; j++) tmp_ptr[idx++] = variance_[j]; + for (int j = 0; j < 4; j++) tmpPtr[idx++] = variance_[j]; MatrixPtr outV = getOutputValue(); outV->copyFrom(buffer_->data_, dim * 2); } - REGISTER_LAYER(priorbox, PriorBoxLayer); } // namespace paddle diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 8a82e5d667aa3..0f7c601fe0d29 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1589,8 +1589,6 @@ def __init__(self, name, inputs, size, min_size, max_size, aspect_ratio, self.config.inputs[0].priorbox_conf.aspect_ratio.extend(aspect_ratio) self.config.inputs[0].priorbox_conf.variance.extend(variance) self.config.size = size - input_layer0 = self.get_input_layer(0) - input_layer1 = self.get_input_layer(1) @config_layer('data') diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 80c421aa2ec3b..4bcdb9f35e2f0 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -938,7 +938,7 @@ def print_layer(input, name=None): @wrap_name_default("priorbox") def priorbox_layer(input, - img_shape, + image, aspect_ratio, variance, min_size, @@ -951,8 +951,8 @@ def priorbox_layer(input, :type name: basestring :param input: The input layer. :type input: LayerOutput - :param img_shape: The width and height of the network input image. - :type img_shape: LayerOutput + :param image: The network input image. + :type image: LayerOutput :param aspect_ratio: The aspect ratio. :type aspect_ratio: list :param variance: The bounding box variance. @@ -968,7 +968,7 @@ def priorbox_layer(input, Layer( name=name, type=LayerType.PRIORBOX_LAYER, - inputs=[input.name, img_shape.name], + inputs=[input.name, image.name], size=size, min_size=min_size, max_size=max_size, @@ -977,7 +977,7 @@ def priorbox_layer(input, return LayerOutput( name, LayerType.PRIORBOX_LAYER, - parents=[input, img_shape], + parents=[input, image], num_filters=num_filters, size=size) From 520342ed9179a727cbe05ccfaa80cc491acd9eef Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Thu, 15 Dec 2016 15:35:44 +0800 Subject: [PATCH 15/55] Fix code format --- paddle/gserver/layers/PriorBox.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/paddle/gserver/layers/PriorBox.cpp b/paddle/gserver/layers/PriorBox.cpp index c9194235fd14b..dd52f61c30ece 100644 --- a/paddle/gserver/layers/PriorBox.cpp +++ b/paddle/gserver/layers/PriorBox.cpp @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "Layer.h" -#include "paddle/math/Matrix.h" #include "paddle/math/BaseMatrix.h" +#include "paddle/math/Matrix.h" namespace paddle { /** @@ -100,6 +100,8 @@ void PriorBoxLayer::forward(PassType passType) { tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; tmpPtr[idx++] = (centerX + boxWidth / 2.) / imageWidth; tmpPtr[idx++] = (centerY + boxHeight / 2.) / imageHeight; + // set the variance. + for (int t = 0; t < 4; t++) tmpPtr[idx++] = variance_[t]; if (maxSize_.size() > 0) { CHECK_EQ(minSize_.size(), maxSize_.size()); @@ -111,6 +113,8 @@ void PriorBoxLayer::forward(PassType passType) { tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; tmpPtr[idx++] = (centerX + boxWidth / 2.) / imageWidth; tmpPtr[idx++] = (centerY + boxHeight / 2.) / imageHeight; + // set the variance. + for (int t = 0; t < 4; t++) tmpPtr[idx++] = variance_[t]; } } } @@ -124,17 +128,15 @@ void PriorBoxLayer::forward(PassType passType) { tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; tmpPtr[idx++] = (centerX + boxWidth / 2.) / imageWidth; tmpPtr[idx++] = (centerY + boxHeight / 2.) / imageHeight; + // set the variance. + for (int t = 0; t < 4; t++) tmpPtr[idx++] = variance_[t]; } } } // clip the prior's coordidate such that it is within [0, 1] - for (int d = 0; d < dim; ++d) - tmpPtr[d] = std::min(std::max(tmpPtr[d], (float)0.), (float)1.); - // set the variance. - for (int h = 0; h < layerHeight; h++) - for (int w = 0; w < layerWidth; w++) - for (int i = 0; i < numPriors_; i++) - for (int j = 0; j < 4; j++) tmpPtr[idx++] = variance_[j]; + for (int d = 0; d < dim * 2; ++d) + if ((d % 8) < 4) + tmpPtr[d] = std::min(std::max(tmpPtr[d], (float)0.), (float)1.); MatrixPtr outV = getOutputValue(); outV->copyFrom(buffer_->data_, dim * 2); } From d2d0010609b6ba621360973b6c6972b836607de3 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 16:19:10 +0800 Subject: [PATCH 16/55] add CrossMapNormalGradFunc --- paddle/gserver/layers/NormProjectionLayer.cpp | 41 +++-- paddle/gserver/layers/NormProjectionLayer.h | 7 +- paddle/math/Function.h | 2 +- paddle/math/cross_map_normal_op.cpp | 145 ++++++++++++------ paddle/math/cross_map_normal_op.h | 40 ++--- paddle/math/cross_map_normal_op_gpu.cu | 54 ++----- paddle/math/tests/test_matrixCompare.cpp | 57 ++++--- 7 files changed, 190 insertions(+), 156 deletions(-) diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index 03c6952c30b0c..d6923c2192cf1 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -13,10 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "NormProjectionLayer.h" +#include "paddle/math/cross_map_normal_op.h" #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" -#include "paddle/math/cross_map_normal_op.h" -#include "NormProjectionLayer.h" namespace paddle { size_t CMRProjectionNormLayer::getSize() { @@ -48,13 +47,23 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, CHECK_EQ(config_.inputs_size(), 1); if (useGpu_) { - normal_ = FunctionBase::funcRegistrar_.createByType( + forward_ = FunctionBase::funcRegistrar_.createByType( FUNC_NAME(CrossMapNormal, GPU)); } else { - normal_ = FunctionBase::funcRegistrar_.createByType( + forward_ = FunctionBase::funcRegistrar_.createByType( FUNC_NAME(CrossMapNormal, CPU)); } - normal_->init( + forward_->init( + FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); + + if (useGpu_) { + backward_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, GPU)); + } else { + backward_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, CPU)); + } + backward_->init( FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); return true; @@ -74,13 +83,13 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - Dims dims{(size_t)batchSize, - (size_t)channels_, - (size_t)imgSizeH_, - (size_t)imgSizeW_}; - normal_->calc( - {Tensor(input->getData(), dims)}, - {Tensor(outV->getData(), dims), Tensor(denoms_->getData(), dims)}, + dims_ = {(size_t)batchSize, + (size_t)channels_, + (size_t)imgSizeH_, + (size_t)imgSizeW_}; + forward_->calc( + {Tensor(input->getData(), dims_)}, + {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, {}); } @@ -96,6 +105,13 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { MatrixPtr localOutV = getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); + backward_->calc({Tensor(preOutV->getData(), dims_), + Tensor(localOutV->getData(), dims_), + Tensor(localGrad->getData(), dims_), + Tensor(denoms_->getData(), dims_)}, + {Tensor(preOutGrad->getData(), dims_)}, + {}); +#if 0 if (useGpu_) { CrossMapNormalGrad crossGrad; crossGrad(dynamic_cast(*preOutGrad), @@ -123,5 +139,6 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { scale_, pow_); } +#endif } } // namespace paddle diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 1dc3921283ca7..82aa427f8d425 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,9 +16,8 @@ limitations under the License. */ #include #include "NormLayer.h" -#include "paddle/math/Matrix.h" #include "paddle/math/Function.h" -#include +#include "paddle/math/Matrix.h" namespace paddle { @@ -43,6 +42,8 @@ class CMRProjectionNormLayer : public ResponseNormLayer { void backward(const UpdateCallback& callback = nullptr); protected: - FunctionBase* normal_; + Dims dims_; + FunctionBase* forward_; + FunctionBase* backward_; }; } // namespace paddle diff --git a/paddle/math/Function.h b/paddle/math/Function.h index f8fab972a6902..095584c0b19f7 100644 --- a/paddle/math/Function.h +++ b/paddle/math/Function.h @@ -16,8 +16,8 @@ limitations under the License. */ #include #include -#include "paddle/utils/ClassRegistrar.h" #include "paddle/math/Matrix.h" +#include "paddle/utils/ClassRegistrar.h" namespace paddle { diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index e520351d2e3b8..8547978c99160 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "cross_map_normal_op.h" +#include "paddle/math/Vector.h" namespace paddle { @@ -56,66 +57,49 @@ void CrossMapNormal(real* outputs, } template <> -void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, - CpuMatrix& inputsValue, - CpuMatrix& outputsGrad, - CpuMatrix& outputsValue, - CpuMatrix& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(inputsGrad.isContiguous()); - CHECK(outputsGrad.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK(inputsValue.isContiguous()); - CHECK(outputsValue.isContiguous()); - CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); - - size_t numSample = inputsGrad.getHeight(); - size_t numCols = inputsGrad.getWidth(); - size_t imageSize = imgSizeH * imgSizeW; - CHECK(imageSize * channels == numCols); - +void CrossMapNormalGrad(real* inputsGrad, + real* inputsValue, + real* outputsValue, + real* outputsGrad, + real* denoms, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t oneSample = channels * height * width; std::function oneImage = [=](real* data, size_t offset) { - return CpuVector(imageSize, data + offset); + return CpuVector(height * width, data + offset); }; - const int start = -((int)sizeX) / 2; - const int end = (int)sizeX + start; + const int start = -((int)size) / 2; + const int end = (int)size + start; const real ratio = -(real)2 * scale * pow; - for (size_t i = 0; i < numSample; i++) { - size_t sOffset = i * numCols; - real* inputGradData = inputsGrad.getData() + sOffset; - real* inputData = inputsValue.getData() + sOffset; - real* denomData = denoms.getData() + sOffset; - real* outputGradData = outputsGrad.getData() + sOffset; - real* outputData = outputsValue.getData() + sOffset; + for (size_t i = 0; i < numSamples; i++) { + size_t sOffset = i * oneSample; + real* oneInputGrad = inputsGrad + sOffset; + real* oneInputValue = inputsValue + sOffset; + real* oneDenom = denoms + sOffset; + real* oneOutputGrad = outputsGrad + sOffset; + real* oneOutputValue = outputsValue + sOffset; for (int c = 0; c < (int)channels; c++) { - size_t cOffset = c * imageSize; - CpuVector inputGrad = oneImage(inputGradData, cOffset); - CpuVector inputValue = oneImage(inputData, cOffset); - CpuVector denom = oneImage(denomData, cOffset); - CpuVector outputGrad = oneImage(outputGradData, cOffset); + size_t cOffset = c * height * width; + CpuVector inputGrad = oneImage(oneInputGrad, cOffset); + CpuVector inputValue = oneImage(oneInputValue, cOffset); + CpuVector denom = oneImage(oneDenom, cOffset); + CpuVector outputGrad = oneImage(oneOutputGrad, cOffset); inputGrad = inputGrad + denom.pow(-pow) * outputGrad; for (int s = start; s < end; s++) { if (c + s >= 0 && c + s < (int)channels) { - size_t offset = (c + s) * imageSize; - CpuVector output = oneImage(outputData, offset); - CpuVector outputGrad = oneImage(outputGradData, offset); - CpuVector denom = oneImage(denomData, offset); + size_t offset = (c + s) * height * width; + CpuVector output = oneImage(oneOutputValue, offset); + CpuVector outputGrad = oneImage(oneOutputGrad, offset); + CpuVector denom = oneImage(oneDenom, offset); inputGrad += ((outputGrad * output * ratio) / denom) * inputValue; } @@ -124,6 +108,11 @@ void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, } } +/** + * \param inputs[0] input value. + * \param outputs[0] output value. + * \param outputs[1] denoms. + */ template class CrossMapNormalFunc : public FunctionBase { public: @@ -169,7 +158,65 @@ class CrossMapNormalFunc : public FunctionBase { real pow_; }; +/** + * \param inputs[0] input value. + * \param inputs[1] output value. + * \param inputs[2] output grad. + * \param inputs[3] denoms. + * \param outputs[0] input grad. + */ +template +class CrossMapNormalGradFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + size_ = config.get("size"); + scale_ = config.get("scale"); + pow_ = config.get("pow"); + } + + void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) override { + CHECK_EQ(4, inputs.size()); + CHECK_EQ(1, outputs.size()); + CHECK_EQ(0, inouts.size()); + + CHECK_EQ(inputs[0].dims_.size(), 4); + for (size_t i = 0; i < inputs[0].dims_.size(); i++) { + CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); + } + + size_t samples = inputs[0].dims_[0]; + size_t channels = inputs[0].dims_[1]; + size_t height = inputs[0].dims_[2]; + size_t width = inputs[0].dims_[3]; + + CrossMapNormalGrad(outputs[0].getData(), + inputs[0].getData(), + inputs[1].getData(), + inputs[2].getData(), + inputs[3].getData(), + samples, + channels, + height, + width, + size_, + scale_, + pow_); + } + +private: + size_t size_; + real scale_; + real pow_; +}; + REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); +REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc); +REGISTER_TYPED_FUNC(CrossMapNormalGrad, GPU, CrossMapNormalGradFunc); } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index ef9533485ec9c..f065208084f1d 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include "Function.h" -#include "paddle/math/Matrix.h" namespace paddle { @@ -30,34 +29,19 @@ void CrossMapNormal(real* outputs, size_t size, real scale, real pow); -#if 0 -template -struct CrossMapNormal { - void operator()(typename MatrixT::type& outputs, - typename MatrixT::type& denoms, - typename MatrixT::type& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow); -}; -#endif template -struct CrossMapNormalGrad { - void operator()(typename MatrixT::type& inputsGrad, - typename MatrixT::type& inputsValue, - typename MatrixT::type& outputsGrad, - typename MatrixT::type& outputsValue, - typename MatrixT::type& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow); -}; +void CrossMapNormalGrad(real* inputsGrad, + real* inputsValue, + real* outputsValue, + real* outputsGrad, + real* denoms, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow); } // namespace paddle diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/math/cross_map_normal_op_gpu.cu index 9b92974344955..6339c04194834 100644 --- a/paddle/math/cross_map_normal_op_gpu.cu +++ b/paddle/math/cross_map_normal_op_gpu.cu @@ -131,48 +131,26 @@ __global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, } template <> -void CrossMapNormalGrad::operator()(GpuMatrix& inputsGrad, - GpuMatrix& inputsValue, - GpuMatrix& outputsGrad, - GpuMatrix& outputsValue, - GpuMatrix& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(inputsGrad.isContiguous()); - CHECK(outputsGrad.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK(inputsValue.isContiguous()); - CHECK(outputsValue.isContiguous()); - CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); - - size_t numSample = inputsGrad.getHeight(); - size_t numCols = inputsGrad.getWidth(); - CHECK(imgSizeH * imgSizeW * channels == numCols); - - size_t imageSize = numSample * imgSizeH * imgSizeW; - real* inputsGradData = inputsGrad.getData(); - real* inputsData = inputsValue.getData(); - real* denomsData = denoms.getData(); - real* outputsGradData = outputsGrad.getData(); - real* outputsData = outputsValue.getData(); +void CrossMapNormalGrad(real* inputsGrad, + real* inputsValue, + real* outputsValue, + real* outputsGrad, + real* denoms, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t imageSize = numSamples * height * width; int blockSize = 1024; int gridSize = (imageSize + 1024 - 1) / 1024; KeCMRNormDiff <<>> - (imageSize, inputsData, outputsData, denomsData, outputsGradData, channels, - imgSizeH, imgSizeW, sizeX, -pow, 2.0f * pow * scale, inputsGradData); - CHECK_SYNC("KeCMRNormDiff"); + (imageSize, inputsValue, outputsValue, denoms, outputsGrad, channels, + height, width, size, -pow, 2.0f * pow * scale, inputsGrad); + CHECK_SYNC("CrossMapNormalGrad"); } } // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 0341d757f31b9..bc146514572a4 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -19,12 +19,11 @@ limitations under the License. */ #include #include "TensorCheck.h" #include "paddle/gserver/tests/TestUtil.h" +#include "paddle/math/Function.h" #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h" -#include "paddle/utils/Stat.h" -#include "TensorCheck.h" #include "paddle/math/cross_map_normal_op.h" -#include "paddle/math/Function.h" +#include "paddle/utils/Stat.h" #include "paddle/utils/Util.h" using namespace paddle; // NOLINT @@ -1282,12 +1281,6 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); -#if 0 - FuncConfig config; - config.set("size", (size_t)sizeX); - config.set("scale", scale); - config.set("pow", pow); -#endif FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); FunctionBase* gpu = @@ -1311,22 +1304,6 @@ void testCrossMapNormalFwd( {Tensor(inputsGpu.getData(), dims)}, {Tensor(outputsGpu.getData(), dims), Tensor(denomsGpu.getData(), dims)}, {}); -#if 0 - CrossMapNormal cpuCross; - cpuCross( - outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); - - CrossMapNormal gpuCross; - gpuCross(outputsGpu, - denomsGpu, - inputsGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); -#endif TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); @@ -1381,6 +1358,35 @@ void testCrossMapNormalBwd( outputsValueGpu.copyFrom(outputsValue); inputsGradGpu.copyFrom(inputsGrad); + FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, CPU)); + FunctionBase* gpu = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, GPU)); + cpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); + gpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); + + Dims dims{ + (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; + cpu->calc({Tensor(inputsValue.getData(), dims), + Tensor(outputsValue.getData(), dims), + Tensor(outputsGrad.getData(), dims), + Tensor(denoms.getData(), dims)}, + {Tensor(inputsGrad.getData(), dims)}, + {}); + + gpu->calc({Tensor(inputsValueGpu.getData(), dims), + Tensor(outputsValueGpu.getData(), dims), + Tensor(outputsGradGpu.getData(), dims), + Tensor(denomsGpu.getData(), dims)}, + {Tensor(inputsGradGpu.getData(), dims)}, + {}); +#if 0 CrossMapNormalGrad cpuCross; cpuCross(inputsGrad, inputsValue, @@ -1406,6 +1412,7 @@ void testCrossMapNormalBwd( sizeX, scale, pow); +#endif TensorCheckErr(inputsGrad, inputsGradGpu); } From 22a5e478f3b6ecc0e43d31abce39a686b6331165 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 16:36:51 +0800 Subject: [PATCH 17/55] move Function to function dir --- paddle/{math => function}/Function.cpp | 0 paddle/{math => function}/Function.h | 0 paddle/{math => function}/cross_map_normal_op.cpp | 0 paddle/{math => function}/cross_map_normal_op.h | 0 paddle/{math => function}/cross_map_normal_op_gpu.cu | 0 paddle/gserver/layers/NormProjectionLayer.cpp | 1 - paddle/gserver/layers/NormProjectionLayer.h | 2 +- paddle/math/tests/test_matrixCompare.cpp | 3 +-- 8 files changed, 2 insertions(+), 4 deletions(-) rename paddle/{math => function}/Function.cpp (100%) rename paddle/{math => function}/Function.h (100%) rename paddle/{math => function}/cross_map_normal_op.cpp (100%) rename paddle/{math => function}/cross_map_normal_op.h (100%) rename paddle/{math => function}/cross_map_normal_op_gpu.cu (100%) diff --git a/paddle/math/Function.cpp b/paddle/function/Function.cpp similarity index 100% rename from paddle/math/Function.cpp rename to paddle/function/Function.cpp diff --git a/paddle/math/Function.h b/paddle/function/Function.h similarity index 100% rename from paddle/math/Function.h rename to paddle/function/Function.h diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp similarity index 100% rename from paddle/math/cross_map_normal_op.cpp rename to paddle/function/cross_map_normal_op.cpp diff --git a/paddle/math/cross_map_normal_op.h b/paddle/function/cross_map_normal_op.h similarity index 100% rename from paddle/math/cross_map_normal_op.h rename to paddle/function/cross_map_normal_op.h diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/function/cross_map_normal_op_gpu.cu similarity index 100% rename from paddle/math/cross_map_normal_op_gpu.cu rename to paddle/function/cross_map_normal_op_gpu.cu diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index d6923c2192cf1..e69c406993054 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "NormProjectionLayer.h" -#include "paddle/math/cross_map_normal_op.h" #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 82aa427f8d425..3c4876ece609e 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "NormLayer.h" -#include "paddle/math/Function.h" +#include "paddle/function/Function.h" #include "paddle/math/Matrix.h" namespace paddle { diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index bc146514572a4..da7a585484e7e 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -18,11 +18,10 @@ limitations under the License. */ #include #include "TensorCheck.h" +#include "paddle/function/Function.h" #include "paddle/gserver/tests/TestUtil.h" -#include "paddle/math/Function.h" #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h" -#include "paddle/math/cross_map_normal_op.h" #include "paddle/utils/Stat.h" #include "paddle/utils/Util.h" From 558e86927caa2bbe0bc97b287f9d1abe73cfaaa3 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 17:12:22 +0800 Subject: [PATCH 18/55] add CMakeLists --- cmake/util.cmake | 1 + paddle/CMakeLists.txt | 1 + paddle/function/CMakeLists.txt | 12 ++++++++++++ paddle/function/cross_map_normal_op.cpp | 4 +++- paddle/gserver/CMakeLists.txt | 8 ++------ 5 files changed, 19 insertions(+), 7 deletions(-) create mode 100644 paddle/function/CMakeLists.txt diff --git a/cmake/util.cmake b/cmake/util.cmake index 38366373c6dbc..03734e7839d74 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -96,6 +96,7 @@ function(link_paddle_exe TARGET_NAME) target_circle_link_libraries(${TARGET_NAME} ARCHIVE_START paddle_gserver + paddle_function ${METRIC_LIBS} ARCHIVE_END paddle_pserver diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index fb3af8ea92fee..2daea052b01ad 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(cuda) +add_subdirectory(function) add_subdirectory(utils) add_subdirectory(math) add_subdirectory(parameter) diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt new file mode 100644 index 0000000000000..8fad0e3ebdfb2 --- /dev/null +++ b/paddle/function/CMakeLists.txt @@ -0,0 +1,12 @@ +file(GLOB FUNCTION_HEADERS . *.h) + +if(NOT WITH_GPU) + file(GLOB FUNCTION_SOURCES . *.cpp) + add_library(paddle_function STATIC ${FUNCTION_SOURCES}) +else() + file(GLOB FUNCTION_SOURCES . *.cpp *.cu) + cuda_add_library(paddle_function ${FUNCTION_SOURCES}) +endif() + +add_style_check_target(paddle_function ${FUNCTION_SOURCES}) +add_style_check_target(paddle_function ${FUNCTION_HEADERS}) diff --git a/paddle/function/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp index 8547978c99160..0391a58d89f4a 100644 --- a/paddle/function/cross_map_normal_op.cpp +++ b/paddle/function/cross_map_normal_op.cpp @@ -215,8 +215,10 @@ class CrossMapNormalGradFunc : public FunctionBase { }; REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); -REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); REGISTER_TYPED_FUNC(CrossMapNormalGrad, GPU, CrossMapNormalGradFunc); +#endif } // namespace paddle diff --git a/paddle/gserver/CMakeLists.txt b/paddle/gserver/CMakeLists.txt index a066f80c221ee..4f92150ec84d6 100644 --- a/paddle/gserver/CMakeLists.txt +++ b/paddle/gserver/CMakeLists.txt @@ -27,16 +27,12 @@ if(NOT WITH_GPU) list(REMOVE_ITEM GSERVER_HEADER layers/CudnnConvLayer.h layers/CudnnPoolLayer.h - layers/CudnnBatchNormLayer.h - layers/NormProjectionLayer.h - layers/NormLayer.h) + layers/CudnnBatchNormLayer.h) list(REMOVE_ITEM GSERVER_SOURCES layers/CudnnConvLayer.cpp layers/CudnnPoolLayer.cpp - layers/CudnnBatchNormLayer.cpp - layers/NormProjectionLayer.cpp - layers/NormLayer.cpp) + layers/CudnnBatchNormLayer.cpp) compile_cu_as_cpp(layers/LstmCompute.cu) compile_cu_as_cpp(layers/GruCompute.cu) endif() From d11e2b401348c147b20507863a43b8952f17d6a1 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 17:33:01 +0800 Subject: [PATCH 19/55] Remove some useless code --- paddle/cuda/include/hl_cnn.h | 56 ------ paddle/cuda/include/stub/hl_cnn_stub.h | 24 --- paddle/cuda/src/hl_cuda_cnn.cu | 120 ------------ paddle/gserver/layers/NormProjectionLayer.cpp | 29 --- paddle/math/Matrix.cpp | 176 ------------------ paddle/math/Matrix.h | 65 ------- paddle/math/tests/test_matrixCompare.cpp | 27 --- 7 files changed, 497 deletions(-) diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index 06ee3b3654b57..c5787630abbe1 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -240,62 +240,6 @@ extern void hl_avgpool_backward(const int frameCnt, real* backGrad, const int outStride); -/** - * @brief Cross-map-respose normalize forward. - * - * @param[in] frameCnt batch size of input image. - * @param[in] in input data. - * @param[in] scale buffer. - * @param[out] out output data. - * @param[in] channels number of channel. - * @param[in] height image height. - * @param[in] width image width. - * @param[in] sizeX size. - * @param[in] alpha scale. - * @param[in] beta scale. - * - */ -extern void hl_CMRNorm_forward(size_t frameCnt, - const real* in, - real* scale, - real* out, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta); - -/** - * @brief Cross-map-respose normalize backward. - * - * @param[in] frameCnt batch size of input image. - * @param[in] inV input data. - * @param[in] scale buffer. - * @param[out] outV output value. - * @param[out] outDiff output grad. - * @param[out] inDiff input grad. - * @param[in] channels number of channel. - * @param[in] height image height. - * @param[in] width image width. - * @param[in] sizeX size. - * @param[in] alpha scale. - * @param[in] beta scale. - * - */ -extern void hl_CMRNorm_backward(size_t frameCnt, - const real* inV, - const real* scale, - const real* outV, - const real* outDiff, - real* inDiff, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta); - /** * @brief Bilinear interpolation forward. * diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index 52c978735279e..039551c6cc695 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -117,30 +117,6 @@ inline void hl_avgpool_backward(const int frameCnt, real* backGrad, const int outStride) {} -inline void hl_CMRNorm_forward(size_t frameCnt, - const real* in, - real* scale, - real* out, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta) {} - -inline void hl_CMRNorm_backward(size_t frameCnt, - const real* inV, - const real* scale, - const real* outV, - const real* outDiff, - real* inDiff, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta) {} - inline void hl_bilinear_forward(const real* inData, const size_t inImgH, const size_t inImgW, diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 1516accaae17f..b94f4d8fe4a25 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -381,126 +381,6 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad, CHECK_SYNC("hl_avgpool_backward failed"); } -__global__ void KeCMRNormFillScale(size_t imageSize, const real* in, - real* scale, size_t channels, - size_t height, size_t width, size_t size, - real alpha) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < imageSize) { - const int w = idx % width; - const int h = (idx / width) % height; - const int n = idx / width / height; - const int offset = (n * channels * height + h) * width + w; - - in += offset; - scale += offset; - const int step = height * width; - const int pre_pad = (size - 1) / 2; - const int post_pad = size - pre_pad - 1; - - real accum = 0; - int index = 0; - while (index < channels + post_pad) { - if (index < channels) { - accum += in[index * step] * in[index * step]; - } - if (index >= size) { - accum -= in[(index - size) * step] * in[(index - size) * step]; - } - if (index >= post_pad) { - scale[(index - post_pad) * step] = 1. + accum * alpha; - } - ++index; - } - } -} - -__global__ void KeCMRNormOutput(size_t inputSize, const real* in, - const real* scale, real negative_beta, - real* out) { - const int index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < inputSize) { - out[index] = in[index] * pow(scale[index], negative_beta); - } -} - -void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale, - real* out, size_t channels, - size_t height, size_t width, size_t sizeX, - real alpha, real beta) { - size_t imageSize = frameCnt * height * width; - int blockSize = 1024; - int gridSize = (imageSize + 1024 - 1) / 1024; - KeCMRNormFillScale<<>> - (imageSize, in, scale, channels, height, width, sizeX, alpha); - - size_t inputSize = frameCnt * height * width *channels; - blockSize = 1024; - gridSize = (inputSize + 1024 - 1) / 1024; - KeCMRNormOutput<<>> - (inputSize, in, scale, beta, out); - CHECK_SYNC("hl_CMRNorm_forward"); -} - -__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, - const real* top_data, const real* scale, - const real* top_diff, size_t channels, - size_t height, size_t width, size_t size, - real negative_beta, real cache_ratio, - real* bottom_diff ) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < imageSize) { - const int w = idx % width; - const int h = (idx / width) % height; - const int n = idx / width / height; - const int offset = (n * channels * height + h) * width + w; - bottom_data += offset; - top_data += offset; - scale += offset; - top_diff += offset; - bottom_diff += offset; - - const int step = height * width; - const int pre_pad = size - (size + 1) / 2; - const int post_pad = size - pre_pad - 1; - - int index = 0; - real accum = 0; - while (index < channels + post_pad) { - if (index < channels) { - accum += top_diff[index * step] * top_data[index * step] / - scale[index * step]; - } - if (index >= size) { - accum -= top_diff[(index - size) * step] * - top_data[(index - size) * step] / scale[(index - size) * step]; - } - if (index >= post_pad) { - bottom_diff[(index - post_pad) * step] += - top_diff[(index - post_pad) * step] * - pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(index - post_pad) * step] * accum; - } - ++index; - } - } -} - -void hl_CMRNorm_backward(size_t frameCnt, const real* inV, - const real* scale, - const real* outV, const real* outDiff, - real *inDiff, size_t channels, - size_t height, size_t width, size_t sizeX, - real alpha, real beta) { - size_t imageSize = frameCnt * height * width; - int blockSize = 1024; - int gridSize = (imageSize + 1024 - 1) / 1024; - KeCMRNormDiff <<>> - (imageSize, inV, outV, scale, outDiff, channels, - height, width, sizeX, alpha, beta, inDiff); - CHECK_SYNC("hl_CMRNorm_backward"); -} - __global__ void KeBilinearInterpFw(const real* in, const size_t inImgH, const size_t inImgW, diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index e69c406993054..4ff3b805fbb06 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -110,34 +110,5 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { Tensor(denoms_->getData(), dims_)}, {Tensor(preOutGrad->getData(), dims_)}, {}); -#if 0 - if (useGpu_) { - CrossMapNormalGrad crossGrad; - crossGrad(dynamic_cast(*preOutGrad), - dynamic_cast(*preOutV), - dynamic_cast(*localGrad), - dynamic_cast(*localOutV), - dynamic_cast(*denoms_), - channels_, - imgSizeH_, - imgSizeW_, - size_, - scale_, - pow_); - } else { - CrossMapNormalGrad crossGrad; - crossGrad(dynamic_cast(*preOutGrad), - dynamic_cast(*preOutV), - dynamic_cast(*localGrad), - dynamic_cast(*localOutV), - dynamic_cast(*denoms_), - channels_, - imgSizeH_, - imgSizeW_, - size_, - scale_, - pow_); - } -#endif } } // namespace paddle diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 2cde11dd479dc..a36c31d32b7ca 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1265,69 +1265,6 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad, outGrad.getStride()); } -void GpuMatrix::crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow) { - size_t num = input.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - - CHECK(height * width * channels == input.getWidth()); - CHECK(denoms.getHeight() == input.getHeight() && - denoms.getWidth() == input.getWidth() && input.getHeight() == height_ && - input.getWidth() == width_); - hl_CMRNorm_forward(num, - input.getData(), - denoms.getData(), - data_, - channels, - height, - width, - sizeX, - scale, - -pow); -} - -void GpuMatrix::crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow) { - size_t num = preOutV.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - - CHECK(width * height * channels == preOutV.getWidth()); - CHECK(denoms.getHeight() == preOutV.getHeight() && - denoms.getWidth() == preOutV.getWidth() && - preOutV.getHeight() == height_ && preOutV.getWidth() == width_); - CHECK(denoms.getHeight() == localGrad.getHeight() && - denoms.getWidth() == localGrad.getWidth()); - - hl_CMRNorm_backward(num, - preOutV.getData(), - denoms.getData(), - localOutV.getData(), - localGrad.getData(), - data_, - channels, - height, - width, - sizeX, - -pow, - 2.0f * pow * scale); -} - void GpuMatrix::maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index) { @@ -2219,119 +2156,6 @@ void CpuMatrix::avgPoolBackward(Matrix& input, } } -void CpuMatrix::crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow) { - CHECK(isContiguous()); - CHECK(input.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK_EQ(getHeight(), input.getHeight()); - CHECK_EQ(getWidth(), input.getWidth()); - CHECK_EQ(getHeight(), denoms.getHeight()); - CHECK_EQ(getWidth(), denoms.getWidth()); - - size_t numSample = input.getHeight(); - size_t numCols = input.getWidth(); - size_t height = imgSizeH; - size_t width = imgSizeW; - CHECK(height * width * channels == numCols); - - // TODO(hedaoyuan) After commit TensorExpress code, - // Reconstruction this code to remove the temporary memory. - CpuMatrix tmp(channels, height * width); - CpuMatrix tmp2(tmp.getData(), 1, channels * height * width); - denoms.zero(); - const int start = -((int)sizeX - 1) / 2; - const int end = (int)sizeX + start; - for (size_t i = 0; i < numSample; i++) { - input.subMatrix(i, 1)->square2(tmp2); - CpuMatrix subDen( - denoms.subMatrix(i, 1)->getData(), channels, height * width); - for (int c = 0; c < (int)channels; c++) { - for (int s = start; s < end; s++) { - if (c + s >= 0 && c + s < (int)channels) { - subDen.subMatrix(c, 1)->add(*tmp.subMatrix(c + s, 1)); - } - } - } - } - - denoms.add(scale, (real)1); - this->pow2(denoms, -pow); - this->dotMul(input); -} - -void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow) { - CHECK(isContiguous()); - CHECK(localGrad.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK(preOutV.isContiguous()); - CHECK(localOutV.isContiguous()); - CHECK_EQ(getHeight(), localGrad.getHeight()); - CHECK_EQ(getWidth(), localGrad.getWidth()); - CHECK_EQ(getHeight(), denoms.getHeight()); - CHECK_EQ(getWidth(), denoms.getWidth()); - CHECK_EQ(getHeight(), preOutV.getHeight()); - CHECK_EQ(getWidth(), preOutV.getWidth()); - CHECK_EQ(getHeight(), localOutV.getHeight()); - CHECK_EQ(getWidth(), localOutV.getWidth()); - - size_t numSample = getHeight(); - size_t numCols = getWidth(); - size_t height = imgSizeH; - size_t width = imgSizeW; - CHECK(height * width * channels == numCols); - - // TODO(hedaoyuan) After commit TensorExpress code, - // Reconstruction this code to remove the temporary memory. - CpuMatrix tmp(1, height * width); - - const int start = -((int)sizeX) / 2; - const int end = (int)sizeX + start; - const real ratio = -(real)2 * scale * pow; - for (size_t i = 0; i < numSample; i++) { - CpuMatrix inputDiff( - this->subMatrix(i, 1)->getData(), channels, height * width); - CpuMatrix outDiff( - localGrad.subMatrix(i, 1)->getData(), channels, height * width); - CpuMatrix input( - preOutV.subMatrix(i, 1)->getData(), channels, height * width); - CpuMatrix output( - localOutV.subMatrix(i, 1)->getData(), channels, height * width); - CpuMatrix subDen( - denoms.subMatrix(i, 1)->getData(), channels, height * width); - - for (int c = 0; c < (int)channels; c++) { - tmp.pow2(*subDen.subMatrix(c, 1), -pow); - inputDiff.subMatrix(c, 1) - ->addDotMul(tmp, *outDiff.subMatrix(c, 1), (real)1, (real)1); - for (int s = start; s < end; s++) { - if (c + s >= 0 && c + s < (int)channels) { - tmp.dotMul(*outDiff.subMatrix(c + s, 1), *output.subMatrix(c + s, 1)); - tmp.mulScalar(ratio); - tmp.dotDiv(tmp, *subDen.subMatrix(c + s, 1)); - tmp.dotMul(*input.subMatrix(c, 1)); - inputDiff.subMatrix(c, 1)->add(tmp); - } - } - } - } -} - /** * Input: one or more sequences. Each sequence contains some instances. * Output: output size is the number of input sequences (NOT input instances). diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 5685cb7bcbbb6..62bc1b16fc7b6 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -952,31 +952,6 @@ class Matrix : public BaseMatrix { LOG(FATAL) << "Not implemeted"; } - /// normalize-operation. - virtual void crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow) { - LOG(FATAL) << "Not implemeted"; - } - - virtual void crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t size, - float scale, - float pow) { - LOG(FATAL) << "Not implemeted"; - } - /** * Input: one or more sequences. Each sequence contains some instances. * @@ -1459,26 +1434,6 @@ class GpuMatrix : public Matrix { size_t paddingH, size_t paddingW); - void crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow); - - void crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow); - void maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index); @@ -1685,26 +1640,6 @@ class CpuMatrix : public Matrix { size_t paddingH, size_t paddingW); - void crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow); - - void crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow); - void maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index da7a585484e7e..c89b7ff490232 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1385,33 +1385,6 @@ void testCrossMapNormalBwd( Tensor(denomsGpu.getData(), dims)}, {Tensor(inputsGradGpu.getData(), dims)}, {}); -#if 0 - CrossMapNormalGrad cpuCross; - cpuCross(inputsGrad, - inputsValue, - outputsGrad, - outputsValue, - denoms, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - - CrossMapNormalGrad gpuCross; - gpuCross(inputsGradGpu, - inputsValueGpu, - outputsGradGpu, - outputsValueGpu, - denomsGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); -#endif TensorCheckErr(inputsGrad, inputsGradGpu); } From f13aeb52e9fc666ac1e24acf5315cbdccf108402 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 20:12:53 +0800 Subject: [PATCH 20/55] fix swig_api --- paddle/api/CMakeLists.txt | 1 + paddle/api/paddle_ld_flags.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index 6ad1d79e59b11..ed69bd764f30a 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -46,6 +46,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp WORKING_DIRECTORY ${PROJ_ROOT}/paddle DEPENDS python_swig_sources paddle_parameter + paddle_function paddle_math paddle_utils paddle_gserver diff --git a/paddle/api/paddle_ld_flags.py b/paddle/api/paddle_ld_flags.py index 51d7dfee58b78..7c8206e3fe097 100644 --- a/paddle/api/paddle_ld_flags.py +++ b/paddle/api/paddle_ld_flags.py @@ -30,8 +30,8 @@ whole_end = "" LIB_DIRS = [ - "math", 'utils', 'parameter', "gserver", "api", "cuda", "pserver", - "trainer" + "math", 'function', 'utils', 'parameter', "gserver", "api", "cuda", + "pserver", "trainer" ] PARENT_LIB_DIRS = ['proto'] @@ -75,6 +75,7 @@ def libs_str(self): libs = [ whole_start, "-lpaddle_gserver", + "-lpaddle_function", whole_end, "-lpaddle_pserver", "-lpaddle_trainer_lib", From 1048aee0f7b32b27538f175112e42c9632642648 Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Thu, 15 Dec 2016 20:25:31 +0800 Subject: [PATCH 21/55] Add input layer check --- python/paddle/trainer/config_parser.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 0f7c601fe0d29..83fda9f709223 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1584,6 +1584,13 @@ def __init__(self, name, inputs, size, min_size, max_size, aspect_ratio, variance): super(PriorBoxLayer, self).__init__(name, 'priorbox', 0, inputs) config_assert(len(inputs) == 2, 'PriorBoxLayer must have 2 input') + input_layer = self.get_input_layer(1) + config_assert( + input_layer.type == 'data', + 'Expecting the second input layer of an priorbox layer to be ' + 'a data layer') + config_assert(input_layer.width > 0, 'The data layer must set width') + config_assert(input_layer.height > 0, 'The data layer must set height') self.config.inputs[0].priorbox_conf.min_size.extend(min_size) self.config.inputs[0].priorbox_conf.max_size.extend(max_size) self.config.inputs[0].priorbox_conf.aspect_ratio.extend(aspect_ratio) From cee934680467c50d4084dbaf7273a39a40cc832d Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 21:23:05 +0800 Subject: [PATCH 22/55] add some comments --- paddle/function/cross_map_normal_op.cpp | 5 ++- paddle/function/cross_map_normal_op.h | 34 +++++++++++++++++++ paddle/gserver/layers/Layer.h | 6 ++++ paddle/gserver/layers/NormProjectionLayer.cpp | 18 ++++------ paddle/gserver/layers/NormProjectionLayer.h | 3 -- 5 files changed, 50 insertions(+), 16 deletions(-) diff --git a/paddle/function/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp index 0391a58d89f4a..a18c0bb750acf 100644 --- a/paddle/function/cross_map_normal_op.cpp +++ b/paddle/function/cross_map_normal_op.cpp @@ -17,7 +17,6 @@ limitations under the License. */ namespace paddle { -// NCHW template <> void CrossMapNormal(real* outputs, real* denoms, @@ -36,6 +35,10 @@ void CrossMapNormal(real* outputs, CpuVector inputsV(numSamples * oneSample, inputs); CpuVector denomsV(numSamples * oneSample, denoms); + // f(x) = x * ( 1 + scale * SUM((x)^2) )^(-pow) + // x represents inputs + // f(x) represents outputs + // denoms save the intermediate result for backward denomsV = denomsV.constant(1.0); const int start = -((int)size - 1) / 2; const int end = (int)size + start; diff --git a/paddle/function/cross_map_normal_op.h b/paddle/function/cross_map_normal_op.h index f065208084f1d..e935b26e125d3 100644 --- a/paddle/function/cross_map_normal_op.h +++ b/paddle/function/cross_map_normal_op.h @@ -18,6 +18,22 @@ limitations under the License. */ namespace paddle { +/** + * \brief Cross map respose normalize forward. + * The data structure of image data is NCHW. + * + * \param[out] outputs output data. + * \param[in] denoms denoms buffer. + * \param[in] inputs input data. + * \param[in] numSamples batch size of input image. + * \param[in] channels number of channel. + * \param[in] height image height. + * \param[in] width image width. + * \param[in] size size. + * \param[in] scale scale. + * \param[in] pow scale. + * + */ template void CrossMapNormal(real* outputs, real* denoms, @@ -30,6 +46,24 @@ void CrossMapNormal(real* outputs, real scale, real pow); +/** + * \brief Cross map respose normalize backward. + * The data structure of image data is NCHW. + * + * \param[out] inputsGrad input grad. + * \param[in] inputsValue input value. + * \param[out] outputsValue output value. + * \param[out] outputsGrad output grad. + * \param[in] denoms denoms buffer. + * \param[in] numSamples batch size of input image. + * \param[in] channels number of channel. + * \param[in] height image height. + * \param[in] width image width. + * \param[in] size size. + * \param[in] scale scale. + * \param[in] pow scale. + * + */ template void CrossMapNormalGrad(real* inputsGrad, real* inputsValue, diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h index 172e558b82945..16f66a2205f49 100644 --- a/paddle/gserver/layers/Layer.h +++ b/paddle/gserver/layers/Layer.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include "ModelConfig.pb.h" +#include "paddle/function/Function.h" #include "paddle/math/CpuSparseMatrix.h" #include "paddle/parameter/Parameter.h" #include "paddle/utils/ClassRegistrar.h" @@ -100,6 +101,11 @@ class Layer { /// Mark input grad in(true) or out(false) of backward function. std::vector markInBackward_; + /// Layer forward function + FunctionBase* forward_; + /// Layer backward function + FunctionBase* backward_; + public: /** * Wait until all input value ready. diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index 4ff3b805fbb06..0f6f9b91d0578 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -48,20 +48,17 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, if (useGpu_) { forward_ = FunctionBase::funcRegistrar_.createByType( FUNC_NAME(CrossMapNormal, GPU)); + backward_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, GPU)); } else { forward_ = FunctionBase::funcRegistrar_.createByType( FUNC_NAME(CrossMapNormal, CPU)); + backward_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, CPU)); } forward_->init( FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); - if (useGpu_) { - backward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, GPU)); - } else { - backward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, CPU)); - } backward_->init( FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); @@ -74,7 +71,7 @@ void CMRProjectionNormLayer::forward(PassType passType) { /* malloc memory for the output_ if necessary */ /* note: one sample correspond to one row */ MatrixPtr input = inputLayers_[0]->getOutputValue(); - int batchSize = input->getHeight(); + size_t batchSize = input->getHeight(); int size = getSize(); resetOutput(batchSize, size); @@ -82,10 +79,7 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - dims_ = {(size_t)batchSize, - (size_t)channels_, - (size_t)imgSizeH_, - (size_t)imgSizeW_}; + dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_}; forward_->calc( {Tensor(input->getData(), dims_)}, {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 3c4876ece609e..6b2c5dde0d74d 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,7 +16,6 @@ limitations under the License. */ #include #include "NormLayer.h" -#include "paddle/function/Function.h" #include "paddle/math/Matrix.h" namespace paddle { @@ -43,7 +42,5 @@ class CMRProjectionNormLayer : public ResponseNormLayer { protected: Dims dims_; - FunctionBase* forward_; - FunctionBase* backward_; }; } // namespace paddle From 5222b586e2db3a4dc46cacf884afae9e4d6e51f2 Mon Sep 17 00:00:00 2001 From: yangwenbo02 Date: Fri, 16 Dec 2016 15:43:40 +0800 Subject: [PATCH 23/55] support UBUNTU MIRROR and modify doc --- .../build_and_install/docker_install_en.rst | 16 ++++++++++++++++ paddle/scripts/docker/Dockerfile | 2 ++ paddle/scripts/docker/Dockerfile.gpu | 2 ++ 3 files changed, 20 insertions(+) diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index 7633bf4d576ee..1252ff3974e02 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -142,6 +142,22 @@ to install CUDA driver and let Docker knows about it: export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}') docker run ${CUDA_SO} ${DEVICES} -it paddledev/paddle:gpu-latest + +UBUNTU MIRROR +------------- + +Building Paddle Docker image hits some wrong with apt-get update, you +can use other UBUNTU MIRROR instead of the default + +.. code-block:: bash + + cd ~ + git clone https://github.com/PaddlePaddle/Paddle.git + cd Paddle + git submodule update --init --recursive + docker build --build-arg UBUNTU_MIRROR="http://mirrors.163.com" -t paddle:cpu-avx -f paddle/scripts/docker/Dockerfile . + docker build --build-arg UBUNTU_MIRROR="http://mirrors.163.com" -t paddle:gpu-avx -f paddle/scripts/docker/Dockerfile.gpu . + Non-AVX Images -------------- diff --git a/paddle/scripts/docker/Dockerfile b/paddle/scripts/docker/Dockerfile index 207f97c4a69e6..f26055d0d4c99 100644 --- a/paddle/scripts/docker/Dockerfile +++ b/paddle/scripts/docker/Dockerfile @@ -2,6 +2,8 @@ FROM ubuntu:14.04 MAINTAINER PaddlePaddle Authors ARG DEBIAN_FRONTEND=noninteractive +ARG UBUNTU_MIRROR +RUN /bin/bash -c 'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i 's#http://archive.ubuntu.com#${UBUNTU_MIRROR}#g' /etc/apt/sources.list; fi' RUN apt-get update \ && apt-get install -y cmake libprotobuf-dev protobuf-compiler git \ libgoogle-glog-dev libgflags-dev libgtest-dev \ diff --git a/paddle/scripts/docker/Dockerfile.gpu b/paddle/scripts/docker/Dockerfile.gpu index 33f6adfea2a60..d13b97714727a 100644 --- a/paddle/scripts/docker/Dockerfile.gpu +++ b/paddle/scripts/docker/Dockerfile.gpu @@ -2,6 +2,8 @@ FROM nvidia/cuda:7.5-cudnn5-devel-ubuntu14.04 MAINTAINER PaddlePaddle Authors ARG DEBIAN_FRONTEND=noninteractive +ARG UBUNTU_MIRROR +RUN /bin/bash -c 'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i 's#http://archive.ubuntu.com#${UBUNTU_MIRROR}#g' /etc/apt/sources.list; fi' RUN apt-get update \ && apt-get install -y cmake libprotobuf-dev protobuf-compiler git \ libgoogle-glog-dev libgflags-dev libgtest-dev \ From 5b746fb183572bc04a0697f3ef9d043849506862 Mon Sep 17 00:00:00 2001 From: yangwenbo02 Date: Fri, 16 Dec 2016 17:23:24 +0800 Subject: [PATCH 24/55] modify doc doc/getstarted/build_and_install/docker_install_en.rst --- .../build_and_install/docker_install_en.rst | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index 1252ff3974e02..ffda7964702e7 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -39,12 +39,18 @@ The general development workflow with Docker and Bazel is as follows: code. This image contains all the development tools and dependencies of PaddlePaddle. - .. code-block:: bash cd paddle docker build -t paddle:dev -f paddle/scripts/docker/Dockerfile . + Apt-get source errors may occur when building paddle docker image. + **You can specify the UBUNTU MIRROR with :code:`--build-arg UBUNTU_MIRROR` like the example below.** + + .. code-block:: bash + + docker build --build-arg UBUNTU_MIRROR="http://mirrors.163.com" -t paddle:dev -f paddle/scripts/docker/Dockerfile . + 3. Run the image as a container and mounting local source code directory into the container. This allows us to change the code on @@ -142,22 +148,6 @@ to install CUDA driver and let Docker knows about it: export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}') docker run ${CUDA_SO} ${DEVICES} -it paddledev/paddle:gpu-latest - -UBUNTU MIRROR -------------- - -Building Paddle Docker image hits some wrong with apt-get update, you -can use other UBUNTU MIRROR instead of the default - -.. code-block:: bash - - cd ~ - git clone https://github.com/PaddlePaddle/Paddle.git - cd Paddle - git submodule update --init --recursive - docker build --build-arg UBUNTU_MIRROR="http://mirrors.163.com" -t paddle:cpu-avx -f paddle/scripts/docker/Dockerfile . - docker build --build-arg UBUNTU_MIRROR="http://mirrors.163.com" -t paddle:gpu-avx -f paddle/scripts/docker/Dockerfile.gpu . - Non-AVX Images -------------- From 36af605a2d13f7be0a8d326144b88d7d2ed5d242 Mon Sep 17 00:00:00 2001 From: yangwenbo02 Date: Fri, 16 Dec 2016 17:33:14 +0800 Subject: [PATCH 25/55] modify doc --- doc/getstarted/build_and_install/docker_install_en.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index ffda7964702e7..1cc23ac3aa989 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -45,11 +45,14 @@ The general development workflow with Docker and Bazel is as follows: docker build -t paddle:dev -f paddle/scripts/docker/Dockerfile . Apt-get source errors may occur when building paddle docker image. - **You can specify the UBUNTU MIRROR with :code:`--build-arg UBUNTU_MIRROR` like the example below.** + **You can specify the UBUNTU MIRROR with** :code:`--build-arg UBUNTU_MIRROR` **like the example below.** .. code-block:: bash - docker build --build-arg UBUNTU_MIRROR="http://mirrors.163.com" -t paddle:dev -f paddle/scripts/docker/Dockerfile . + docker build \ + --build-arg UBUNTU_MIRROR="http://mirrors.163.com" \ + -t paddle:dev \ + -f paddle/scripts/docker/Dockerfile . 3. Run the image as a container and mounting local source code From 9f990d9059dcf1b3536c3060670121a2fe67ce66 Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Fri, 16 Dec 2016 19:15:03 +0800 Subject: [PATCH 26/55] Add unittest of the priorbox layer --- paddle/gserver/layers/PriorBox.cpp | 1 + paddle/gserver/tests/CMakeLists.txt | 8 ++ paddle/gserver/tests/test_PriorBox.cpp | 160 +++++++++++++++++++++++++ python/paddle/trainer/config_parser.py | 3 +- 4 files changed, 171 insertions(+), 1 deletion(-) create mode 100644 paddle/gserver/tests/test_PriorBox.cpp diff --git a/paddle/gserver/layers/PriorBox.cpp b/paddle/gserver/layers/PriorBox.cpp index dd52f61c30ece..ca61dfec5faa0 100644 --- a/paddle/gserver/layers/PriorBox.cpp +++ b/paddle/gserver/layers/PriorBox.cpp @@ -76,6 +76,7 @@ void PriorBoxLayer::forward(PassType passType) { auto image = getInput(1); int imageWidth = image.getFrameWidth(); int imageHeight = image.getFrameHeight(); + float stepW = static_cast(imageWidth) / layerWidth; float stepH = static_cast(imageHeight) / layerHeight; int dim = layerHeight * layerWidth * numPriors_ * 4; diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index 34dc375f21a54..c26a2a7f06bc1 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -34,6 +34,14 @@ add_unittest_without_exec(test_ConvTrans add_test(NAME test_ConvTrans COMMAND test_ConvTrans) +################# test_PriorBox ####################### +add_unittest_without_exec(test_PriorBox + test_PriorBox.cpp + LayerGradUtil.cpp + TestUtil.cpp) + +add_test(NAME test_PriorBox + COMMAND test_PriorBox) ################# test_ConvUnify ####################### add_unittest_without_exec(test_ConvUnify test_ConvUnify.cpp diff --git a/paddle/gserver/tests/test_PriorBox.cpp b/paddle/gserver/tests/test_PriorBox.cpp new file mode 100644 index 0000000000000..fd63be2f8e4e1 --- /dev/null +++ b/paddle/gserver/tests/test_PriorBox.cpp @@ -0,0 +1,160 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "LayerGradUtil.h" +#include "TestUtil.h" + +using namespace paddle; // NOLINT +using namespace std; // NOLINT + +P_DECLARE_bool(use_gpu); +P_DECLARE_int32(gpu_id); +P_DECLARE_bool(thread_local_rand_use_global_seed); + +// Do one forward pass of priorBox layer and check to see if its output +// matches the given result +void doOnePriorBoxTest(size_t featureMapWidth, + size_t featureMapHeight, + size_t imageWidth, + size_t imageHeight, + vector minSize, + vector maxSize, + vector aspectRatio, + vector variance, + MatrixPtr& result) { + // Setting up the priorbox layer + TestConfig configt; + configt.layerConfig.set_type("priorbox"); + + configt.inputDefs.push_back({INPUT_DATA, "featureMap", 1, 0}); + LayerInputConfig* input = configt.layerConfig.add_inputs(); + configt.inputDefs.push_back({INPUT_DATA, "image", 1, 0}); + configt.layerConfig.add_inputs(); + PriorBoxConfig* pb = input->mutable_priorbox_conf(); + for (size_t i = 0; i < minSize.size(); i++) pb->add_min_size(minSize[i]); + for (size_t i = 0; i < maxSize.size(); i++) pb->add_max_size(maxSize[i]); + for (size_t i = 0; i < aspectRatio.size(); i++) + pb->add_aspect_ratio(aspectRatio[i]); + for (size_t i = 0; i < variance.size(); i++) pb->add_variance(variance[i]); + + // data layer initialize + std::vector dataLayers; + LayerMap layerMap; + vector datas; + initDataLayer( + configt, &dataLayers, &datas, &layerMap, "priorbox", 1, false, true); + dataLayers[0]->getOutput().setFrameHeight(featureMapHeight); + dataLayers[0]->getOutput().setFrameWidth(featureMapWidth); + dataLayers[1]->getOutput().setFrameHeight(imageHeight); + dataLayers[1]->getOutput().setFrameWidth(imageWidth); + + // test layer initialize + std::vector parameters; + LayerPtr priorboxLayer; + initTestLayer(configt, &layerMap, ¶meters, &priorboxLayer); + + priorboxLayer->forward(PASS_GC); + checkMatrixEqual(priorboxLayer->getOutputValue(), result); +} + +TEST(Layer, priorBoxLayerFwd) { + vector minSize; + vector maxSize; + vector aspectRatio; + vector variance; + + minSize.push_back(276); + maxSize.push_back(330); + variance.push_back(0.1); + variance.push_back(0.1); + variance.push_back(0.2); + variance.push_back(0.2); + + MatrixPtr result; + result = Matrix::create(1, 2 * 8, false, false); + + float resultData[] = {0.04, + 0.04, + 0.96, + 0.96, + 0.1, + 0.1, + 0.2, + 0.2, + 0, + 0, + 1, + 1, + 0.1, + 0.1, + 0.2, + 0.2}; + result->setData(resultData); + doOnePriorBoxTest(/* featureMapWidth */ 1, + /* featureMapHeight */ 1, + /* imageWidth */ 300, + /* imageHeight */ 300, + minSize, + maxSize, + aspectRatio, + variance, + result); + + variance[1] = 0.2; + variance[3] = 0.1; + maxSize.pop_back(); + Matrix::resizeOrCreate(result, 1, 4 * 8, false, false); + float resultData2[] = {0, 0, 0.595, 0.595, 0.1, 0.2, 0.2, 0.1, + 0.405, 0, 1, 0.595, 0.1, 0.2, 0.2, 0.1, + 0, 0.405, 0.595, 1, 0.1, 0.2, 0.2, 0.1, + 0.405, 0.405, 1, 1, 0.1, 0.2, 0.2, 0.1}; + result->setData(resultData2); + doOnePriorBoxTest(/* featureMapWidth */ 2, + /* featureMapHeight */ 2, + /* imageWidth */ 400, + /* imageHeight */ 400, + minSize, + maxSize, + aspectRatio, + variance, + result); + + aspectRatio.push_back(2); + Matrix::resizeOrCreate(result, 1, 3 * 8, false, false); + float resultData3[] = {0.04, 0.04, 0.96, 0.96, 0.1, 0.2, + 0.2, 0.1, 0, 0.17473088, 1, 0.825269, + 0.1, 0.2, 0.2, 0.1, 0.17473088, 0, + 0.825269, 1, 0.1, 0.2, 0.2, 0.1}; + result->setData(resultData3); + doOnePriorBoxTest(/* featureMapWidth */ 1, + /* featureMapHeight */ 1, + /* imageWidth */ 300, + /* imageHeight */ 300, + minSize, + maxSize, + aspectRatio, + variance, + result); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + initMain(argc, argv); + FLAGS_thread_local_rand_use_global_seed = true; + srand(1); + return RUN_ALL_TESTS(); +} diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 83fda9f709223..4056ecdf795ae 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1583,7 +1583,7 @@ class PriorBoxLayer(LayerBase): def __init__(self, name, inputs, size, min_size, max_size, aspect_ratio, variance): super(PriorBoxLayer, self).__init__(name, 'priorbox', 0, inputs) - config_assert(len(inputs) == 2, 'PriorBoxLayer must have 2 input') + config_assert(len(inputs) == 2, 'PriorBoxLayer must have 2 inputs') input_layer = self.get_input_layer(1) config_assert( input_layer.type == 'data', @@ -1591,6 +1591,7 @@ def __init__(self, name, inputs, size, min_size, max_size, aspect_ratio, 'a data layer') config_assert(input_layer.width > 0, 'The data layer must set width') config_assert(input_layer.height > 0, 'The data layer must set height') + config_assert(len(variance) == 4, 'The variance must have 4 inputs') self.config.inputs[0].priorbox_conf.min_size.extend(min_size) self.config.inputs[0].priorbox_conf.max_size.extend(max_size) self.config.inputs[0].priorbox_conf.aspect_ratio.extend(aspect_ratio) From 8d9f67591022655fb62401c470825b319573920c Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Fri, 16 Dec 2016 19:54:09 +0800 Subject: [PATCH 27/55] Add header files --- paddle/gserver/tests/test_PriorBox.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/gserver/tests/test_PriorBox.cpp b/paddle/gserver/tests/test_PriorBox.cpp index fd63be2f8e4e1..8aabb1ef97a21 100644 --- a/paddle/gserver/tests/test_PriorBox.cpp +++ b/paddle/gserver/tests/test_PriorBox.cpp @@ -12,8 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include +#include "ModelConfig.pb.h" +#include "paddle/gserver/layers/DataLayer.h" +#include "paddle/math/MathUtils.h" +#include "paddle/trainer/Trainer.h" +#include "paddle/utils/GlobalConstants.h" #include "LayerGradUtil.h" #include "TestUtil.h" From d40bb72da8345a25de087ea4eea07984f675d241 Mon Sep 17 00:00:00 2001 From: livc Date: Fri, 16 Dec 2016 20:38:32 +0800 Subject: [PATCH 28/55] modify rnn_config_cn.rst --- doc/howto/deep_model/rnn_config_cn.rst | 287 +++++++++++++++++++++++++ 1 file changed, 287 insertions(+) create mode 100644 doc/howto/deep_model/rnn_config_cn.rst diff --git a/doc/howto/deep_model/rnn_config_cn.rst b/doc/howto/deep_model/rnn_config_cn.rst new file mode 100644 index 0000000000000..e6d8c1133a5e8 --- /dev/null +++ b/doc/howto/deep_model/rnn_config_cn.rst @@ -0,0 +1,287 @@ +RNN 配置 +======== + +本教程将指导你如何在 PaddlePaddle +中配置循环神经网络(RNN)。PaddlePaddle +高度支持灵活和高效的循环神经网络配置。 在本教程中,您将了解如何: + +- 准备用来学习循环神经网络的序列数据。 +- 配置循环神经网络架构。 +- 使用学习完成的循环神经网络模型生成序列。 + +我们将使用 vanilla 循环神经网络和 sequence to sequence +模型来指导你完成这些步骤。sequence to sequence +模型的代码可以在\ ``demo / seqToseq``\ 找到。 + +准备序列数据 +------------ + +PaddlePaddle +不需要对序列数据进行任何预处理,例如填充。唯一需要做的是将相应类型设置为输入。例如,以下代码段定义了三个输入。 +它们都是序列,它们的大小是\ ``src_dict``\ ,\ ``trg_dict``\ 和\ ``trg_dict``\ : + +.. code:: sourcecode + + settings.input_types = [ + integer_value_sequence(len(settings.src_dict)), + integer_value_sequence(len(settings.trg_dict)), + integer_value_sequence(len(settings.trg_dict))] + +在\ ``process``\ 函数中,每个\ ``yield``\ 函数将返回三个整数列表。每个整数列表被视为一个整数序列: + +.. code:: sourcecode + + yield src_ids, trg_ids, trg_ids_next + +有关如何编写数据提供程序的更多细节描述,请参考 +`PyDataProvider2 <../../ui/data_provider/index.html>`__\ 。完整的数据提供文件在 +``demo/seqToseq/dataprovider.py``\ 。 + +配置循环神经网络架构 +-------------------- + +简单门控循环神经网络(Gated Recurrent Neural Network) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +循环神经网络在每个时间步骤顺序地处理序列。下面列出了 LSTM 的架构的示例。 + +.. figure:: ../../../tutorials/sentiment_analysis/bi_lstm.jpg + :alt: image + + image + +一般来说,循环网络从 *t* = 1 到 *t* = *T* 或者反向地从 *t* = *T* 到 *t* += 1 执行以下操作。 + +*x*\ \ *t* + 1 = *f*\ \ *x*\ (*x*\ \ *t*\ ),\ *y*\ \ *t*\  = *f*\ \ *y*\ (*x*\ \ *t*\ ) + +其中 *f*\ \ *x*\ (.) 称为\ **单步函数**\ (即单时间步执行的函数,step +function),而 *f*\ \ *y*\ (.) 称为\ **输出函数**\ 。在 vanilla +循环神经网络中,单步函数和输出函数都非常简单。然而,PaddlePaddle +可以通过修改这两个函数来实现复杂的网络配置。我们将使用 sequence to +sequence +模型演示如何配置复杂的循环神经网络模型。在本节中,我们将使用简单的 +vanilla +循环神经网络作为使用\ ``recurrent_group``\ 配置简单循环神经网络的例子。 +注意,如果你只需要使用简单的RNN,GRU或LSTM,那么推荐使用\ ``grumemory``\ 和\ ``lstmemory``\ ,因为它们的计算效率比\ ``recurrent_group``\ 更高。 + +对于 vanilla RNN,在每个时间步长,\ **单步函数**\ 为: + +*x*\ \ *t* + 1 = *W*\ \ *x*\ \ *x*\ \ *t*\  + *W*\ \ *i*\ \ *I*\ \ *t*\  + *b* + +其中 *x*\ \ *t*\ 是RNN状态,并且 *I*\ \ *t*\ 是输入,\ *W*\ \ *x*\ 和 +*W*\ \ *i*\ 分别是RNN状态和输入的变换矩阵。\ *b* +是偏差。它的\ **输出函数**\ 只需要\ *x*\ \ *t*\ 作为输出。 + +``recurrent_group``\ 是构建循环神经网络的最重要的工具。 +它定义了\ **单步函数**\ ,\ **输出函数**\ 和循环神经网络的输入。注意,这个函数的\ ``step``\ 参数需要实现\ ``step function``\ (单步函数)和\ ``output function``\ (输出函数): + +.. code:: sourcecode + + def simple_rnn(input, + size=None, + name=None, + reverse=False, + rnn_bias_attr=None, + act=None, + rnn_layer_attr=None): + def __rnn_step__(ipt): + out_mem = memory(name=name, size=size) + rnn_out = mixed_layer(input = [full_matrix_projection(ipt), + full_matrix_projection(out_mem)], + name = name, + bias_attr = rnn_bias_attr, + act = act, + layer_attr = rnn_layer_attr, + size = size) + return rnn_out + return recurrent_group(name='%s_recurrent_group' % name, + step=__rnn_step__, + reverse=reverse, + input=input) + +PaddlePaddle +使用“Memory”(记忆模块)实现单步函数。\ **Memory**\ 是在PaddlePaddle中构造循环神经网络时最重要的概念。 +Memory是在单步函数中循环使用的状态,例如\ *x*\ \ *t* + 1 = *f*\ \ *x*\ (*x*\ \ *t*\ )。 +一个Memory包含\ **输出**\ 和\ **输入**\ 。当前时间步处的Memory的输出作为下一时间步Memory的输入。Memory也可以具有\ **boot +layer(引导层)**\ ,其输出被用作Memory的初始值。 +在我们的例子中,门控循环单元的输出被用作输出Memory。请注意,\ ``rnn_out``\ 层的名称与\ ``out_mem``\ 的名称相同。这意味着\ ``rnn_out`` +(*x*\ \ *t* + 1)的输出被用作\ ``out_mem``\ Memory的\ **输出**\ 。 + +Memory也可以是序列。在这种情况下,在每个时间步中,我们有一个序列作为循环神经网络的状态。这在构造非常复杂的循环神经网络时是有用的。 +其他高级功能包括定义多个Memory,以及使用子序列来定义分级循环神经网络架构。 + +我们在函数的结尾返回\ ``rnn_out``\ 。 这意味着 ``rnn_out`` +层的输出被用作门控循环神经网络的\ **输出**\ 函数。 + +Sequence to Sequence Model with Attention +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +我们将使用 sequence to sequence model with attention +作为例子演示如何配置复杂的循环神经网络模型。该模型的说明如下图所示。 + +.. figure:: ../../../tutorials/text_generation/encoder-decoder-attention-model.png + :alt: image + + image + +在这个模型中,源序列 *S* = {*s*\ 1, …, \ *s*\ \ *T*\ } +用双向门控循环神经网络编码。双向门控循环神经网络的隐藏状态 +*H*\ \ *S*\  = {*H*\ 1, …, \ *H*\ \ *T*\ } 被称为 +*编码向量*\ 。解码器是门控循环神经网络。当解读每一个\ *y*\ \ *t*\ 时, +这个门控循环神经网络生成一系列权重 +*W*\ \ *S*\ \ *t*\  = {*W*\ 1\ *t*\ , …, \ *W*\ \ *T*\ \ *t*\ }, +用于计算编码向量的加权和。加权和用来生成\ *y*\ \ *t*\ 。 + +模型的编码器部分如下所示。它叫做\ ``grumemory``\ 来表示门控循环神经网络。如果网络架构简单,那么推荐使用循环神经网络的方法,因为它比 +``recurrent_group`` +更快。我们已经实现了大多数常用的循环神经网络架构,可以参考 +`Layers <../../ui/api/trainer_config_helpers/layers_index.html>`__ +了解更多细节。 + +我们还将编码向量投射到 ``decoder_size`` +维空间。这通过获得反向循环网络的第一个实例,并将其投射到 +``decoder_size`` 维空间完成: + +.. code:: sourcecode + + # 定义源语句的数据层 + src_word_id = data_layer(name='source_language_word', size=source_dict_dim) + # 计算每个词的词向量 + src_embedding = embedding_layer( + input=src_word_id, + size=word_vector_dim, + param_attr=ParamAttr(name='_source_language_embedding')) + # 应用前向循环神经网络 + src_forward = grumemory(input=src_embedding, size=encoder_size) + # 应用反向递归神经网络(reverse=True表示反向循环神经网络) + src_backward = grumemory(input=src_embedding, + size=encoder_size, + reverse=True) + # 将循环神经网络的前向和反向部分混合在一起 + encoded_vector = concat_layer(input=[src_forward, src_backward]) + + # 投射编码向量到 decoder_size + encoder_proj = mixed_layer(input = [full_matrix_projection(encoded_vector)], + size = decoder_size) + + # 计算反向RNN的第一个实例 + backward_first = first_seq(input=src_backward) + + # 投射反向RNN的第一个实例到 decoder size + decoder_boot = mixed_layer(input=[full_matrix_projection(backward_first)], size=decoder_size, act=TanhActivation()) + +解码器使用 ``recurrent_group`` 来定义循环神经网络。单步函数和输出函数在 +``gru_decoder_with_attention`` 中定义: + +.. code:: sourcecode + + group_inputs=[StaticInput(input=encoded_vector,is_seq=True), + StaticInput(input=encoded_proj,is_seq=True)] + trg_embedding = embedding_layer( + input=data_layer(name='target_language_word', + size=target_dict_dim), + size=word_vector_dim, + param_attr=ParamAttr(name='_target_language_embedding')) + group_inputs.append(trg_embedding) + + # 对于配备有注意力机制的解码器,在训练中, + # 目标向量(groudtruth)是数据输入, + # 而源序列的编码向量可以被无边界的memory访问 + # StaticInput 意味着不同时间步的输入都是相同的值, + # 否则它以一个序列输入,不同时间步的输入是不同的。 + # 所有输入序列应该有相同的长度。 + decoder = recurrent_group(name=decoder_group_name, + step=gru_decoder_with_attention, + input=group_inputs) + +单步函数的实现如下所示。首先,它定义解码网络的\ **Memory**\ 。然后定义 +attention,门控循环单元单步函数和输出函数: + +.. code:: sourcecode + + def gru_decoder_with_attention(enc_vec, enc_proj, current_word): + # 定义解码器的Memory + # Memory的输出定义在 gru_step 内 + # 注意 gru_step 应该与它的Memory名字相同 + decoder_mem = memory(name='gru_decoder', + size=decoder_size, + boot_layer=decoder_boot) + # 计算 attention 加权编码向量 + context = simple_attention(encoded_sequence=enc_vec, + encoded_proj=enc_proj, + decoder_state=decoder_mem) + # 混合当前词向量和attention加权编码向量 + decoder_inputs = mixed_layer(inputs = [full_matrix_projection(context), + full_matrix_projection(current_word)], + size = decoder_size * 3) + # 定义门控循环单元循环神经网络单步函数 + gru_step = gru_step_layer(name='gru_decoder', + input=decoder_inputs, + output_mem=decoder_mem, + size=decoder_size) + # 定义输出函数 + out = mixed_layer(input=[full_matrix_projection(input=gru_step)], + size=target_dict_dim, + bias_attr=True, + act=SoftmaxActivation()) + return out + +生成序列 +-------- + +训练模型后,我们可以使用它来生成序列。通常的做法是使用\ **beam search** +生成序列。以下代码片段定义 beam search 算法。注意,\ ``beam_search`` +函数假设 ``step`` 的输出函数返回的是下一个时刻输出词的 softmax +归一化概率向量。我们对模型进行了以下更改。 + +- 使用 ``GeneratedInput`` 来表示 trg\_embedding。 ``GeneratedInput`` + 将上一时间步所生成的词的向量来作为当前时间步的输入。 +- 使用 ``beam_search`` 函数。这个函数需要设置: + + - ``bos_id``: 开始标记。每个句子都以开始标记开头。 + - ``eos_id``: 结束标记。每个句子都以结束标记结尾。 + - ``beam_size``: beam search 算法中的beam大小。 + - ``max_length``: 生成序列的最大长度。 + +- 使用 ``seqtext_printer_evaluator`` + 根据索引矩阵和字典打印文本。这个函数需要设置: + + - ``id_input``: 数据的整数ID,用于标识生成的文件中的相应输出。 + - ``dict_file``: 用于将词ID转换为词的字典文件。 + - ``result_file``: 生成结果文件的路径。 + +代码如下: + +.. code:: sourcecode + + group_inputs=[StaticInput(input=encoded_vector,is_seq=True), + StaticInput(input=encoded_proj,is_seq=True)] + # 在生成时,解码器基于编码源序列和最后生成的目标词预测下一目标词。 + # 编码源序列(编码器输出)必须由只读Memory的 StaticInput 指定。 + # 这里, GeneratedInputs 自动获取上一个生成的词,并在最开始初始化为起始词,如 。 + trg_embedding = GeneratedInput( + size=target_dict_dim, + embedding_name='_target_language_embedding', + embedding_size=word_vector_dim) + group_inputs.append(trg_embedding) + beam_gen = beam_search(name=decoder_group_name, + step=gru_decoder_with_attention, + input=group_inputs, + bos_id=0, # Beginnning token. + eos_id=1, # End of sentence token. + beam_size=beam_size, + max_length=max_length) + + seqtext_printer_evaluator(input=beam_gen, + id_input=data_layer(name="sent_id", size=1), + dict_file=trg_dict_path, + result_file=gen_trans_file) + outputs(beam_gen) + +注意,这种生成技术只用于类似解码器的生成过程。如果你正在处理序列标记任务,请参阅 +`Semantic Role Labeling +Demo <../../demo/semantic_role_labeling/index.html>`__ +了解更多详细信息。 + +完整的配置文件在\ ``demo/seqToseq/seqToseq_net.py``\ 。 From cad325f09ae8c1e2272b79a0c0b30298e891350e Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Fri, 16 Dec 2016 21:03:05 +0800 Subject: [PATCH 29/55] Add header file --- paddle/gserver/tests/test_PriorBox.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/gserver/tests/test_PriorBox.cpp b/paddle/gserver/tests/test_PriorBox.cpp index 8aabb1ef97a21..d37c0bb70209a 100644 --- a/paddle/gserver/tests/test_PriorBox.cpp +++ b/paddle/gserver/tests/test_PriorBox.cpp @@ -15,8 +15,10 @@ limitations under the License. */ #include #include #include +#include "./paddle/utils/CommandLineParser.h" #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" +#include "paddle/gserver/layers/ExpandConvTransLayer.h" #include "paddle/math/MathUtils.h" #include "paddle/trainer/Trainer.h" #include "paddle/utils/GlobalConstants.h" @@ -29,7 +31,9 @@ using namespace std; // NOLINT P_DECLARE_bool(use_gpu); P_DECLARE_int32(gpu_id); +P_DECLARE_double(checkgrad_eps); P_DECLARE_bool(thread_local_rand_use_global_seed); +P_DECLARE_bool(prev_batch_state); // Do one forward pass of priorBox layer and check to see if its output // matches the given result From 38723e778dbdec32d98b6a191da0e2ea94f0f3c5 Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Mon, 19 Dec 2016 10:56:22 +0800 Subject: [PATCH 30/55] remove random flag --- paddle/gserver/tests/test_PriorBox.cpp | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/paddle/gserver/tests/test_PriorBox.cpp b/paddle/gserver/tests/test_PriorBox.cpp index d37c0bb70209a..1a7217ab943f7 100644 --- a/paddle/gserver/tests/test_PriorBox.cpp +++ b/paddle/gserver/tests/test_PriorBox.cpp @@ -15,13 +15,6 @@ limitations under the License. */ #include #include #include -#include "./paddle/utils/CommandLineParser.h" -#include "ModelConfig.pb.h" -#include "paddle/gserver/layers/DataLayer.h" -#include "paddle/gserver/layers/ExpandConvTransLayer.h" -#include "paddle/math/MathUtils.h" -#include "paddle/trainer/Trainer.h" -#include "paddle/utils/GlobalConstants.h" #include "LayerGradUtil.h" #include "TestUtil.h" @@ -29,12 +22,6 @@ limitations under the License. */ using namespace paddle; // NOLINT using namespace std; // NOLINT -P_DECLARE_bool(use_gpu); -P_DECLARE_int32(gpu_id); -P_DECLARE_double(checkgrad_eps); -P_DECLARE_bool(thread_local_rand_use_global_seed); -P_DECLARE_bool(prev_batch_state); - // Do one forward pass of priorBox layer and check to see if its output // matches the given result void doOnePriorBoxTest(size_t featureMapWidth, @@ -164,7 +151,5 @@ TEST(Layer, priorBoxLayerFwd) { int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); - FLAGS_thread_local_rand_use_global_seed = true; - srand(1); return RUN_ALL_TESTS(); } From 7dfe3bdf7a2e7d39ee76c356c1894e72e84bc464 Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Mon, 19 Dec 2016 11:46:57 +0800 Subject: [PATCH 31/55] remove gpu memory alloc --- paddle/gserver/tests/test_PriorBox.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/gserver/tests/test_PriorBox.cpp b/paddle/gserver/tests/test_PriorBox.cpp index 1a7217ab943f7..1dab21218e49d 100644 --- a/paddle/gserver/tests/test_PriorBox.cpp +++ b/paddle/gserver/tests/test_PriorBox.cpp @@ -53,7 +53,7 @@ void doOnePriorBoxTest(size_t featureMapWidth, LayerMap layerMap; vector datas; initDataLayer( - configt, &dataLayers, &datas, &layerMap, "priorbox", 1, false, true); + configt, &dataLayers, &datas, &layerMap, "priorbox", 1, false, false); dataLayers[0]->getOutput().setFrameHeight(featureMapHeight); dataLayers[0]->getOutput().setFrameWidth(featureMapWidth); dataLayers[1]->getOutput().setFrameHeight(imageHeight); From 148bd4d0b3240d31c1c96ddac89ffd4935f71b03 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 19 Dec 2016 15:04:48 +0800 Subject: [PATCH 32/55] add Layer::createFunction --- paddle/gserver/layers/Layer.h | 24 +++++++++++-- paddle/gserver/layers/NormProjectionLayer.cpp | 34 +++++++------------ 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h index 16f66a2205f49..6dfd48fb96618 100644 --- a/paddle/gserver/layers/Layer.h +++ b/paddle/gserver/layers/Layer.h @@ -102,9 +102,9 @@ class Layer { std::vector markInBackward_; /// Layer forward function - FunctionBase* forward_; + std::vector> forward_; /// Layer backward function - FunctionBase* backward_; + std::vector> backward_; public: /** @@ -132,6 +132,26 @@ class Layer { virtual void markAllInputGrad(); protected: + /** + * Create layer function. Function is called in forward or backward. + * \param function, Layer::forward_ or Layer::backward_ + * \param name, function name + * \param config, initialization configuration for the function + */ + void createFunction(std::vector>& function, + const std::string& name, + const FuncConfig& config) { + if (useGpu_) { + function.emplace_back( + FunctionBase::funcRegistrar_.createByType(name + "-GPU")); + } else { + function.emplace_back( + FunctionBase::funcRegistrar_.createByType(name + "-CPU")); + } + auto& func = function.back(); + func->init(config); + } + /** * Notify specified layer the output grad ready. * Called in the backward function. diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index 0f6f9b91d0578..262d757c67e10 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -45,21 +45,13 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, /* the size of inputs for norm-layer is 1 */ CHECK_EQ(config_.inputs_size(), 1); - if (useGpu_) { - forward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormal, GPU)); - backward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, GPU)); - } else { - forward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormal, CPU)); - backward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, CPU)); - } - forward_->init( + createFunction( + forward_, + "CrossMapNormal", FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); - - backward_->init( + createFunction( + backward_, + "CrossMapNormalGrad", FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); return true; @@ -80,7 +72,7 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_}; - forward_->calc( + forward_[0]->calc( {Tensor(input->getData(), dims_)}, {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, {}); @@ -98,11 +90,11 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { MatrixPtr localOutV = getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); - backward_->calc({Tensor(preOutV->getData(), dims_), - Tensor(localOutV->getData(), dims_), - Tensor(localGrad->getData(), dims_), - Tensor(denoms_->getData(), dims_)}, - {Tensor(preOutGrad->getData(), dims_)}, - {}); + backward_[0]->calc({Tensor(preOutV->getData(), dims_), + Tensor(localOutV->getData(), dims_), + Tensor(localGrad->getData(), dims_), + Tensor(denoms_->getData(), dims_)}, + {Tensor(preOutGrad->getData(), dims_)}, + {}); } } // namespace paddle From 1a0669753e5fe2af475905d084149b5d928c9b6a Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Mon, 19 Dec 2016 13:11:48 +0800 Subject: [PATCH 33/55] travis for check broken links --- .travis.yml | 2 +- paddle/scripts/travis/docs.sh | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 5b14f8e61e614..047ca6ffe79bd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -56,7 +56,7 @@ before_install: - if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then sudo paddle/scripts/travis/before_install.linux.sh; fi - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then paddle/scripts/travis/before_install.osx.sh; fi - if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi - - pip install wheel protobuf sphinx recommonmark virtualenv numpy sphinx_rtd_theme pre-commit + - pip install wheel protobuf sphinx recommonmark virtualenv numpy sphinx_rtd_theme pre-commit requests==2.9.2 LinkChecker script: - paddle/scripts/travis/main.sh notifications: diff --git a/paddle/scripts/travis/docs.sh b/paddle/scripts/travis/docs.sh index 0bbb76a8a3caa..4ab1746b5af81 100755 --- a/paddle/scripts/travis/docs.sh +++ b/paddle/scripts/travis/docs.sh @@ -7,6 +7,19 @@ source ./common.sh cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=ON make paddle_docs paddle_docs_cn +# check websites for broken links +set +e +linkchecker doc/cn/html/index.html > doc_cn.out +linkchecker doc/en/html/index.html > doc_en.out +for i in doc_cn.out doc_en.out; do + echo $i + grep " 0 errors found" $i + if [ $? -ne 0 ]; then + cat $i + exit 1 + fi +done + # Parse Github URL REPO=`git config remote.origin.url` SSH_REPO=${REPO/https:\/\/github.com\//git@github.com:} @@ -35,8 +48,8 @@ git checkout $TARGET_BRANCH || git checkout --orphan $TARGET_BRANCH # remove old docs. mv new docs. rm -rf doc doc_cn -mv ../doc_cn/html doc_cn -mv ../doc/html doc +mv ../doc/cn/html doc_cn +mv ../doc/en/html doc # Check is there anything changed. set +e From 706c572424b6f273fd948d60675c25c378e7021a Mon Sep 17 00:00:00 2001 From: xutianbing Date: Fri, 16 Dec 2016 15:14:02 -0800 Subject: [PATCH 34/55] Matrix API refactor, when passing parameters, convert shared_ptr (MatrixPtr) to reference or raw matrix (Matrix & or Matrix *) contextProjectionForward contextProjectionBackward contextProjectionBackwardData contextProjectionBackwardWeight classificationError The mul functions would be updated later. --- paddle/gserver/evaluators/Evaluator.cpp | 2 +- paddle/gserver/layers/ContextProjection.cpp | 12 +- paddle/math/Matrix.cpp | 171 ++++++++------------ paddle/math/Matrix.h | 34 ++-- paddle/math/tests/test_matrixCompare.cpp | 20 +-- 5 files changed, 103 insertions(+), 136 deletions(-) diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index 2f9928191170a..ae7508e2bb117 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -78,7 +78,7 @@ class ClassificationErrorEvaluator : public Evaluator { useGpu(arguments[0].deviceId)); errorMat->zeroMem(); if (label != nullptr) { - errorMat->classificationError(output, label); + errorMat->classificationError(*output, *label); } else if (dynamic_cast(multiBinaryLabel.get()) || dynamic_cast(multiBinaryLabel.get())) { errorMat->classificationErrorMulti( diff --git a/paddle/gserver/layers/ContextProjection.cpp b/paddle/gserver/layers/ContextProjection.cpp index 7ac56e3a2ab2a..51c0ae5cc9523 100644 --- a/paddle/gserver/layers/ContextProjection.cpp +++ b/paddle/gserver/layers/ContextProjection.cpp @@ -90,8 +90,8 @@ void ContextProjection::forward() { REGISTER_TIMER_INFO("ContextProjectionForward", getName().c_str()); bool isPadding = config_.trainable_padding(); out_->value->contextProjectionForward( - in_->value, - state_ ? state_ : isPadding ? weight_->getW() : nullptr, + *(in_->value), + state_ ? state_.get() : isPadding ? weight_->getW().get() : nullptr, *startPositions, config_.context_length(), config_.context_start(), @@ -128,8 +128,8 @@ void ContextProjection::backward(const UpdateCallback& callback) { bool isPadding = config_.trainable_padding(); if (!out_->grad->useGpu()) { out_->grad->contextProjectionBackward( - in_->grad, - isPadding ? weight_->getWGrad() : nullptr, + in_->grad.get(), + isPadding ? weight_->getWGrad().get() : nullptr, *startPositions, config_.context_length(), config_.context_start(), @@ -137,7 +137,7 @@ void ContextProjection::backward(const UpdateCallback& callback) { isPadding); } else { if (in_->grad) { - out_->grad->contextProjectionBackwardData(in_->grad, + out_->grad->contextProjectionBackwardData(*(in_->grad), *startPositions, config_.context_length(), config_.context_start()); @@ -145,7 +145,7 @@ void ContextProjection::backward(const UpdateCallback& callback) { if (isPadding && weight_->getWGrad()) { out_->grad->contextProjectionBackwardWeight( - weight_->getWGrad(), + *(weight_->getWGrad()), *startPositions, config_.context_length(), config_.context_start(), diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index c69e074a76399..3b3c1d7d48a28 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -766,20 +766,19 @@ void GpuMatrix::maxoutBackward(Matrix& a, } /*calulate the error of classification */ -void GpuMatrix::classificationError(MatrixPtr output, IVectorPtr label) { - GpuMatrixPtr output_ptr = std::dynamic_pointer_cast(output); - GpuIVectorPtr label_ptr = std::dynamic_pointer_cast(label); - +void GpuMatrix::classificationError(Matrix& output, IVector& label) { + auto output_ptr = dynamic_cast(&output); + auto label_ptr = dynamic_cast(&label); CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; CHECK(height_ == output_ptr->height_ && width_ == 1) << "Matrix dimensions are not equal"; - real* output_d = output_ptr->data_; - real* recResult_d = data_; - int* label_d = label_ptr->getData(); - hl_matrix_classification_error( - output_d, label_d, recResult_d, height_, output_ptr->width_); + hl_matrix_classification_error((real*)output_ptr->data_, + (int*)label_ptr->getData(), + data_, + height_, + output_ptr->width_); } /* copy -log(output[i * width + label]) to this->data[i] */ @@ -1370,86 +1369,62 @@ void GpuMatrix::maxSequenceBackward(Matrix& outputGrad, hl_max_sequence_backward(outGrad, maxIndex, inputGrad, numSequences, dim); } -void GpuMatrix::contextProjectionForward(MatrixPtr input, - MatrixPtr weight, +void GpuMatrix::contextProjectionForward(Matrix& input, + Matrix* weight, const IVector& sequence, int contextLength, int contextStart, size_t beginPad, bool isPadding) { - CHECK(dynamic_cast(input.get())); + CHECK(dynamic_cast(&input)); CHECK(dynamic_cast(&sequence)); - if (weight) CHECK(dynamic_cast(weight.get())); - - size_t numSequences = sequence.getSize() - 1; - int64_t inputDim = input->getWidth(); - int64_t dim = getWidth(); - CHECK_EQ(dim, inputDim * contextLength); - - real* outData = getData(); - real* inputData = input->getData(); - const int* starts = sequence.getData(); + if (weight) CHECK(dynamic_cast(weight)); + CHECK_EQ(getWidth(), input.getWidth() * contextLength); - hl_context_projection_forward(inputData, - starts, + hl_context_projection_forward(input.getData(), + sequence.getData(), isPadding ? weight->getData() : NULL, - outData, - numSequences, - inputDim, + getData(), + sequence.getSize() - 1, + input.getWidth(), contextLength, contextStart, beginPad, isPadding); } -void GpuMatrix::contextProjectionBackwardData(MatrixPtr inputGrad, +void GpuMatrix::contextProjectionBackwardData(Matrix& inputGrad, const IVector& sequence, int contextLength, int contextStart) { - CHECK(dynamic_cast(inputGrad.get())); + CHECK(dynamic_cast(&inputGrad)); CHECK(dynamic_cast(&sequence)); + CHECK_EQ(getWidth(), inputGrad.getWidth() * contextLength); - size_t numSequences = sequence.getSize() - 1; - int64_t inputDim = inputGrad->getWidth(); - int64_t dim = getWidth(); - CHECK_EQ(dim, inputDim * contextLength); - - real* outGrad = getData(); - real* inGrad = inputGrad->getData(); - const int* starts = sequence.getData(); - - hl_context_projection_backward_data(outGrad, - starts, - inGrad, - numSequences, - inputDim, + hl_context_projection_backward_data(getData(), + sequence.getData(), + inputGrad.getData(), + sequence.getSize() - 1, + inputGrad.getWidth(), contextLength, contextStart); } -void GpuMatrix::contextProjectionBackwardWeight(MatrixPtr weightGrad, +void GpuMatrix::contextProjectionBackwardWeight(Matrix& weightGrad, const IVector& sequence, int contextLength, int contextStart, int totalPad, size_t beginPad) { - CHECK(dynamic_cast(weightGrad.get())); + CHECK(dynamic_cast(&weightGrad)); CHECK(dynamic_cast(&sequence)); + CHECK_EQ(getWidth(), weightGrad.getWidth() * contextLength); - size_t numSequences = sequence.getSize() - 1; - int64_t weightDim = weightGrad->getWidth(); - int64_t dim = getWidth(); - CHECK_EQ(dim, weightDim * contextLength); - - real* outGrad = getData(); - real* wtGrad = weightGrad->getData(); - const int* starts = sequence.getData(); - - hl_context_projection_backward_weight(outGrad, - starts, - wtGrad, - numSequences, - weightDim, + hl_context_projection_backward_weight(getData(), + sequence.getData(), + weightGrad.getData(), + sequence.getSize() - 1, + weightGrad.getWidth(), totalPad, contextLength, contextStart, @@ -2371,23 +2346,21 @@ void CpuMatrix::maxSequenceBackward(Matrix& outputGrad, } } -void CpuMatrix::contextProjectionForward(MatrixPtr input, - MatrixPtr weight, +void CpuMatrix::contextProjectionForward(Matrix& input, + Matrix* weight, const IVector& sequence, int contextLength, int contextStart, size_t beginPad, bool isPadding) { - CHECK(dynamic_cast(input.get())); - CHECK(dynamic_cast(&sequence)); - if (weight) CHECK(dynamic_cast(weight.get())); - - size_t numSequences = sequence.getSize() - 1; - int64_t inputDim = input->getWidth(); - int64_t dim = getWidth(); - CHECK_EQ(dim, inputDim * contextLength); - const int* starts = sequence.getData(); - + auto input_ptr = dynamic_cast(&input); + auto seq_ptr = dynamic_cast(&sequence); + CHECK(input_ptr && seq_ptr); + if (weight) CHECK(dynamic_cast(weight)); + CHECK_EQ(getWidth(), input_ptr->getWidth() * contextLength); + + const int* starts = seq_ptr->getData(); + size_t numSequences = seq_ptr->getSize() - 1; for (size_t i = 0; i < numSequences; ++i) { for (int j = 0; j < contextLength; ++j) { int begin = starts[i] + contextStart + j; @@ -2400,7 +2373,7 @@ void CpuMatrix::contextProjectionForward(MatrixPtr input, MatrixPtr mat = this->subMatrix(starts[i], padSize); if (isPadding) { MatrixPtr sub = weight->subMatrix(j, padSize); - mat->addAtOffset(*sub, j * inputDim); + mat->addAtOffset(*sub, j * input_ptr->getWidth()); } dstBegin = starts[i] + padSize; begin = starts[i]; @@ -2412,41 +2385,36 @@ void CpuMatrix::contextProjectionForward(MatrixPtr input, if (isPadding) { MatrixPtr sub = weight->subMatrix(beginPad + contextStart + j - padSize, padSize); - mat->addAtOffset(*sub, j * inputDim); + mat->addAtOffset(*sub, j * input_ptr->getWidth()); } dstEnd = starts[i + 1] - padSize; end = starts[i + 1]; } if (end <= begin) continue; - MatrixPtr src = input->subMatrix(begin, end - begin); + MatrixPtr src = input_ptr->subMatrix(begin, end - begin); MatrixPtr dst = this->subMatrix(dstBegin, dstEnd - dstBegin); - dst->addAtOffset(*src, j * inputDim); + dst->addAtOffset(*src, j * input_ptr->getWidth()); } } } -void CpuMatrix::contextProjectionBackward(MatrixPtr inputGrad, - MatrixPtr weightGrad, +void CpuMatrix::contextProjectionBackward(Matrix* inputGrad, + Matrix* weightGrad, const IVector& sequence, int contextLength, int contextStart, size_t beginPad, bool isPadding) { - if (inputGrad) CHECK(dynamic_cast(inputGrad.get())); - if (weightGrad) CHECK(dynamic_cast(weightGrad.get())); + if (inputGrad) CHECK(dynamic_cast(inputGrad)); + if (weightGrad) CHECK(dynamic_cast(weightGrad)); CHECK(dynamic_cast(&sequence)); - int64_t inputDim = 0; - int64_t dim = getWidth(); - size_t numSequences = sequence.getSize() - 1; - const int* starts = sequence.getData(); - if (inputGrad) { - inputDim = inputGrad->getWidth(); - } else { - inputDim = weightGrad->getWidth(); - } - CHECK_EQ(dim, inputDim * contextLength); + int64_t inputDim = inputGrad ? inputGrad->getWidth() + : weightGrad ? weightGrad->getWidth() : 0; + CHECK_EQ(getWidth(), inputDim * contextLength); + const int* starts = sequence.getData(); + size_t numSequences = sequence.getSize() - 1; for (size_t i = 0; i < numSequences; ++i) { for (int j = 0; j < contextLength; ++j) { int begin = starts[i] + contextStart + j; @@ -3544,21 +3512,20 @@ void CpuMatrix::rowNormalizeL1(Matrix& out) { } /* calulate classification error */ -void CpuMatrix::classificationError(MatrixPtr output, IVectorPtr label) { - CHECK(dynamic_cast(output.get())); - CHECK(dynamic_cast(label.get())); +void CpuMatrix::classificationError(Matrix& output, IVector& label) { + CHECK(dynamic_cast(&output)); + CHECK(dynamic_cast(&label)); - size_t numSamples = getHeight(); - size_t dim = output->getWidth(); - CHECK_EQ(label->getSize(), numSamples); - CHECK_EQ(output->getHeight(), numSamples); CHECK_EQ(getWidth(), (size_t)1); + size_t numSamples = getHeight(); + CHECK_EQ(label.getSize(), numSamples); + CHECK_EQ(output.getHeight(), numSamples); - real* out = output->getData(); - real* result = getData(); - int* lbl = label->getData(); - real maxData; - int maxIndex; + size_t dim = output.getWidth(); + real* out = output.getData(); + int* lbl = label.getData(); + real maxData = 0.0; + int maxIndex = -1; for (size_t i = 0; i < numSamples; ++i) { CHECK_GE(lbl[i], 0); CHECK_LT((size_t)lbl[i], dim); @@ -3570,7 +3537,7 @@ void CpuMatrix::classificationError(MatrixPtr output, IVectorPtr label) { maxData = out[i * dim + j]; } } - result[i] = (maxIndex != lbl[i]); + getData()[i] = (maxIndex != lbl[i]); } } diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 1cfb90a9dbf19..b8c7adf9486be 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -835,7 +835,7 @@ class Matrix : public BaseMatrix { * * output[i] = 0 if row i is correct. */ - virtual void classificationError(MatrixPtr output, IVectorPtr label) { + virtual void classificationError(Matrix& output, IVector& label) { LOG(FATAL) << "Not implemented"; } @@ -997,8 +997,8 @@ class Matrix : public BaseMatrix { LOG(FATAL) << "Not implemeted"; } - virtual void contextProjectionForward(MatrixPtr input, - MatrixPtr weight, + virtual void contextProjectionForward(Matrix& input, + Matrix* weight, const IVector& sequence, int contextLength, int contextStart, @@ -1007,8 +1007,8 @@ class Matrix : public BaseMatrix { LOG(FATAL) << "Not implemeted"; } - virtual void contextProjectionBackward(MatrixPtr inputGrad, - MatrixPtr weightGrad, + virtual void contextProjectionBackward(Matrix* inputGrad, + Matrix* weightGrad, const IVector& sequence, int contextLength, int contextStart, @@ -1017,14 +1017,14 @@ class Matrix : public BaseMatrix { LOG(FATAL) << "Not implemeted"; } - virtual void contextProjectionBackwardData(MatrixPtr inputGrad, + virtual void contextProjectionBackwardData(Matrix& inputGrad, const IVector& sequence, int contextLength, int contextStart) { LOG(FATAL) << "Not implemeted"; } - virtual void contextProjectionBackwardWeight(MatrixPtr weightGrad, + virtual void contextProjectionBackwardWeight(Matrix& weightGrad, const IVector& sequence, int contextLength, int contextStart, @@ -1373,7 +1373,7 @@ class GpuMatrix : public Matrix { void check(std::ostream& os, Matrix& refMat, bool printDiff = true); void randomizeUniform(); - void classificationError(MatrixPtr output, IVectorPtr label); + void classificationError(Matrix& output, IVector& label); void convExpand(Matrix& feature, int feaImgHeight, @@ -1487,20 +1487,20 @@ class GpuMatrix : public Matrix { const IVector& sequence, IVector& index); - void contextProjectionForward(MatrixPtr input, - MatrixPtr weight, + void contextProjectionForward(Matrix& input, + Matrix* weight, const IVector& sequence, int contextLength, int contextStart, size_t beginPad, bool isPadding); - void contextProjectionBackwardData(MatrixPtr inputGrad, + void contextProjectionBackwardData(Matrix& inputGrad, const IVector& sequence, int contextLength, int contextStart); - void contextProjectionBackwardWeight(MatrixPtr weightGrad, + void contextProjectionBackwardWeight(Matrix& weightGrad, const IVector& sequence, int contextLength, int contextStart, @@ -1713,16 +1713,16 @@ class CpuMatrix : public Matrix { const IVector& sequence, IVector& index); - void contextProjectionForward(MatrixPtr input, - MatrixPtr weight, + void contextProjectionForward(Matrix& input, + Matrix* weight, const IVector& sequence, int contextLength, int contextStart, size_t beginPad, bool isPadding); - void contextProjectionBackward(MatrixPtr inputGrad, - MatrixPtr weightGrad, + void contextProjectionBackward(Matrix* inputGrad, + Matrix* weightGrad, const IVector& sequence, int contextLength, int contextStart, @@ -1881,7 +1881,7 @@ class CpuMatrix : public Matrix { void randomizeUniform(); - void classificationError(MatrixPtr output, IVectorPtr label); + void classificationError(Matrix& output, IVector& label); void addByBitCode(size_t numClasses, const IVector& codes, const Matrix& vec); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 62de5b25e4cc8..10289940a4c5f 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -65,16 +65,16 @@ void testMatrixProjectionForward(int contextStart, // calculate int beginPad = std::max(0, -contextStart); - cpuOutput->contextProjectionForward(cpuInput, - cpuWeight, + cpuOutput->contextProjectionForward(*cpuInput, + cpuWeight.get(), *cpuSequence, contextLength, contextStart, beginPad, padding); - gpuOutput->contextProjectionForward(gpuInput, - gpuWeight, + gpuOutput->contextProjectionForward(*gpuInput, + gpuWeight.get(), *gpuSequence, contextLength, contextStart, @@ -120,17 +120,17 @@ void testMatrixProjectionBackward(int contextStart, // calculate int beginPad = std::max(0, -contextStart); - cpuOutputGrad->contextProjectionBackward(cpuInputGrad, - cpuWeightGrad, + cpuOutputGrad->contextProjectionBackward(cpuInputGrad.get(), + cpuWeightGrad.get(), *cpuSequence, contextLength, contextStart, beginPad, padding); gpuOutputGrad->contextProjectionBackwardData( - gpuInputGrad, *gpuSequence, contextLength, contextStart); + *gpuInputGrad, *gpuSequence, contextLength, contextStart); if (padding) { - gpuOutputGrad->contextProjectionBackwardWeight(gpuWeightGrad, + gpuOutputGrad->contextProjectionBackwardWeight(*gpuWeightGrad, *gpuSequence, contextLength, contextStart, @@ -939,8 +939,8 @@ void testClassificationError(int numSamples, int dim) { gpuOutput->copyFrom(*cpuOutput); gpuLabel->copyFrom(*cpuLabel); - cpuError->classificationError(cpuOutput, cpuLabel); - gpuError->classificationError(gpuOutput, gpuLabel); + cpuError->classificationError(*cpuOutput, *cpuLabel); + gpuError->classificationError(*gpuOutput, *gpuLabel); TensorCheckEqual(*cpuError, *gpuError); } From 4fbf94993b0699bb06c5347612a7b97d692a2625 Mon Sep 17 00:00:00 2001 From: xutianbing Date: Mon, 19 Dec 2016 17:21:06 -0800 Subject: [PATCH 35/55] Refactor MUL functions, pass object reference instead of shared_ptr. --- .../gserver/layers/ConvexCombinationLayer.cpp | 6 +- paddle/gserver/layers/ExpandConvBaseLayer.cpp | 6 +- .../gserver/layers/FullMatrixProjection.cpp | 7 ++- paddle/gserver/layers/FullyConnectedLayer.cpp | 8 +-- paddle/gserver/layers/LinearChainCRF.cpp | 2 +- paddle/gserver/layers/LstmLayer.cpp | 26 ++++----- paddle/gserver/layers/MDLstmLayer.cpp | 8 +-- paddle/gserver/layers/OuterProdLayer.cpp | 6 +- paddle/gserver/layers/RecurrentLayer.cpp | 32 +++++------ .../layers/SelectiveFullyConnectedLayer.cpp | 10 ++-- paddle/gserver/layers/TensorLayer.cpp | 8 +-- .../layers/TransposedFullMatrixProjection.cpp | 7 ++- paddle/math/CpuSparseMatrix.cpp | 15 ++--- paddle/math/CpuSparseMatrix.h | 2 +- paddle/math/Matrix.cpp | 49 +++++++---------- paddle/math/Matrix.h | 14 ++--- paddle/math/SparseMatrix.cpp | 55 +++++++++---------- paddle/math/SparseMatrix.h | 7 +-- paddle/math/tests/test_SparseMatrix.cpp | 14 ++--- paddle/math/tests/test_matrixCompare.cpp | 12 ++-- .../math/tests/test_sparseMatrixCompare.cpp | 4 +- 21 files changed, 144 insertions(+), 154 deletions(-) diff --git a/paddle/gserver/layers/ConvexCombinationLayer.cpp b/paddle/gserver/layers/ConvexCombinationLayer.cpp index 3f4d77a2fe069..ed57f2af3c645 100644 --- a/paddle/gserver/layers/ConvexCombinationLayer.cpp +++ b/paddle/gserver/layers/ConvexCombinationLayer.cpp @@ -113,7 +113,7 @@ void ConvexCombinationLayer::forward(PassType passType) { tmpRow0->setData(inV0->getData() + i * weightDim); tmpRow1->setData(outV->getData() + i * dataDim); - tmpRow1->mul(tmpRow0, tmpMtx0, 1, 0); + tmpRow1->mul(*tmpRow0, *tmpMtx0, 1, 0); } } @@ -136,7 +136,7 @@ void ConvexCombinationLayer::backward(const UpdateCallback& callback) { tmpRow1->setData(outG->getData() + i * dataDim); tmpMtx0->setData(inV1->getData() + i * weightDim * dataDim); - tmpRow0->mul(tmpRow1, tmpMtx0->getTranspose(), 1, 1); + tmpRow0->mul(*tmpRow1, *(tmpMtx0->getTranspose()), 1, 1); } } @@ -146,7 +146,7 @@ void ConvexCombinationLayer::backward(const UpdateCallback& callback) { tmpRow1->setData(outG->getData() + i * dataDim); tmpMtx0->setData(inG1->getData() + i * weightDim * dataDim); - tmpMtx0->mul(tmpRow0->getTranspose(), tmpRow1, 1, 1); + tmpMtx0->mul(*(tmpRow0->getTranspose()), *tmpRow1, 1, 1); } } } diff --git a/paddle/gserver/layers/ExpandConvBaseLayer.cpp b/paddle/gserver/layers/ExpandConvBaseLayer.cpp index 25948747fe93e..9ddccc202705c 100644 --- a/paddle/gserver/layers/ExpandConvBaseLayer.cpp +++ b/paddle/gserver/layers/ExpandConvBaseLayer.cpp @@ -150,7 +150,7 @@ void ExpandConvBaseLayer::expandFwdOnce(MatrixPtr image, Matrix::create(wgtData, subM, subK, false, useGpu_); // mark transpose MatrixPtr B = Matrix::create(expInData, subK, subN, false, useGpu_); MatrixPtr C = Matrix::create(outData, subM, subN, false, useGpu_); - C->mul(A, B, 1, 1); + C->mul(*A, *B, 1, 1); A->clear(); B->clear(); @@ -185,7 +185,7 @@ void ExpandConvBaseLayer::bpropActs(MatrixPtr out, MatrixPtr C = Matrix::create(expandInData, subK, subN, false, useGpu_); MatrixPtr B = Matrix::create(localGradData, subM, subN, false, useGpu_); MatrixPtr A = Matrix::create(wgtData, subM, subK, true, useGpu_); - C->mul(A, B); // mul + C->mul(*A, *B); // mul // clear the temporary matrix A->clear(); @@ -252,7 +252,7 @@ void ExpandConvBaseLayer::bpropWeights(MatrixPtr image, MatrixPtr A = Matrix::create(expandInData, subK, subN, true, useGpu_); MatrixPtr B = Matrix::create(gradData, subM, subN, false, useGpu_); MatrixPtr C = Matrix::create(wGradData, subM, subK, false, useGpu_); - C->mul(B, A, 1, 1); + C->mul(*B, *A, 1, 1); A->clear(); B->clear(); diff --git a/paddle/gserver/layers/FullMatrixProjection.cpp b/paddle/gserver/layers/FullMatrixProjection.cpp index 9e72a33a3c6f4..b8b6f403d6a02 100644 --- a/paddle/gserver/layers/FullMatrixProjection.cpp +++ b/paddle/gserver/layers/FullMatrixProjection.cpp @@ -28,7 +28,7 @@ FullMatrixProjection::FullMatrixProjection(const ProjectionConfig& config, void FullMatrixProjection::forward() { REGISTER_TIMER_INFO("FwMulTimer", getName().c_str()); - out_->value->mul(in_->value, weight_->getW(), 1, 1); + out_->value->mul(*(in_->value), *(weight_->getW()), 1, 1); } void FullMatrixProjection::backward(const UpdateCallback& callback) { @@ -37,7 +37,8 @@ void FullMatrixProjection::backward(const UpdateCallback& callback) { /* Calculate the W-gradient for the current layer */ if (weight_->getWGrad()) { REGISTER_TIMER_INFO("GradMulTimer", getName().c_str()); - weight_->getWGrad()->mul(in_->value->getTranspose(), out_->grad, 1, 1); + weight_->getWGrad()->mul( + *(in_->value->getTranspose()), *(out_->grad), 1, 1); } // If callback does not change value, backward propagation error @@ -47,7 +48,7 @@ void FullMatrixProjection::backward(const UpdateCallback& callback) { /* Calculate the input layers error */ if (in_->grad) { REGISTER_TIMER_INFO("BpMulTimer", getName().c_str()); - in_->grad->mul(out_->grad, weight_->getW()->getTranspose(), 1, 1); + in_->grad->mul(*(out_->grad), *(weight_->getW()->getTranspose()), 1, 1); } hl_set_sync_flag(syncFlag); diff --git a/paddle/gserver/layers/FullyConnectedLayer.cpp b/paddle/gserver/layers/FullyConnectedLayer.cpp index 89afe33c36697..d8a667ff8dc02 100644 --- a/paddle/gserver/layers/FullyConnectedLayer.cpp +++ b/paddle/gserver/layers/FullyConnectedLayer.cpp @@ -84,8 +84,8 @@ void FullyConnectedLayer::forward(PassType passType) { auto input = getInput(i); CHECK(input.value) << "The input of 'fc' layer must be matrix"; REGISTER_TIMER_INFO("FwMulTimer", getName().c_str()); - i == 0 ? outV->mul(input.value, weights_[i]->getW(), 1, 0) - : outV->mul(input.value, weights_[i]->getW(), 1, 1); + i == 0 ? outV->mul(*input.value, *weights_[i]->getW(), 1, 0) + : outV->mul(*input.value, *weights_[i]->getW(), 1, 1); } /* add the bias-vector */ @@ -123,7 +123,7 @@ void FullyConnectedLayer::backward(const UpdateCallback& callback) { MatrixPtr oGrad = getOutputGrad(); { REGISTER_TIMER_INFO("GradMulTimer", getName().c_str()); - weights_[i]->getWGrad()->mul(input_T, oGrad, 1, 1); + weights_[i]->getWGrad()->mul(*input_T, *oGrad, 1, 1); } } @@ -136,7 +136,7 @@ void FullyConnectedLayer::backward(const UpdateCallback& callback) { if (NULL != preGrad) { MatrixPtr weights_T = weights_[i]->getW()->getTranspose(); REGISTER_TIMER_INFO("BpMulTimer", getName().c_str()); - preGrad->mul(getOutputGrad(), weights_T, 1, 1); + preGrad->mul(*getOutputGrad(), *weights_T, 1, 1); } hl_set_sync_flag(syncFlag); diff --git a/paddle/gserver/layers/LinearChainCRF.cpp b/paddle/gserver/layers/LinearChainCRF.cpp index af550c7a01548..b7f748f3bb8a4 100644 --- a/paddle/gserver/layers/LinearChainCRF.cpp +++ b/paddle/gserver/layers/LinearChainCRF.cpp @@ -59,7 +59,7 @@ real LinearChainCRF::forward(real* x, int* s, int length) { matX->rowMax(*maxX_); expX_->assign(*matX); // subtract max to avoid overflow or underflow - expX_->mul(maxX_, ones_, (real)-1, (real)1); + expX_->mul(*maxX_, *ones_, (real)-1, (real)1); expX_->exp2(); real* a = a_->getData(); diff --git a/paddle/gserver/layers/LstmLayer.cpp b/paddle/gserver/layers/LstmLayer.cpp index 2543d1b49a801..01cc5fec8b970 100644 --- a/paddle/gserver/layers/LstmLayer.cpp +++ b/paddle/gserver/layers/LstmLayer.cpp @@ -316,7 +316,7 @@ void LstmLayer::forwardSequence(int batchSize, } if (prevOutput_) { frameGate->setData(lstmValue.gateValue); - frameGate->mul(prevOutput_, weight_->getW(), 1, 1); + frameGate->mul(*prevOutput_, *weight_->getW(), 1, 1); } } AsyncGpuBlock asyncGpuBlock; @@ -338,7 +338,7 @@ void LstmLayer::forwardSequence(int batchSize, frameOutput->setData(lstmValue.outputValue); nextFrame(reversed_, getSize()); frameGate->setData(lstmValue.gateValue); - frameGate->mul(frameOutput, weight_->getW(), 1, 1); + frameGate->mul(*frameOutput, *weight_->getW(), 1, 1); } } if (n != numSequences - 1) { @@ -348,7 +348,7 @@ void LstmLayer::forwardSequence(int batchSize, if (!reversed_) { if (!prevState_) lstmValue.prevStateValue = nullptr; if (prevOutput_) { - frameGate->mul(frameOutput, weight_->getW(), 1, 1); + frameGate->mul(*frameOutput, *weight_->getW(), 1, 1); } } else { lstmValue.prevStateValue = nullptr; @@ -470,7 +470,7 @@ void LstmLayer::backwardSequence(int batchSize, frameGate->setData(lstmGrad.gateGrad); nextFrame(reversed_, getSize()); frameOutput->setData(lstmGrad.outputGrad); - frameOutput->mul(frameGate, weightT, 1, 1); + frameOutput->mul(*frameGate, *weightT, 1, 1); } else { nextFrame(reversed_, getSize()); } @@ -479,14 +479,14 @@ void LstmLayer::backwardSequence(int batchSize, if (weight_->getWGrad()) { if (!reversed_) { weight_->getWGrad()->mul( - output_.value->subMatrix(start, length - 1)->getTranspose(), - gate_.grad->subMatrix(start + 1, length - 1), + *output_.value->subMatrix(start, length - 1)->getTranspose(), + *gate_.grad->subMatrix(start + 1, length - 1), 1, 1); } else { weight_->getWGrad()->mul( - output_.value->subMatrix(start + 1, length - 1)->getTranspose(), - gate_.grad->subMatrix(start, length - 1), + *output_.value->subMatrix(start + 1, length - 1)->getTranspose(), + *gate_.grad->subMatrix(start, length - 1), 1, 1); } @@ -541,7 +541,7 @@ void LstmLayer::forwardBatch(int batchSize, if (n != 0) { MatrixPtr batch1 = batchValue_->getBatchValue(n - 1, batchSize); - gateValue->mul(batch1, weight_->getW(), 1, 1); + gateValue->mul(*batch1, *weight_->getW(), 1, 1); } else if (prevOutput_) { Matrix::resizeOrCreate(prevBatchOutput2_, gateValue->getHeight(), @@ -549,7 +549,7 @@ void LstmLayer::forwardBatch(int batchSize, false, useGpu_); batchValue_->prevOutput2Batch(*prevOutput_, *prevBatchOutput2_); - gateValue->mul(prevBatchOutput2_, weight_->getW(), 1, 1); + gateValue->mul(*prevBatchOutput2_, *weight_->getW(), 1, 1); batchValue_->prevOutput2Batch(*prevState_, *totalState_->subMatrix(0, numSequences)); @@ -672,16 +672,16 @@ void LstmLayer::backwardBatch(int batchSize, if (n != 0) { MatrixPtr tmp = batchGrad_->getBatchValue(n - 1, batchSize); - tmp->mul(gateGrad, weightT, 1, 1); + tmp->mul(*gateGrad, *weightT, 1, 1); } if (n != 0 && weight_->getWGrad()) { /* backward weight */ MatrixPtr outputValue = batchValue_->getBatchValue(n - 1, batchSize); - weight_->getWGrad()->mul(outputValue->getTranspose(), gateGrad, 1, 1); + weight_->getWGrad()->mul(*outputValue->getTranspose(), *gateGrad, 1, 1); } else if (prevOutput_ && weight_->getWGrad()) { weight_->getWGrad()->mul( - prevBatchOutput2_->getTranspose(), gateGrad, 1, 1); + *prevBatchOutput2_->getTranspose(), *gateGrad, 1, 1); } } } diff --git a/paddle/gserver/layers/MDLstmLayer.cpp b/paddle/gserver/layers/MDLstmLayer.cpp index 1243c12889542..fb41af5631954 100644 --- a/paddle/gserver/layers/MDLstmLayer.cpp +++ b/paddle/gserver/layers/MDLstmLayer.cpp @@ -547,7 +547,7 @@ void MDLstmLayer::forwardOneSequence(int start, CoordIterator& coordIter) { if (coordIter.getPrePos(delays_, i, prePos)) { int preOffset = coordIter.offset(prePos); frameGate_[start + offset].value->mul( - frameOutput_[start + preOffset].value, weight_->getW(), 1.0, 1.0); + *frameOutput_[start + preOffset].value, *weight_->getW(), 1.0, 1.0); } } forwardGate2OutputSequence(start, coordIter); @@ -747,11 +747,11 @@ void MDLstmLayer::backwardOneSequence(int start, CoordIterator& coordIter) { if (coordIter.getPrePos(delays_, i, prePos)) { int preOffset = coordIter.offset(prePos); frameOutput_[start + preOffset].grad->mul( - frameGate_[start + offset].grad, weightT, 1.0, 1.0); + *frameGate_[start + offset].grad, *weightT, 1.0, 1.0); if (weight_->getWGrad()) { weight_->getWGrad()->mul( - frameOutput_[start + preOffset].value->getTranspose(), - frameGate_[start + offset].grad, + *frameOutput_[start + preOffset].value->getTranspose(), + *frameGate_[start + offset].grad, 1.0, 1.0); } diff --git a/paddle/gserver/layers/OuterProdLayer.cpp b/paddle/gserver/layers/OuterProdLayer.cpp index cf9a008318e9d..b606e4436567e 100644 --- a/paddle/gserver/layers/OuterProdLayer.cpp +++ b/paddle/gserver/layers/OuterProdLayer.cpp @@ -96,7 +96,7 @@ void OuterProdLayer::forward(PassType passType) { tmpRow0->setData(inV0->getData() + i * dim0); tmpRow1->setData(inV1->getData() + i * dim1); - tmpMtx0->mul(tmpRow0->getTranspose(), tmpRow1); + tmpMtx0->mul(*tmpRow0->getTranspose(), *tmpRow1); } } } @@ -121,7 +121,7 @@ void OuterProdLayer::backward(const UpdateCallback& callback) { tmpRow0->setData(inG0->getData() + i * dim0); tmpRow1->setData(inV1->getData() + i * dim1); - tmpRow0->mul(tmpRow1, tmpMtx0->getTranspose(), 1, 1); + tmpRow0->mul(*tmpRow1, *tmpMtx0->getTranspose(), 1, 1); } } @@ -131,7 +131,7 @@ void OuterProdLayer::backward(const UpdateCallback& callback) { tmpRow0->setData(inV0->getData() + i * dim0); tmpRow1->setData(inG1->getData() + i * dim1); - tmpRow1->mul(tmpRow0, tmpMtx0, 1, 1); + tmpRow1->mul(*tmpRow0, *tmpMtx0, 1, 1); } } } diff --git a/paddle/gserver/layers/RecurrentLayer.cpp b/paddle/gserver/layers/RecurrentLayer.cpp index 85812c9d660e0..94b16996a86d2 100644 --- a/paddle/gserver/layers/RecurrentLayer.cpp +++ b/paddle/gserver/layers/RecurrentLayer.cpp @@ -215,12 +215,12 @@ void RecurrentLayer::forwardSequence(int batchSize, void RecurrentLayer::forwardOneSequence(int start, int length) { if (!reversed_) { if (prevOutput_) { - frameOutput_[start].value->mul(prevOutput_, weight_->getW(), 1, 1); + frameOutput_[start].value->mul(*prevOutput_, *weight_->getW(), 1, 1); } activation_->forward(frameOutput_[start]); for (int i = 1; i < length; ++i) { frameOutput_[start + i].value->mul( - frameOutput_[start + i - 1].value, weight_->getW(), 1, 1); + *frameOutput_[start + i - 1].value, *weight_->getW(), 1, 1); activation_->forward(frameOutput_[start + i]); } if (prevOutput_) { @@ -230,7 +230,7 @@ void RecurrentLayer::forwardOneSequence(int start, int length) { activation_->forward(frameOutput_[start + length - 1]); for (int i = length - 2; i >= 0; --i) { frameOutput_[start + i].value->mul( - frameOutput_[start + i + 1].value, weight_->getW(), 1, 1); + *frameOutput_[start + i + 1].value, *weight_->getW(), 1, 1); activation_->forward(frameOutput_[start + i]); } } @@ -282,13 +282,13 @@ void RecurrentLayer::backwardOneSequence(int start, int length) { for (int i = length - 1; i > 0; --i) { activation_->backward(frameOutput_[start + i]); frameOutput_[start + i - 1].grad->mul( - frameOutput_[start + i].grad, weightT, 1, 1); + *frameOutput_[start + i].grad, *weightT, 1, 1); } activation_->backward(frameOutput_[start]); if (weight_->getWGrad()) { weight_->getWGrad()->mul( - output_.value->subMatrix(start, length - 1)->getTranspose(), - output_.grad->subMatrix(start + 1, length - 1), + *output_.value->subMatrix(start, length - 1)->getTranspose(), + *output_.grad->subMatrix(start + 1, length - 1), 1, 1); } @@ -296,13 +296,13 @@ void RecurrentLayer::backwardOneSequence(int start, int length) { for (int i = 0; i < length - 1; ++i) { activation_->backward(frameOutput_[start + i]); frameOutput_[start + i + 1].grad->mul( - frameOutput_[start + i].grad, weightT, 1, 1); + *frameOutput_[start + i].grad, *weightT, 1, 1); } activation_->backward(frameOutput_[start + length - 1]); if (weight_->getWGrad()) { weight_->getWGrad()->mul( - output_.value->subMatrix(start + 1, length - 1)->getTranspose(), - output_.grad->subMatrix(start, length - 1), + *output_.value->subMatrix(start + 1, length - 1)->getTranspose(), + *output_.grad->subMatrix(start, length - 1), 1, 1); } @@ -329,7 +329,7 @@ void RecurrentLayer::forwardBatch(int batchSize, if (n != 0) { MatrixPtr batch1 = batchValue_->getBatchValue(n - 1, batch2->getHeight()); - batch2->mul(batch1, weight_->getW(), 1, 1); + batch2->mul(*batch1, *weight_->getW(), 1, 1); } Argument arg; arg.value = batch2; @@ -367,14 +367,14 @@ void RecurrentLayer::backwardBatch(int batchSize, if (n != 0) { batch1 = batchGrad_->getBatchValue(n - 1, batch2->getHeight()); - batch1->mul(batch2, weightT, 1, 1); + batch1->mul(*batch2, *weightT, 1, 1); } if (backwardByBatch && weight_->getWGrad()) { if (n != 0) { /* backward weight */ batch1 = batchValue_->getBatchValue(n - 1, batch2->getHeight()); - weight_->getWGrad()->mul(batch1->getTranspose(), batch2, 1, 1); + weight_->getWGrad()->mul(*batch1->getTranspose(), *batch2, 1, 1); } } } @@ -389,14 +389,14 @@ void RecurrentLayer::backwardBatch(int batchSize, int len = starts[seq + 1] - starts[seq]; if (!reversed_) { weight_->getWGrad()->mul( - output_.value->subMatrix(starts[seq], len - 1)->getTranspose(), - output_.grad->subMatrix(starts[seq] + 1, len - 1), + *output_.value->subMatrix(starts[seq], len - 1)->getTranspose(), + *output_.grad->subMatrix(starts[seq] + 1, len - 1), 1, 1); } else { weight_->getWGrad()->mul( - output_.value->subMatrix(starts[seq] + 1, len - 1)->getTranspose(), - output_.grad->subMatrix(starts[seq], len - 1), + *output_.value->subMatrix(starts[seq] + 1, len - 1)->getTranspose(), + *output_.grad->subMatrix(starts[seq], len - 1), 1, 1); } diff --git a/paddle/gserver/layers/SelectiveFullyConnectedLayer.cpp b/paddle/gserver/layers/SelectiveFullyConnectedLayer.cpp index 9200a01eee3be..5eacff6b71439 100644 --- a/paddle/gserver/layers/SelectiveFullyConnectedLayer.cpp +++ b/paddle/gserver/layers/SelectiveFullyConnectedLayer.cpp @@ -155,20 +155,20 @@ void SelectiveFullyConnectedLayer::forward(PassType passType) { // manully compute the multiplication of // the input vector and the selected rows. REGISTER_TIMER("selective.plain"); - interOutput_->mul(input, weight->getTranspose(), 1, scaleT); + interOutput_->mul(*input, *weight->getTranspose(), 1, scaleT); } else { // if the indecies is not sparse enough, // use full mul instead REGISTER_TIMER("selective.mul"); if (fullOutput_) { - interOutput_->mul(input, weight->getTranspose(), 1, scaleT); + interOutput_->mul(*input, *weight->getTranspose(), 1, scaleT); } else { Matrix::resizeOrCreate(mmat_, hsize, wsize, /*trans=*/false, /*useGpu=*/useGpu_); - mmat_->mul(input, weight->getTranspose()); + mmat_->mul(*input, *weight->getTranspose()); interOutput_->add3(mmat_); } } @@ -242,14 +242,14 @@ void SelectiveFullyConnectedLayer::backward(const UpdateCallback& callback) { MatrixPtr preGrad = getInputGrad(i); if (preGrad) { REGISTER_TIMER_INFO("BpMulTimer", getName().c_str()); - preGrad->mul(interOutGrad_, weights_[i]->getW(), 1, 1); + preGrad->mul(*interOutGrad_, *weights_[i]->getW(), 1, 1); } MatrixPtr wGrad = weights_[i]->getWGrad(); if (wGrad) { REGISTER_TIMER_INFO("GradMulTimer", getName().c_str()); MatrixPtr input = getInputValue(i); - wGrad->mul(interOutGrad_->getTranspose(), input, 1, 1); + wGrad->mul(*interOutGrad_->getTranspose(), *input, 1, 1); } { diff --git a/paddle/gserver/layers/TensorLayer.cpp b/paddle/gserver/layers/TensorLayer.cpp index 642eb1bdd31c0..5be88d7c05dae 100644 --- a/paddle/gserver/layers/TensorLayer.cpp +++ b/paddle/gserver/layers/TensorLayer.cpp @@ -77,7 +77,7 @@ void TensorLayer::forward(PassType passType) { REGISTER_TIMER_INFO("TensorFwMulTimer", getName().c_str()); for (size_t i = 0; i < getSize(); ++i) { MatrixPtr weights = weights_[i]->getW(); - tmpMat->mul(input1, weights, 1, 0); + tmpMat->mul(*input1, *weights, 1, 0); outV->rowDotMul(i, *tmpMat, *input2); } } @@ -112,7 +112,7 @@ void TensorLayer::backward(const UpdateCallback& callback) { if (weights_[i]->getWGrad()) { tmpMat->rowScale(i, *input1, *oGrad); MatrixPtr input1_T = tmpMat->getTranspose(); - weights_[i]->getWGrad()->mul(input1_T, input2, 1, 1); + weights_[i]->getWGrad()->mul(*input1_T, *input2, 1, 1); } } } @@ -130,11 +130,11 @@ void TensorLayer::backward(const UpdateCallback& callback) { if (NULL != preGrad1) { /* (grad * e2) * trans(W) */ tmpMat->rowScale(i, *input2, *oGrad); MatrixPtr weights_T = weights->getTranspose(); - preGrad1->mul(tmpMat, weights_T, 1, 1); + preGrad1->mul(*tmpMat, *weights_T, 1, 1); } if (NULL != preGrad2) { /* (grad * e1) * W */ tmpMat->rowScale(i, *input1, *oGrad); - preGrad2->mul(tmpMat, weights, 1, 1); + preGrad2->mul(*tmpMat, *weights, 1, 1); } } } diff --git a/paddle/gserver/layers/TransposedFullMatrixProjection.cpp b/paddle/gserver/layers/TransposedFullMatrixProjection.cpp index 3f7ff04882075..2a12499e5b5f1 100644 --- a/paddle/gserver/layers/TransposedFullMatrixProjection.cpp +++ b/paddle/gserver/layers/TransposedFullMatrixProjection.cpp @@ -46,7 +46,7 @@ TransposedFullMatrixProjection::TransposedFullMatrixProjection( void TransposedFullMatrixProjection::forward() { REGISTER_TIMER_INFO("FwMulTimer", getName().c_str()); - out_->value->mul(in_->value, weight_->getW()->getTranspose(), 1, 1); + out_->value->mul(*(in_->value), *(weight_->getW()->getTranspose()), 1, 1); } void TransposedFullMatrixProjection::backward(const UpdateCallback& callback) { @@ -55,7 +55,8 @@ void TransposedFullMatrixProjection::backward(const UpdateCallback& callback) { /* Calculate the W-gradient for the current layer */ if (weight_->getWGrad()) { REGISTER_TIMER_INFO("GradMulTimer", getName().c_str()); - weight_->getWGrad()->mul(out_->grad->getTranspose(), in_->value, 1, 1); + weight_->getWGrad()->mul( + *(out_->grad->getTranspose()), *(in_->value), 1, 1); } // If callback does not change value, backprop error asynchronously so that @@ -69,7 +70,7 @@ void TransposedFullMatrixProjection::backward(const UpdateCallback& callback) { /* Calculate the input layers error */ if (in_->grad) { REGISTER_TIMER_INFO("BpMulTimer", getName().c_str()); - in_->grad->mul(out_->grad, weight_->getW(), 1, 1); + in_->grad->mul(*(out_->grad), *(weight_->getW()), 1, 1); } hl_set_sync_flag(syncFlag); diff --git a/paddle/math/CpuSparseMatrix.cpp b/paddle/math/CpuSparseMatrix.cpp index b5d5b6ef61582..82a482f701481 100644 --- a/paddle/math/CpuSparseMatrix.cpp +++ b/paddle/math/CpuSparseMatrix.cpp @@ -163,15 +163,16 @@ MatrixPtr CpuSparseMatrix::getTranspose() { SparseValueType CpuSparseMatrix::getValueType() { return valueType_; } -void CpuSparseMatrix::mul(MatrixPtr a, MatrixPtr b, real scaleAB, real scaleT) { +void CpuSparseMatrix::mul(const Matrix& a, + const Matrix& b, + real scaleAB, + real scaleT) { CHECK(!isTransposed()) << "Not supported"; + const auto a_ptr = dynamic_cast(&a); + const auto b_ptr = dynamic_cast(&b); - if (dynamic_cast(a.get()) && dynamic_cast(b.get())) { - CpuMatrix::mul(dynamic_cast(a.get()), - dynamic_cast(b.get()), - this, - scaleAB, - scaleT); + if (a_ptr && b_ptr) { + CpuMatrix::mul((CpuMatrix*)a_ptr, (CpuMatrix*)b_ptr, this, scaleAB, scaleT); } else { LOG(FATAL) << "not supported"; } diff --git a/paddle/math/CpuSparseMatrix.h b/paddle/math/CpuSparseMatrix.h index 9676f8864f845..d3e8871cb5b32 100644 --- a/paddle/math/CpuSparseMatrix.h +++ b/paddle/math/CpuSparseMatrix.h @@ -203,7 +203,7 @@ class CpuSparseMatrix : public Matrix { /// mem MUST be alloced outside (memAlloc=false) void transpose(MatrixPtr matTrans, bool memAlloc); - void mul(MatrixPtr A, MatrixPtr B, real alpha, real beta); + void mul(const Matrix& A, const Matrix& B, real alpha, real beta); /** * @brief sparseMatrix += denseMatrix diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 3b3c1d7d48a28..0193f2f997303 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -582,18 +582,16 @@ void GpuMatrix::mul(const GpuMatrix& a, } /* this = a*b */ -void GpuMatrix::mul(const MatrixPtr a, const MatrixPtr b) { - mul(a, b, 1.0, 0.0); -} +void GpuMatrix::mul(const Matrix& a, const Matrix& b) { mul(a, b, 1.0, 0.0); } -void GpuMatrix::mul(const MatrixPtr a, - const MatrixPtr b, +void GpuMatrix::mul(const Matrix& a, + const Matrix& b, real scaleAB, real scaleT) { - GpuMatrixPtr a_ptr = std::dynamic_pointer_cast(a); - GpuMatrixPtr b_ptr = std::dynamic_pointer_cast(b); - GpuSparseMatrixPtr a_ptr_s = std::dynamic_pointer_cast(a); - GpuSparseMatrixPtr b_ptr_s = std::dynamic_pointer_cast(b); + const auto a_ptr = dynamic_cast(&a); + const auto b_ptr = dynamic_cast(&b); + const auto a_ptr_s = dynamic_cast(&a); + const auto b_ptr_s = dynamic_cast(&b); if (a_ptr && b_ptr) { mul(*a_ptr, *b_ptr, scaleAB, scaleT); @@ -2598,29 +2596,22 @@ void CpuMatrix::sequenceAvgForward(Matrix& a, } /* this = scaleAB*(a*b) + scaleT*this*/ -void CpuMatrix::mul(const MatrixPtr a, - const MatrixPtr b, +void CpuMatrix::mul(const Matrix& a, + const Matrix& b, real scaleAB, real scaleT) { CHECK(!isTransposed()) << "Not supported"; + const auto a_ptr = dynamic_cast(&a); + const auto b_ptr = dynamic_cast(&b); + const auto a_ptr_s = dynamic_cast(&a); + const auto b_ptr_s = dynamic_cast(&b); - if (dynamic_cast(a.get()) && dynamic_cast(b.get())) { - mul(dynamic_cast(a.get()), - dynamic_cast(b.get()), - scaleAB, - scaleT); - } else if (dynamic_cast(a.get()) && - dynamic_cast(b.get())) { - mul(dynamic_cast(a.get()), - dynamic_cast(b.get()), - scaleAB, - scaleT); - } else if (dynamic_cast(a.get()) && - dynamic_cast(b.get())) { - mul(dynamic_cast(a.get()), - dynamic_cast(b.get()), - scaleAB, - scaleT); + if (a_ptr && b_ptr) { + mul((CpuMatrix*)a_ptr, (CpuMatrix*)b_ptr, scaleAB, scaleT); + } else if (a_ptr_s && b_ptr) { + mul((CpuSparseMatrix*)a_ptr_s, (CpuMatrix*)b_ptr, scaleAB, scaleT); + } else if (a_ptr && b_ptr_s) { + mul((CpuMatrix*)a_ptr, (CpuSparseMatrix*)b_ptr_s, scaleAB, scaleT); } else { LOG(FATAL) << "Not supported"; } @@ -3289,7 +3280,7 @@ void CpuMatrix::addColumnVector(const Matrix& b) { } /* this = a*b */ -void CpuMatrix::mul(const MatrixPtr a, const MatrixPtr b) { +void CpuMatrix::mul(const Matrix& a, const Matrix& b) { return mul(a, b, 1.0, 0.0); } diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index b8c7adf9486be..dfcb0853df37c 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -444,8 +444,8 @@ class Matrix : public BaseMatrix { * this = scaleAB*(a*b) + scaleT*this * @endcode */ - virtual void mul(const MatrixPtr a, - const MatrixPtr b, + virtual void mul(const Matrix& a, + const Matrix& b, real scaleAB, real scaleT) { LOG(FATAL) << "Not implemented"; @@ -643,7 +643,7 @@ class Matrix : public BaseMatrix { * this = a*b * @endcode */ - virtual void mul(const MatrixPtr a, const MatrixPtr b) { + virtual void mul(const Matrix& a, const Matrix& b) { LOG(FATAL) << "Not implemented"; } @@ -1272,14 +1272,14 @@ class GpuMatrix : public Matrix { * this = scaleAB*(a*b) + scaleT*this * @endcode */ - void mul(const MatrixPtr a, const MatrixPtr b, real scaleAB, real scaleT); + void mul(const Matrix& a, const Matrix& b, real scaleAB, real scaleT); /** * @code * this = a*b * @endcode */ - void mul(const MatrixPtr a, const MatrixPtr b); + void mul(const Matrix& a, const Matrix& b); void mul(const GpuMatrix& a, const GpuMatrix& b, real scaleAB, real scaleT); @@ -1784,7 +1784,7 @@ class CpuMatrix : public Matrix { void addColumnVector(const Matrix& b); - void mul(const MatrixPtr a, const MatrixPtr b, real scaleAB, real scaleT); + void mul(const Matrix& a, const Matrix& b, real scaleAB, real scaleT); void mul(CpuMatrix* a, CpuMatrix* b, real scaleAB, real scaleT); void mul(CpuMatrix* a, CpuSparseMatrix* b, real scaleAB, real scaleT); @@ -1807,7 +1807,7 @@ class CpuMatrix : public Matrix { virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT); - void mul(const MatrixPtr a, const MatrixPtr b); + void mul(const Matrix& a, const Matrix& b); void rightMul(Matrix& b, real scaleAB, real scaleT); void rightMul(Matrix& b); diff --git a/paddle/math/SparseMatrix.cpp b/paddle/math/SparseMatrix.cpp index 9154503c2132a..720a035ecbd26 100644 --- a/paddle/math/SparseMatrix.cpp +++ b/paddle/math/SparseMatrix.cpp @@ -571,49 +571,48 @@ void GpuSparseMatrix::transpose(MatrixPtr matTrans, bool memAlloc) { hl_stream_synchronize(stream); } -void GpuSparseMatrix::mul(const GpuMatrixPtr a, - const GpuMatrixPtr b, +void GpuSparseMatrix::mul(const GpuMatrix& a, + const GpuMatrix& b, real scaleAB, real scaleT) { - CHECK(a->useGpu_ && b->useGpu_) << "type not match"; + CHECK(a.useGpu_ && b.useGpu_) << "type not match"; CHECK(!trans_) << "trans not supported"; - real* A_d = a->getData(); - real* B_d = b->getData(); + real* A_d = (real*)a.getData(); + real* B_d = (real*)b.getData(); hl_sparse_matrix_s C_d = sMatrix_.get(); - hl_trans_op_t a_trans = a->trans_ ? HPPL_OP_T : HPPL_OP_N; - hl_trans_op_t b_trans = b->trans_ ? HPPL_OP_T : HPPL_OP_N; - - if (!a->trans_ && !b->trans_) { - CHECK(height_ == a->getHeight()); - CHECK(width_ == b->getWidth()); - CHECK(a->getWidth() == b->getHeight()); - } else if (a->trans_ && !b->trans_) { - CHECK(height_ == a->getWidth()); - CHECK(width_ == b->getWidth()); - CHECK(a->getHeight() == b->getHeight()); - } else if (!a->trans_ && b->trans_) { - CHECK(height_ == a->getHeight()); - CHECK(width_ == b->getHeight()); - CHECK(a->getWidth() == b->getWidth()); + hl_trans_op_t a_trans = a.trans_ ? HPPL_OP_T : HPPL_OP_N; + hl_trans_op_t b_trans = b.trans_ ? HPPL_OP_T : HPPL_OP_N; + + if (!a.trans_ && !b.trans_) { + CHECK(height_ == a.getHeight()); + CHECK(width_ == b.getWidth()); + CHECK(a.getWidth() == b.getHeight()); + } else if (a.trans_ && !b.trans_) { + CHECK(height_ == a.getWidth()); + CHECK(width_ == b.getWidth()); + CHECK(a.getHeight() == b.getHeight()); + } else if (!a.trans_ && b.trans_) { + CHECK(height_ == a.getHeight()); + CHECK(width_ == b.getHeight()); + CHECK(a.getWidth() == b.getWidth()); } else { LOG(INFO) << "Not support"; } int dimM = height_; int dimN = width_; - int dimK = !b->trans_ ? b->getHeight() : b->getWidth(); + int dimK = !b.trans_ ? b.getHeight() : b.getWidth(); hl_sparse_matrix_mul( A_d, a_trans, B_d, b_trans, C_d, dimM, dimN, dimK, scaleAB, scaleT); } -void GpuSparseMatrix::mul(const MatrixPtr a, - const MatrixPtr b, +void GpuSparseMatrix::mul(const Matrix& a, + const Matrix& b, real scaleAB, real scaleT) { - if (std::dynamic_pointer_cast(a) && - std::dynamic_pointer_cast(b)) { - GpuMatrixPtr a_ptr = std::dynamic_pointer_cast(a); - GpuMatrixPtr b_ptr = std::dynamic_pointer_cast(b); - mul(a_ptr, b_ptr, scaleAB, scaleT); + const auto a_ptr = dynamic_cast(&a); + const auto b_ptr = dynamic_cast(&b); + if (a_ptr && b_ptr) { + mul(*a_ptr, *b_ptr, scaleAB, scaleT); } else { LOG(FATAL) << "not supported"; } diff --git a/paddle/math/SparseMatrix.h b/paddle/math/SparseMatrix.h index bd96a3301ded2..1d3801548e03a 100644 --- a/paddle/math/SparseMatrix.h +++ b/paddle/math/SparseMatrix.h @@ -104,10 +104,7 @@ class GpuSparseMatrix : public Matrix { size_t newNnz, SparseValueType valueType); - void mul(const GpuMatrixPtr a, - const GpuMatrixPtr b, - real scaleAB, - real scaleT); + void mul(const GpuMatrix& a, const GpuMatrix& b, real scaleAB, real scaleT); /// B = A , B.trans = !A.trans MatrixPtr getTranspose(); @@ -218,7 +215,7 @@ class GpuSparseMatrix : public Matrix { void copyRow(int offsets, size_t colNum, const sparse_float_value_t* row); public: - void mul(const MatrixPtr a, const MatrixPtr b, real scaleAB, real scaleT); + void mul(const Matrix& a, const Matrix& b, real scaleAB, real scaleT); void copyFrom(CpuSparseMatrix& src, hl_stream_t stream); void copyFrom(GpuSparseMatrix& src, hl_stream_t stream); diff --git a/paddle/math/tests/test_SparseMatrix.cpp b/paddle/math/tests/test_SparseMatrix.cpp index 88b75b6d83612..0949ab7ffba42 100644 --- a/paddle/math/tests/test_SparseMatrix.cpp +++ b/paddle/math/tests/test_SparseMatrix.cpp @@ -33,8 +33,8 @@ TEST(Matrix, CopyCpuMatrixToSparseMatrix) { ret2(new CpuMatrix(HEIGHT, WIDTH_TEST)); ret1->zeroMem(); ret2->zeroMem(); - ret1->mul(testMatrix, mulCpuMatrix, 1.0, 1.0); - ret2->mul(testCpuMatrix, mulCpuMatrix, 1.0, 1.0); + ret1->mul(*testMatrix, *mulCpuMatrix, 1.0, 1.0); + ret2->mul(*testCpuMatrix, *mulCpuMatrix, 1.0, 1.0); checkMatrixEqual(ret1, ret2); } @@ -147,9 +147,9 @@ void test_sparse_matrix_mul(MatrixPara paraA, hl_stream_synchronize(stream); /*matrix mul*/ - cpuMatrixC->mul(cpuMatrixA, cpuMatrixB, 1.0, 1.0); - gpuMatrixC->mul(gpuMatrixA, gpuMatrixB, 1.0, 1.0); - cpuDenseC->mul(cpuDenseA, cpuDenseB, 1.0, 1.0); + cpuMatrixC->mul(*cpuMatrixA, *cpuMatrixB, 1.0, 1.0); + gpuMatrixC->mul(*gpuMatrixA, *gpuMatrixB, 1.0, 1.0); + cpuDenseC->mul(*cpuDenseA, *cpuDenseB, 1.0, 1.0); gpuMatrixC_d2h->copyFrom(*gpuMatrixC, stream); hl_stream_synchronize(stream); @@ -224,8 +224,8 @@ TEST(Matrix, CopySparseMatrixToGpuSparseMatrix) { MatrixPtr ret2(new GpuMatrix(HEIGHT, WIDTH_TEST)); ret1->zeroMem(); ret2->zeroMem(); - ret1->mul(testMatrix, mulCpuMatrix, 1.0, 1.0); - ret2->mul(testGpuMatrix, mulGpuMatrix, 1.0, 1.0); + ret1->mul(*testMatrix, *mulCpuMatrix, 1.0, 1.0); + ret2->mul(*testGpuMatrix, *mulGpuMatrix, 1.0, 1.0); checkMatrixEqual(ret1, ret2); } diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 10289940a4c5f..c6fc849ba0328 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -318,7 +318,7 @@ void testMatrixInverse(int height) { cpu->randomizeUniform(); MatrixPtr cpuT = cpu->getTranspose(); MatrixPtr outputCheck = std::make_shared(height, height); - outputCheck->mul(cpu, cpuT); + outputCheck->mul(*cpu, *cpuT); cpu->setDiag(1.0); cpu->add(*outputCheck); @@ -328,7 +328,7 @@ void testMatrixInverse(int height) { TensorCheckErr(*cpuI, *gpuI); - outputCheck->mul(cpu, cpuI); + outputCheck->mul(*cpu, *cpuI); cpu->setDiag(1.0); TensorCheckErr(*cpu, *outputCheck); } @@ -509,8 +509,8 @@ void testMatrixMul(bool transa, bool transb, int dimM, int dimN, int dimK) { gpuB->copyFrom(*cpuB); gpuC->copyFrom(*cpuC); - cpuC->mul(cpuA, cpuB, alpha, beta); - gpuC->mul(gpuA, gpuB, alpha, beta); + cpuC->mul(*cpuA, *cpuB, alpha, beta); + gpuC->mul(*gpuA, *gpuB, alpha, beta); TensorCheckErr(*cpuC, *gpuC); } @@ -581,8 +581,8 @@ void testSubMatrixMul(bool transa, bool transb, int dimM, int dimN, int dimK) { MatrixPtr subCpuC = cpuC->subMatrix(startM, endM, startN, endN); MatrixPtr subGpuC = gpuC->subMatrix(startM, endM, startN, endN); - subCpuC->mul(subCpuA, subCpuB, alpha, beta); - subGpuC->mul(subGpuA, subGpuB, alpha, beta); + subCpuC->mul(*subCpuA, *subCpuB, alpha, beta); + subGpuC->mul(*subGpuA, *subGpuB, alpha, beta); TensorCheckErr(*cpuC, *gpuC); } diff --git a/paddle/math/tests/test_sparseMatrixCompare.cpp b/paddle/math/tests/test_sparseMatrixCompare.cpp index 6f6de238bacaa..dcdbccffc3a19 100644 --- a/paddle/math/tests/test_sparseMatrixCompare.cpp +++ b/paddle/math/tests/test_sparseMatrixCompare.cpp @@ -102,8 +102,8 @@ void testSpMatrixMul(int M, int N, int K, real rate) { gpuC->copyFrom(*cpuC, stream); hl_stream_synchronize(stream); - cpuC->mul(cpuA, cpuB->getTranspose(), 1, 1); - gpuC->mul(gpuA, gpuB->getTranspose(), 1, 1); + cpuC->mul(*cpuA, *cpuB->getTranspose(), 1, 1); + gpuC->mul(*gpuA, *gpuB->getTranspose(), 1, 1); MatrixPtr outputCheck(new CpuSparseMatrix(M, N, nnz)); outputCheck->copyFrom(*gpuC, stream); From 204152c76ea90100f286067f9ff298a8c79d33a3 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 20 Dec 2016 12:54:19 +0800 Subject: [PATCH 36/55] set -e for docs.sh --- paddle/scripts/travis/docs.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/scripts/travis/docs.sh b/paddle/scripts/travis/docs.sh index 4ab1746b5af81..cd331522a910a 100755 --- a/paddle/scripts/travis/docs.sh +++ b/paddle/scripts/travis/docs.sh @@ -12,13 +12,13 @@ set +e linkchecker doc/cn/html/index.html > doc_cn.out linkchecker doc/en/html/index.html > doc_en.out for i in doc_cn.out doc_en.out; do - echo $i grep " 0 errors found" $i if [ $? -ne 0 ]; then cat $i exit 1 fi done +set -e # Parse Github URL REPO=`git config remote.origin.url` From bf26679c3214f2c0c24f02218d3c15e720557a38 Mon Sep 17 00:00:00 2001 From: yangwenbo02 Date: Tue, 20 Dec 2016 13:51:56 +0800 Subject: [PATCH 37/55] update docker_install_en.rst --- doc/getstarted/build_and_install/docker_install_en.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index 1cc23ac3aa989..57725c0d85997 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -44,8 +44,7 @@ The general development workflow with Docker and Bazel is as follows: cd paddle docker build -t paddle:dev -f paddle/scripts/docker/Dockerfile . - Apt-get source errors may occur when building paddle docker image. - **You can specify the UBUNTU MIRROR with** :code:`--build-arg UBUNTU_MIRROR` **like the example below.** + Sometimes docker build might suffer from a slow network connection to the official Ubuntu apt-source servers. In such case, we can specify an apt-source mirror server that is geologically nearer to us. In the following example, we specified an apt-source server that responds fast in China.You can specify the UBUNTU MIRROR with :code:`--build-arg UBUNTU_MIRROR` like the example below. .. code-block:: bash From 6f8f468fdbfafafda1661b8002cfda76263cf9af Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Tue, 20 Dec 2016 16:28:41 +0800 Subject: [PATCH 38/55] Add priorbox layer gpu unit test. --- paddle/gserver/tests/test_PriorBox.cpp | 127 ++++++++++++++++++------- 1 file changed, 92 insertions(+), 35 deletions(-) diff --git a/paddle/gserver/tests/test_PriorBox.cpp b/paddle/gserver/tests/test_PriorBox.cpp index 1dab21218e49d..19dfd0f065da2 100644 --- a/paddle/gserver/tests/test_PriorBox.cpp +++ b/paddle/gserver/tests/test_PriorBox.cpp @@ -24,14 +24,15 @@ using namespace std; // NOLINT // Do one forward pass of priorBox layer and check to see if its output // matches the given result -void doOnePriorBoxTest(size_t featureMapWidth, - size_t featureMapHeight, - size_t imageWidth, - size_t imageHeight, - vector minSize, - vector maxSize, - vector aspectRatio, +void doOnePriorBoxTest(size_t feature_map_width, + size_t feature_map_height, + size_t image_width, + size_t image_height, + vector min_size, + vector max_size, + vector aspect_ratio, vector variance, + bool use_gpu, MatrixPtr& result) { // Setting up the priorbox layer TestConfig configt; @@ -42,28 +43,27 @@ void doOnePriorBoxTest(size_t featureMapWidth, configt.inputDefs.push_back({INPUT_DATA, "image", 1, 0}); configt.layerConfig.add_inputs(); PriorBoxConfig* pb = input->mutable_priorbox_conf(); - for (size_t i = 0; i < minSize.size(); i++) pb->add_min_size(minSize[i]); - for (size_t i = 0; i < maxSize.size(); i++) pb->add_max_size(maxSize[i]); - for (size_t i = 0; i < aspectRatio.size(); i++) - pb->add_aspect_ratio(aspectRatio[i]); + for (size_t i = 0; i < min_size.size(); i++) pb->add_min_size(min_size[i]); + for (size_t i = 0; i < max_size.size(); i++) pb->add_max_size(max_size[i]); for (size_t i = 0; i < variance.size(); i++) pb->add_variance(variance[i]); + for (size_t i = 0; i < aspect_ratio.size(); i++) + pb->add_aspect_ratio(aspect_ratio[i]); // data layer initialize std::vector dataLayers; LayerMap layerMap; vector datas; initDataLayer( - configt, &dataLayers, &datas, &layerMap, "priorbox", 1, false, false); - dataLayers[0]->getOutput().setFrameHeight(featureMapHeight); - dataLayers[0]->getOutput().setFrameWidth(featureMapWidth); - dataLayers[1]->getOutput().setFrameHeight(imageHeight); - dataLayers[1]->getOutput().setFrameWidth(imageWidth); + configt, &dataLayers, &datas, &layerMap, "priorbox", 1, false, use_gpu); + dataLayers[0]->getOutput().setFrameHeight(feature_map_height); + dataLayers[0]->getOutput().setFrameWidth(feature_map_width); + dataLayers[1]->getOutput().setFrameHeight(image_height); + dataLayers[1]->getOutput().setFrameWidth(image_width); // test layer initialize std::vector parameters; LayerPtr priorboxLayer; initTestLayer(configt, &layerMap, ¶meters, &priorboxLayer); - priorboxLayer->forward(PASS_GC); checkMatrixEqual(priorboxLayer->getOutputValue(), result); } @@ -73,6 +73,7 @@ TEST(Layer, priorBoxLayerFwd) { vector maxSize; vector aspectRatio; vector variance; + bool useGpu = false; minSize.push_back(276); maxSize.push_back(330); @@ -81,9 +82,8 @@ TEST(Layer, priorBoxLayerFwd) { variance.push_back(0.2); variance.push_back(0.2); + // CPU case 1. MatrixPtr result; - result = Matrix::create(1, 2 * 8, false, false); - float resultData[] = {0.04, 0.04, 0.96, @@ -100,52 +100,109 @@ TEST(Layer, priorBoxLayerFwd) { 0.1, 0.2, 0.2}; + result = Matrix::create(1, 2 * 8, false, useGpu); result->setData(resultData); - doOnePriorBoxTest(/* featureMapWidth */ 1, - /* featureMapHeight */ 1, - /* imageWidth */ 300, - /* imageHeight */ 300, + doOnePriorBoxTest(/* feature_map_width */ 1, + /* feature_map_height */ 1, + /* image_width */ 300, + /* image_height */ 300, minSize, maxSize, aspectRatio, variance, + useGpu, result); - + // CPU case 2. variance[1] = 0.2; variance[3] = 0.1; maxSize.pop_back(); - Matrix::resizeOrCreate(result, 1, 4 * 8, false, false); float resultData2[] = {0, 0, 0.595, 0.595, 0.1, 0.2, 0.2, 0.1, 0.405, 0, 1, 0.595, 0.1, 0.2, 0.2, 0.1, 0, 0.405, 0.595, 1, 0.1, 0.2, 0.2, 0.1, 0.405, 0.405, 1, 1, 0.1, 0.2, 0.2, 0.1}; + Matrix::resizeOrCreate(result, 1, 4 * 8, false, useGpu); result->setData(resultData2); - doOnePriorBoxTest(/* featureMapWidth */ 2, - /* featureMapHeight */ 2, - /* imageWidth */ 400, - /* imageHeight */ 400, + doOnePriorBoxTest(/* feature_map_width */ 2, + /* feature_map_height */ 2, + /* image_width */ 400, + /* image_height */ 400, minSize, maxSize, aspectRatio, variance, + useGpu, result); - + // CPU case 3. aspectRatio.push_back(2); - Matrix::resizeOrCreate(result, 1, 3 * 8, false, false); float resultData3[] = {0.04, 0.04, 0.96, 0.96, 0.1, 0.2, 0.2, 0.1, 0, 0.17473088, 1, 0.825269, 0.1, 0.2, 0.2, 0.1, 0.17473088, 0, 0.825269, 1, 0.1, 0.2, 0.2, 0.1}; + Matrix::resizeOrCreate(result, 1, 3 * 8, false, useGpu); result->setData(resultData3); - doOnePriorBoxTest(/* featureMapWidth */ 1, - /* featureMapHeight */ 1, - /* imageWidth */ 300, - /* imageHeight */ 300, + doOnePriorBoxTest(/* feature_map_width */ 1, + /* feature_map_height */ 1, + /* image_width */ 300, + /* image_height */ 300, minSize, maxSize, aspectRatio, variance, + useGpu, result); + +#ifndef PADDLE_ONLY_CPU + // reset the input parameters + variance[1] = 0.1; + variance[3] = 0.2; + maxSize.push_back(330); + aspectRatio.pop_back(); + MatrixPtr resultGpu; + useGpu = true; + // GPU case 1. + resultGpu = Matrix::create(1, 2 * 8, false, useGpu); + resultGpu->copyFrom(resultData, 2 * 8); + doOnePriorBoxTest(/* feature_map_width */ 1, + /* feature_map_height */ 1, + /* image_width */ 300, + /* image_height */ 300, + minSize, + maxSize, + aspectRatio, + variance, + useGpu, + resultGpu); + // GPU case 2. + variance[1] = 0.2; + variance[3] = 0.1; + maxSize.pop_back(); + Matrix::resizeOrCreate(resultGpu, 1, 4 * 8, false, useGpu); + resultGpu->copyFrom(resultData2, 4 * 8); + doOnePriorBoxTest(/* feature_map_width */ 2, + /* feature_map_height */ 2, + /* image_width */ 400, + /* image_height */ 400, + minSize, + maxSize, + aspectRatio, + variance, + useGpu, + resultGpu); + // GPU case 3. + aspectRatio.push_back(2); + Matrix::resizeOrCreate(resultGpu, 1, 3 * 8, false, useGpu); + resultGpu->copyFrom(resultData3, 3 * 8); + doOnePriorBoxTest(/* feature_map_width */ 1, + /* feature_map_height */ 1, + /* image_width */ 300, + /* image_height */ 300, + minSize, + maxSize, + aspectRatio, + variance, + useGpu, + resultGpu); +#endif } int main(int argc, char** argv) { From 5fddd99e18f3920ff0d8158fd4a9800d5566943e Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 20 Dec 2016 17:20:22 +0800 Subject: [PATCH 39/55] move TEST from test_matrixCompare.cpp to cross_map_normal_op_test.cpp --- cmake/util.cmake | 1 + paddle/function/CMakeLists.txt | 35 +++-- paddle/function/FunctionTest.h | 102 +++++++++++++ paddle/function/TestMain.cpp | 22 +++ paddle/function/cross_map_normal_op_test.cpp | 71 +++++++++ paddle/math/tests/test_matrixCompare.cpp | 144 ------------------- 6 files changed, 221 insertions(+), 154 deletions(-) create mode 100644 paddle/function/FunctionTest.h create mode 100644 paddle/function/TestMain.cpp create mode 100644 paddle/function/cross_map_normal_op_test.cpp diff --git a/cmake/util.cmake b/cmake/util.cmake index 03734e7839d74..8a71b23c62d9f 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -107,6 +107,7 @@ function(link_paddle_exe TARGET_NAME) paddle_parameter paddle_proto paddle_cuda + paddle_test_main ${METRIC_LIBS} ${PROTOBUF_LIBRARY} ${LIBGLOG_LIBRARY} diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 8fad0e3ebdfb2..0697842bbef62 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -1,12 +1,27 @@ -file(GLOB FUNCTION_HEADERS . *.h) - -if(NOT WITH_GPU) - file(GLOB FUNCTION_SOURCES . *.cpp) - add_library(paddle_function STATIC ${FUNCTION_SOURCES}) -else() - file(GLOB FUNCTION_SOURCES . *.cpp *.cu) - cuda_add_library(paddle_function ${FUNCTION_SOURCES}) +file(GLOB h_files . *_op.h) +file(GLOB cpp_files . *_op.cpp) + +list(APPEND h_files Function.h) +list(APPEND cpp_files Function.cpp) + +if(WITH_GPU) + file(GLOB cu_files . *_op_gpu.cu) + cuda_compile(cu_objs ${cu_files}) endif() -add_style_check_target(paddle_function ${FUNCTION_SOURCES}) -add_style_check_target(paddle_function ${FUNCTION_HEADERS}) +add_library(paddle_function STATIC ${cpp_files} ${cu_objs}) + +add_library(paddle_test_main STATIC TestMain.cpp) + +if(WITH_GPU) + # TODO: + # file(GLOB test_files . *_op_test.cpp) + # add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files}) + add_simple_unittest(cross_map_normal_op_test) +endif() + +add_style_check_target(paddle_function ${h_files}) +add_style_check_target(paddle_function ${cpp_files}) +if(WITH_GPU) + add_style_check_target(paddle_function ${cu_files}) +endif() diff --git a/paddle/function/FunctionTest.h b/paddle/function/FunctionTest.h new file mode 100644 index 0000000000000..a8c5e412bd12d --- /dev/null +++ b/paddle/function/FunctionTest.h @@ -0,0 +1,102 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" +#include "paddle/math/Vector.h" +#include "paddle/math/tests/TensorCheck.h" + +namespace paddle { + +class FunctionCompare { +public: + FunctionCompare(const std::string& name, const FuncConfig& config) + : cpu(FunctionBase::funcRegistrar_.createByType(name + "-CPU")), + gpu(FunctionBase::funcRegistrar_.createByType(name + "-GPU")) { + cpu->init(config); + gpu->init(config); + } + + void cmpWithArg(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) { + // init cpu and gpu arguments + auto initArgs = [=]( + Arguments& cpuArgs, Arguments& gpuArgs, const Arguments& inArgs) { + for (auto arg : inArgs) { + size_t size = sizeof(real); + for (auto dim : arg.dims_) { + size *= dim; + } + cpuMemory.emplace_back(std::make_shared(size)); + gpuMemory.emplace_back(std::make_shared(size)); + cpuArgs.emplace_back( + Tensor((real*)cpuMemory.back()->getBuf(), arg.dims_)); + gpuArgs.emplace_back( + Tensor((real*)gpuMemory.back()->getBuf(), arg.dims_)); + + // will use an api to refactor this code. + CpuVector cpuVector(size / sizeof(real), + (real*)cpuArgs.back().getData()); + GpuVector gpuVector(size / sizeof(real), + (real*)gpuArgs.back().getData()); + cpuVector.uniform(0.001, 1); + gpuVector.copyFrom(cpuVector); + } + }; + initArgs(cpuInputs, gpuInputs, inputs); + initArgs(cpuOutputs, gpuOutputs, outputs); + initArgs(cpuInouts, gpuInouts, inouts); + + // function calculate + cpu->calc(cpuInputs, cpuOutputs, cpuInouts); + gpu->calc(gpuInputs, gpuOutputs, gpuInouts); + + // check outputs and inouts + auto checkArgs = [=](const Arguments& cpuArgs, const Arguments& gpuArgs) { + for (size_t i = 0; i < cpuArgs.size(); i++) { + auto cpu = cpuArgs[i]; + auto gpu = gpuArgs[i]; + size_t size = 1; + for (auto dim : cpu.dims_) { + size *= dim; + } + CpuVector cpuVector(size, (real*)cpu.getData()); + GpuVector gpuVector(size, (real*)gpu.getData()); + + autotest::TensorCheckErr(cpuVector, gpuVector); + } + }; + checkArgs(cpuOutputs, gpuOutputs); + checkArgs(cpuInouts, gpuInouts); + } + +protected: + std::shared_ptr cpu; + std::shared_ptr gpu; + std::vector cpuMemory; + std::vector gpuMemory; + Arguments cpuInputs; + Arguments cpuOutputs; + Arguments cpuInouts; + Arguments gpuInputs; + Arguments gpuOutputs; + Arguments gpuInouts; +}; + +} // namespace paddle + +using paddle::FunctionCompare; +using paddle::FuncConfig; +using paddle::Dims; +using paddle::Tensor; diff --git a/paddle/function/TestMain.cpp b/paddle/function/TestMain.cpp new file mode 100644 index 0000000000000..3e14532d1878f --- /dev/null +++ b/paddle/function/TestMain.cpp @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/utils/Util.h" + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + paddle::initMain(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/function/cross_map_normal_op_test.cpp b/paddle/function/cross_map_normal_op_test.cpp new file mode 100644 index 0000000000000..22692691bdb64 --- /dev/null +++ b/paddle/function/cross_map_normal_op_test.cpp @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +TEST(CrossMapNormal, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {1, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + for (size_t size : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " size=" << size; + + FunctionCompare compare("CrossMapNormal", + FuncConfig() + .set("size", size) + .set("scale", (real)1.5) + .set("pow", (real)0.5)); + Dims dims{numSamples, channels, imgSizeH, imgSizeW}; + compare.cmpWithArg({Tensor(nullptr, dims)}, + {Tensor(nullptr, dims), Tensor(nullptr, dims)}, + {}); + } + } + } + } + } +} + +TEST(CrossMapNormalGrad, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {1, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + for (size_t size : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " size=" << size; + + FunctionCompare compare("CrossMapNormalGrad", + FuncConfig() + .set("size", size) + .set("scale", (real)1.5) + .set("pow", (real)0.5)); + Dims dims{numSamples, channels, imgSizeH, imgSizeW}; + compare.cmpWithArg({Tensor(nullptr, dims), + Tensor(nullptr, dims), + Tensor(nullptr, dims), + Tensor(nullptr, dims)}, + {Tensor(nullptr, dims)}, + {}); + } + } + } + } + } +} diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index c89b7ff490232..440534e722700 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1263,150 +1263,6 @@ TEST(Matrix, MaxOutFwdBwd) { } } -void testCrossMapNormalFwd( - int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { - float scale = 1.5; - float pow = 0.5; - int width = imgSizeH * imgSizeW * channels; - CpuMatrix inputs(numSamples, width); - CpuMatrix denoms(numSamples, width); - CpuMatrix outputs(numSamples, width); - GpuMatrix inputsGpu(numSamples, width); - GpuMatrix denomsGpu(numSamples, width); - GpuMatrix outputsGpu(numSamples, width); - - inputs.randomizeUniform(); - outputs.randomizeUniform(); - inputsGpu.copyFrom(inputs); - outputsGpu.copyFrom(outputs); - - FunctionBase* cpu = - FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); - FunctionBase* gpu = - FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, GPU)); - cpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - gpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - - Dims dims{ - (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; - cpu->calc({Tensor(inputs.getData(), dims)}, - {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, - {}); - - gpu->calc( - {Tensor(inputsGpu.getData(), dims)}, - {Tensor(outputsGpu.getData(), dims), Tensor(denomsGpu.getData(), dims)}, - {}); - - TensorCheckErr(outputs, outputsGpu); - TensorCheckErr(denoms, denomsGpu); -} - -TEST(Matrix, crossMapNormalFwd) { - for (auto numSamples : {5, 32}) { - for (auto channels : {1, 5, 32}) { - for (auto imgSizeH : {5, 33, 100}) { - for (auto imgSizeW : {5, 32, 96}) { - for (auto sizeX : {1, 2, 3, 5, 7}) { - VLOG(3) << " numSamples=" << numSamples << " channels=" << channels - << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW - << " sizeX=" << sizeX; - testCrossMapNormalFwd( - numSamples, channels, imgSizeH, imgSizeW, sizeX); - } - } - } - } - } -} - -void testCrossMapNormalBwd( - int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { - float scale = 1.5; - float pow = 0.5; - size_t width = imgSizeH * imgSizeW * channels; - - CpuMatrix inputsGrad(numSamples, width); - CpuMatrix inputsValue(numSamples, width); - CpuMatrix outputsGrad(numSamples, width); - CpuMatrix outputsValue(numSamples, width); - CpuMatrix denoms(numSamples, width); - - outputsGrad.randomizeUniform(); - denoms.randomizeUniform(); - inputsValue.randomizeUniform(); - outputsValue.randomizeUniform(); - inputsGrad.randomizeUniform(); - denoms.add(0.01); - - GpuMatrix inputsGradGpu(numSamples, width); - GpuMatrix inputsValueGpu(numSamples, width); - GpuMatrix outputsGradGpu(numSamples, width); - GpuMatrix outputsValueGpu(numSamples, width); - GpuMatrix denomsGpu(numSamples, width); - - outputsGradGpu.copyFrom(outputsGrad); - denomsGpu.copyFrom(denoms); - inputsValueGpu.copyFrom(inputsValue); - outputsValueGpu.copyFrom(outputsValue); - inputsGradGpu.copyFrom(inputsGrad); - - FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, CPU)); - FunctionBase* gpu = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, GPU)); - cpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - gpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - - Dims dims{ - (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; - cpu->calc({Tensor(inputsValue.getData(), dims), - Tensor(outputsValue.getData(), dims), - Tensor(outputsGrad.getData(), dims), - Tensor(denoms.getData(), dims)}, - {Tensor(inputsGrad.getData(), dims)}, - {}); - - gpu->calc({Tensor(inputsValueGpu.getData(), dims), - Tensor(outputsValueGpu.getData(), dims), - Tensor(outputsGradGpu.getData(), dims), - Tensor(denomsGpu.getData(), dims)}, - {Tensor(inputsGradGpu.getData(), dims)}, - {}); - - TensorCheckErr(inputsGrad, inputsGradGpu); -} - -TEST(Matrix, crossMapNormalBwd) { - for (auto numSamples : {5, 32}) { - for (auto channels : {1, 5, 32}) { - for (auto imgSizeH : {5, 33, 100}) { - for (auto imgSizeW : {5, 32, 96}) { - for (auto sizeX : {1, 2, 3, 5, 7}) { - VLOG(3) << " numSamples=" << numSamples << " channels=" << channels - << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW - << " sizeX=" << sizeX; - testCrossMapNormalBwd( - numSamples, channels, imgSizeH, imgSizeW, sizeX); - } - } - } - } - } -} - int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); From f1a94e3ff7fce800f6c846da2ae6ad4312c4acfc Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 20 Dec 2016 20:30:13 +0800 Subject: [PATCH 40/55] follow comments --- paddle/function/cross_map_normal_op.cpp | 22 +++++++++++----------- paddle/function/cross_map_normal_op.h | 10 +++++----- paddle/function/cross_map_normal_op_gpu.cu | 10 +++++----- paddle/math/tests/test_matrixCompare.cpp | 1 - 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/paddle/function/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp index a18c0bb750acf..a9c7693830542 100644 --- a/paddle/function/cross_map_normal_op.cpp +++ b/paddle/function/cross_map_normal_op.cpp @@ -20,7 +20,7 @@ namespace paddle { template <> void CrossMapNormal(real* outputs, real* denoms, - real* inputs, + const real* inputs, size_t numSamples, size_t channels, size_t height, @@ -32,7 +32,7 @@ void CrossMapNormal(real* outputs, size_t oneSample = channels * oneImage; CpuVector outputsV(numSamples * oneSample, outputs); - CpuVector inputsV(numSamples * oneSample, inputs); + CpuVector inputsV(numSamples * oneSample, const_cast(inputs)); CpuVector denomsV(numSamples * oneSample, denoms); // f(x) = x * ( 1 + scale * SUM((x)^2) )^(-pow) @@ -44,7 +44,7 @@ void CrossMapNormal(real* outputs, const int end = (int)size + start; for (size_t i = 0; i < numSamples; i++) { real* oneDenom = denoms + i * oneSample; - real* oneInput = inputs + i * oneSample; + real* oneInput = const_cast(inputs) + i * oneSample; for (int c = 0; c < (int)channels; c++) { CpuVector denom(oneImage, oneDenom + c * oneImage); for (int s = start; s < end; s++) { @@ -61,10 +61,10 @@ void CrossMapNormal(real* outputs, template <> void CrossMapNormalGrad(real* inputsGrad, - real* inputsValue, - real* outputsValue, - real* outputsGrad, - real* denoms, + const real* inputsValue, + const real* outputsValue, + const real* outputsGrad, + const real* denoms, size_t numSamples, size_t channels, size_t height, @@ -84,10 +84,10 @@ void CrossMapNormalGrad(real* inputsGrad, for (size_t i = 0; i < numSamples; i++) { size_t sOffset = i * oneSample; real* oneInputGrad = inputsGrad + sOffset; - real* oneInputValue = inputsValue + sOffset; - real* oneDenom = denoms + sOffset; - real* oneOutputGrad = outputsGrad + sOffset; - real* oneOutputValue = outputsValue + sOffset; + real* oneInputValue = const_cast(inputsValue) + sOffset; + real* oneDenom = const_cast(denoms) + sOffset; + real* oneOutputGrad = const_cast(outputsGrad) + sOffset; + real* oneOutputValue = const_cast(outputsValue) + sOffset; for (int c = 0; c < (int)channels; c++) { size_t cOffset = c * height * width; diff --git a/paddle/function/cross_map_normal_op.h b/paddle/function/cross_map_normal_op.h index e935b26e125d3..b1e401ad0a2f5 100644 --- a/paddle/function/cross_map_normal_op.h +++ b/paddle/function/cross_map_normal_op.h @@ -37,7 +37,7 @@ namespace paddle { template void CrossMapNormal(real* outputs, real* denoms, - real* inputs, + const real* inputs, size_t numSamples, size_t channels, size_t height, @@ -66,10 +66,10 @@ void CrossMapNormal(real* outputs, */ template void CrossMapNormalGrad(real* inputsGrad, - real* inputsValue, - real* outputsValue, - real* outputsGrad, - real* denoms, + const real* inputsValue, + const real* outputsValue, + const real* outputsGrad, + const real* denoms, size_t numSamples, size_t channels, size_t height, diff --git a/paddle/function/cross_map_normal_op_gpu.cu b/paddle/function/cross_map_normal_op_gpu.cu index 6339c04194834..aae4f461b6f57 100644 --- a/paddle/function/cross_map_normal_op_gpu.cu +++ b/paddle/function/cross_map_normal_op_gpu.cu @@ -63,7 +63,7 @@ __global__ void KeCMRNormOutput(size_t inputSize, const real* in, template <> void CrossMapNormal(real* outputs, real* denoms, - real* inputs, + const real* inputs, size_t numSamples, size_t channels, size_t height, @@ -132,10 +132,10 @@ __global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, template <> void CrossMapNormalGrad(real* inputsGrad, - real* inputsValue, - real* outputsValue, - real* outputsGrad, - real* denoms, + const real* inputsValue, + const real* outputsValue, + const real* outputsGrad, + const real* denoms, size_t numSamples, size_t channels, size_t height, diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 440534e722700..62de5b25e4cc8 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "TensorCheck.h" -#include "paddle/function/Function.h" #include "paddle/gserver/tests/TestUtil.h" #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h" From f4f0f2daeb3bd0bffd8302a4388098e0ab1ffed6 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 20 Dec 2016 20:30:37 +0800 Subject: [PATCH 41/55] Fix bug in config_parse.py when batch_norm layer is used in RecurrentLayerGroup. --- python/paddle/trainer/config_parser.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 39892d0533aab..0308d9df94839 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -498,9 +498,12 @@ def __init__( is_static=None, is_shared=None, update_hooks=None, - input_layer_argument=None, ): + input_layer_argument=None, + not_make_layer_name_in_submodel=None, ): self.add_keys(locals()) self.input_layer_name = MakeLayerNameInSubmodel(input_layer_name) + if not_make_layer_name_in_submodel: + self.input_layer_name = input_layer_name # Define a projection for iexed layer @@ -1848,7 +1851,8 @@ def __init__(self, initial_std=0.0, initial_mean=0.0, is_static=True, - is_shared=is_shared, )) + is_shared=is_shared, + not_make_layer_name_in_submodel=True, )) parallel_nn = bool(int(g_command_config_args.get("parallel_nn", 0))) cudnn_version = int(g_command_config_args.get("cudnn_version", 0)) From 35bbb4fb01a2172c867cd7c27a3c805e87f1ea69 Mon Sep 17 00:00:00 2001 From: Peng Li Date: Tue, 20 Dec 2016 21:04:53 +0800 Subject: [PATCH 42/55] change float to real in two test Change float in test_ConvTrans and test_ConvUnify to real. --- paddle/gserver/tests/test_ConvTrans.cpp | 12 +++---- paddle/gserver/tests/test_ConvUnify.cpp | 46 ++++++++++++------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/paddle/gserver/tests/test_ConvTrans.cpp b/paddle/gserver/tests/test_ConvTrans.cpp index 99202c2d5702a..dd3378304b433 100644 --- a/paddle/gserver/tests/test_ConvTrans.cpp +++ b/paddle/gserver/tests/test_ConvTrans.cpp @@ -206,8 +206,8 @@ TEST(Layer, convTransLayerFwd2) { /* filter_size */ 5, result); - float resultData[] = {1, 2, 2, 2, 1, 2, 4, 4, 4, 2, 2, 4, 4, - 4, 2, 2, 4, 4, 4, 2, 1, 2, 2, 2, 1}; + real resultData[] = {1, 2, 2, 2, 1, 2, 4, 4, 4, 2, 2, 4, 4, + 4, 2, 2, 4, 4, 4, 2, 1, 2, 2, 2, 1}; result->setData(resultData); doOneConvtTest(/* imgSize */ 5, /* output_x */ 2, @@ -216,8 +216,8 @@ TEST(Layer, convTransLayerFwd2) { /* filter_size */ 4, result); - float resultData2[] = {1, 2, 2, 2, 1, 2, 4, 4, 4, 2, 2, 4, 4, - 4, 2, 2, 4, 4, 4, 2, 1, 2, 2, 2, 1}; + real resultData2[] = {1, 2, 2, 2, 1, 2, 4, 4, 4, 2, 2, 4, 4, + 4, 2, 2, 4, 4, 4, 2, 1, 2, 2, 2, 1}; result->setData(resultData2); doOneConvtTest(/* imgSize */ 5, /* output_x */ 2, @@ -226,8 +226,8 @@ TEST(Layer, convTransLayerFwd2) { /* filter_size */ 5, result); - float resultData3[] = {1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 4, - 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1}; + real resultData3[] = {1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 4, + 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1}; result->setData(resultData3); doOneConvtTest(/* imgSize */ 5, /* output_x */ 2, diff --git a/paddle/gserver/tests/test_ConvUnify.cpp b/paddle/gserver/tests/test_ConvUnify.cpp index 2ab18f886848d..072a886a198f3 100644 --- a/paddle/gserver/tests/test_ConvUnify.cpp +++ b/paddle/gserver/tests/test_ConvUnify.cpp @@ -106,8 +106,8 @@ TEST(Layer, convParaUnified) { #ifndef PADDLE_ONLY_CPU MatrixPtr input, resultCpu, resultGpu; input = Matrix::create(1, 4 * 4, false, false); - float inputData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - float param[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 8, 7, 6, 5, 4, 3, 2, 1}; + real inputData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + real param[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 8, 7, 6, 5, 4, 3, 2, 1}; input->setData(inputData); @@ -137,26 +137,26 @@ TEST(Layer, convParaUnified) { checkMatrixEqual(resultCpu, resultGpu); input = Matrix::create(1, 3 * 3 * 2, false, false); - float inputData2[] = {1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18}; - float param2[] = {1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1}; + real inputData2[] = {1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18}; + real param2[] = {1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1}; input->setData(inputData2); @@ -185,7 +185,7 @@ TEST(Layer, convParaUnified) { true); checkMatrixEqual(resultCpu, resultGpu); - float param3[] = {1, 2, 3, 4, 4, 3, 2, 1}; + real param3[] = {1, 2, 3, 4, 4, 3, 2, 1}; resultCpu = doOneConvTest(/* imgSize */ 3, /* output_x */ 2, From 84ad724f99164eba9c45dfc10e280f8f8104689a Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Tue, 20 Dec 2016 16:31:11 -0800 Subject: [PATCH 43/55] Adding namespace in timing macros Sometime those macros are used under different namespaces. We need to use namespace ::paddle to make it compile correctly. Change-Id: I57a6d6ec8cd0d680b584aab62d72a35c226a24a4 --- paddle/utils/Stat.cpp | 3 +++ paddle/utils/Stat.h | 49 +++++++++++++++++++++++++++---------------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/paddle/utils/Stat.cpp b/paddle/utils/Stat.cpp index 44acee249554e..c7194d3bf1271 100644 --- a/paddle/utils/Stat.cpp +++ b/paddle/utils/Stat.cpp @@ -137,6 +137,9 @@ void StatSet::printSegTimerStatus() { void StatSet::printBarrierTimerStatus() { ReadLockGuard guard(lock_); + if (barrierStatSet_.empty()) { + return; + } // control barrierAbstact in runtime, so enable compliation LOG(INFO) << std::setiosflags(std::ios::left) << std::setfill(' ') << "======= BarrierStatSet status ======" << std::endl; diff --git a/paddle/utils/Stat.h b/paddle/utils/Stat.h index 9be79e8859a3b..d9cc6e413a741 100644 --- a/paddle/utils/Stat.h +++ b/paddle/utils/Stat.h @@ -258,28 +258,41 @@ inline StatSet& registerTimerArg2(uint64_t threshold = -1, // The default arguments are shown in the following line: // REGISTER_TIMER(statName, threshold = -1, statSet = globalStat) // TODO(yuyang18,wangyanfei01): if UNIQUE_NAME is needed -#define REGISTER_TIMER(statName, ...) \ - static StatPtr __stat = registerTimerArg2(__VA_ARGS__).getStat(statName); \ - TimerOnce __timerOnce(__stat.get(), "", registerTimerArg1(__VA_ARGS__)); +#define REGISTER_TIMER(statName, ...) \ + static ::paddle::StatPtr __stat = \ + ::paddle::registerTimerArg2(__VA_ARGS__).getStat(statName); \ + ::paddle::TimerOnce __timerOnce( \ + __stat.get(), "", ::paddle::registerTimerArg1(__VA_ARGS__)); #define REGISTER_TIMER_SET(statName, start, ...) \ - static StatPtr __stat = registerTimerArg2(__VA_ARGS__).getStat(statName); \ - TimerOnce __timerOnce( \ - __stat.get(), "", registerTimerArg1(__VA_ARGS__), false, start); + static ::paddle::StatPtr __stat = \ + ::paddle::registerTimerArg2(__VA_ARGS__).getStat(statName); \ + ::paddle::TimerOnce __timerOnce(__stat.get(), \ + "", \ + ::paddle::registerTimerArg1(__VA_ARGS__), \ + false, \ + start); // dynmaic timer, support to discriminate runtime entity, used in pserver -#define REGISTER_TIMER_DYNAMIC(statName, ...) \ - StatPtr __stat = registerTimerArg2(__VA_ARGS__).getStat(statName); \ - TimerOnce __timerOnce(__stat.get(), "", registerTimerArg1(__VA_ARGS__)); - -#define REGISTER_TIMER_DYNAMIC_SET(statName, start, ...) \ - StatPtr __stat = registerTimerArg2(__VA_ARGS__).getStat(statName); \ - TimerOnce __timerOnce( \ - __stat.get(), "", registerTimerArg1(__VA_ARGS__), false, start); - -#define REGISTER_TIMER_INFO(statName, info) \ - static StatPtr __stat = globalStat.getStat(statName); \ - TimerOnce __timerOnce(__stat.get(), info, 10 * 1000000LU /*threshold*/); +#define REGISTER_TIMER_DYNAMIC(statName, ...) \ + ::paddle::StatPtr __stat = \ + ::paddle::registerTimerArg2(__VA_ARGS__).getStat(statName); \ + ::paddle::TimerOnce __timerOnce( \ + __stat.get(), "", ::paddle::registerTimerArg1(__VA_ARGS__)); + +#define REGISTER_TIMER_DYNAMIC_SET(statName, start, ...) \ + ::paddle::StatPtr __stat = \ + ::paddle::registerTimerArg2(__VA_ARGS__).getStat(statName); \ + ::paddle::TimerOnce __timerOnce(__stat.get(), \ + "", \ + ::paddle::registerTimerArg1(__VA_ARGS__), \ + false, \ + start); + +#define REGISTER_TIMER_INFO(statName, info) \ + static ::paddle::StatPtr __stat = ::paddle::globalStat.getStat(statName); \ + ::paddle::TimerOnce __timerOnce( \ + __stat.get(), info, 10 * 1000000LU /*threshold*/); #endif // DISABLE_TIMER From 5bb29ece7fd5352b93100a20b4bf904c5b5bc2f0 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Wed, 21 Dec 2016 09:55:09 +0800 Subject: [PATCH 44/55] close log info in BN. --- python/paddle/trainer/config_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 0308d9df94839..8389476e6a5a5 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1884,7 +1884,7 @@ def __init__(self, # when either of it is non-zero. if input_layer.width != 0 or input_layer.height != 0: self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size, - image_conf.channels, True) + image_conf.channels, False) else: self.set_layer_size(input_layer.size) From 67fcd898c5f17bf7a61237351e9087257a8a34f2 Mon Sep 17 00:00:00 2001 From: Peng Li Date: Wed, 21 Dec 2016 10:14:10 +0800 Subject: [PATCH 45/55] fix array style problem --- paddle/gserver/tests/test_ConvUnify.cpp | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/paddle/gserver/tests/test_ConvUnify.cpp b/paddle/gserver/tests/test_ConvUnify.cpp index 072a886a198f3..ad99b50245cf5 100644 --- a/paddle/gserver/tests/test_ConvUnify.cpp +++ b/paddle/gserver/tests/test_ConvUnify.cpp @@ -137,25 +137,8 @@ TEST(Layer, convParaUnified) { checkMatrixEqual(resultCpu, resultGpu); input = Matrix::create(1, 3 * 3 * 2, false, false); - real inputData2[] = {1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18}; + real inputData2[] = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}; real param2[] = {1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1}; input->setData(inputData2); From de8927ebe10a2ce8f7eb05b45e07794998040270 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Wed, 21 Dec 2016 10:30:58 +0800 Subject: [PATCH 46/55] refine docs.sh --- paddle/scripts/travis/docs.sh | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/paddle/scripts/travis/docs.sh b/paddle/scripts/travis/docs.sh index cd331522a910a..1b05dce29a12f 100755 --- a/paddle/scripts/travis/docs.sh +++ b/paddle/scripts/travis/docs.sh @@ -9,12 +9,9 @@ make paddle_docs paddle_docs_cn # check websites for broken links set +e -linkchecker doc/cn/html/index.html > doc_cn.out -linkchecker doc/en/html/index.html > doc_en.out -for i in doc_cn.out doc_en.out; do - grep " 0 errors found" $i +for i in cn en; do + linkchecker doc/$i/html/index.html if [ $? -ne 0 ]; then - cat $i exit 1 fi done From f2029298a7d44d396e4e87bef07c55d10a06e498 Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Wed, 21 Dec 2016 10:35:43 +0800 Subject: [PATCH 47/55] Change type float to real. --- paddle/gserver/layers/PriorBox.cpp | 20 ++++----- paddle/gserver/tests/test_PriorBox.cpp | 56 +++++++++++++------------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/paddle/gserver/layers/PriorBox.cpp b/paddle/gserver/layers/PriorBox.cpp index ca61dfec5faa0..abaeaf3c1c3d0 100644 --- a/paddle/gserver/layers/PriorBox.cpp +++ b/paddle/gserver/layers/PriorBox.cpp @@ -36,8 +36,8 @@ class PriorBoxLayer : public Layer { int numPriors_; std::vector minSize_; std::vector maxSize_; - std::vector aspectRatio_; - std::vector variance_; + std::vector aspectRatio_; + std::vector variance_; MatrixPtr buffer_; }; @@ -77,8 +77,8 @@ void PriorBoxLayer::forward(PassType passType) { int imageWidth = image.getFrameWidth(); int imageHeight = image.getFrameHeight(); - float stepW = static_cast(imageWidth) / layerWidth; - float stepH = static_cast(imageHeight) / layerHeight; + real stepW = static_cast(imageWidth) / layerWidth; + real stepH = static_cast(imageHeight) / layerHeight; int dim = layerHeight * layerWidth * numPriors_ * 4; reserveOutput(1, dim * 2); // use a cpu buffer to compute @@ -88,8 +88,8 @@ void PriorBoxLayer::forward(PassType passType) { int idx = 0; for (int h = 0; h < layerHeight; ++h) { for (int w = 0; w < layerWidth; ++w) { - float centerX = (w + 0.5) * stepW; - float centerY = (h + 0.5) * stepH; + real centerX = (w + 0.5) * stepW; + real centerY = (h + 0.5) * stepH; int minSize = 0; for (size_t s = 0; s < minSize_.size(); s++) { // first prior. @@ -121,10 +121,10 @@ void PriorBoxLayer::forward(PassType passType) { } // rest of priors. for (size_t r = 0; r < aspectRatio_.size(); r++) { - float ar = aspectRatio_[r]; + real ar = aspectRatio_[r]; if (fabs(ar - 1.) < 1e-6) continue; - float boxWidth = minSize * sqrt(ar); - float boxHeight = minSize / sqrt(ar); + real boxWidth = minSize * sqrt(ar); + real boxHeight = minSize / sqrt(ar); tmpPtr[idx++] = (centerX - boxWidth / 2.) / imageWidth; tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; tmpPtr[idx++] = (centerX + boxWidth / 2.) / imageWidth; @@ -137,7 +137,7 @@ void PriorBoxLayer::forward(PassType passType) { // clip the prior's coordidate such that it is within [0, 1] for (int d = 0; d < dim * 2; ++d) if ((d % 8) < 4) - tmpPtr[d] = std::min(std::max(tmpPtr[d], (float)0.), (float)1.); + tmpPtr[d] = std::min(std::max(tmpPtr[d], (real)0.), (real)1.); MatrixPtr outV = getOutputValue(); outV->copyFrom(buffer_->data_, dim * 2); } diff --git a/paddle/gserver/tests/test_PriorBox.cpp b/paddle/gserver/tests/test_PriorBox.cpp index 19dfd0f065da2..a6d6a24269663 100644 --- a/paddle/gserver/tests/test_PriorBox.cpp +++ b/paddle/gserver/tests/test_PriorBox.cpp @@ -30,8 +30,8 @@ void doOnePriorBoxTest(size_t feature_map_width, size_t image_height, vector min_size, vector max_size, - vector aspect_ratio, - vector variance, + vector aspect_ratio, + vector variance, bool use_gpu, MatrixPtr& result) { // Setting up the priorbox layer @@ -71,8 +71,8 @@ void doOnePriorBoxTest(size_t feature_map_width, TEST(Layer, priorBoxLayerFwd) { vector minSize; vector maxSize; - vector aspectRatio; - vector variance; + vector aspectRatio; + vector variance; bool useGpu = false; minSize.push_back(276); @@ -84,22 +84,22 @@ TEST(Layer, priorBoxLayerFwd) { // CPU case 1. MatrixPtr result; - float resultData[] = {0.04, - 0.04, - 0.96, - 0.96, - 0.1, - 0.1, - 0.2, - 0.2, - 0, - 0, - 1, - 1, - 0.1, - 0.1, - 0.2, - 0.2}; + real resultData[] = {0.04, + 0.04, + 0.96, + 0.96, + 0.1, + 0.1, + 0.2, + 0.2, + 0, + 0, + 1, + 1, + 0.1, + 0.1, + 0.2, + 0.2}; result = Matrix::create(1, 2 * 8, false, useGpu); result->setData(resultData); doOnePriorBoxTest(/* feature_map_width */ 1, @@ -116,10 +116,10 @@ TEST(Layer, priorBoxLayerFwd) { variance[1] = 0.2; variance[3] = 0.1; maxSize.pop_back(); - float resultData2[] = {0, 0, 0.595, 0.595, 0.1, 0.2, 0.2, 0.1, - 0.405, 0, 1, 0.595, 0.1, 0.2, 0.2, 0.1, - 0, 0.405, 0.595, 1, 0.1, 0.2, 0.2, 0.1, - 0.405, 0.405, 1, 1, 0.1, 0.2, 0.2, 0.1}; + real resultData2[] = {0, 0, 0.595, 0.595, 0.1, 0.2, 0.2, 0.1, + 0.405, 0, 1, 0.595, 0.1, 0.2, 0.2, 0.1, + 0, 0.405, 0.595, 1, 0.1, 0.2, 0.2, 0.1, + 0.405, 0.405, 1, 1, 0.1, 0.2, 0.2, 0.1}; Matrix::resizeOrCreate(result, 1, 4 * 8, false, useGpu); result->setData(resultData2); doOnePriorBoxTest(/* feature_map_width */ 2, @@ -134,10 +134,10 @@ TEST(Layer, priorBoxLayerFwd) { result); // CPU case 3. aspectRatio.push_back(2); - float resultData3[] = {0.04, 0.04, 0.96, 0.96, 0.1, 0.2, - 0.2, 0.1, 0, 0.17473088, 1, 0.825269, - 0.1, 0.2, 0.2, 0.1, 0.17473088, 0, - 0.825269, 1, 0.1, 0.2, 0.2, 0.1}; + real resultData3[] = {0.04, 0.04, 0.96, 0.96, 0.1, 0.2, + 0.2, 0.1, 0, 0.17473088, 1, 0.825269, + 0.1, 0.2, 0.2, 0.1, 0.17473088, 0, + 0.825269, 1, 0.1, 0.2, 0.2, 0.1}; Matrix::resizeOrCreate(result, 1, 3 * 8, false, useGpu); result->setData(resultData3); doOnePriorBoxTest(/* feature_map_width */ 1, From 1b8e151fa2b1d79b3e145600f136e6d3d556fe70 Mon Sep 17 00:00:00 2001 From: Peng Li Date: Wed, 21 Dec 2016 10:49:04 +0800 Subject: [PATCH 48/55] Support user specified label input in tests --- paddle/gserver/tests/LayerGradUtil.cpp | 36 ++++++++++++++++++++++---- paddle/gserver/tests/LayerGradUtil.h | 19 ++++++++++++++ 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/paddle/gserver/tests/LayerGradUtil.cpp b/paddle/gserver/tests/LayerGradUtil.cpp index 1d5e7de1ba624..57c176810fddf 100644 --- a/paddle/gserver/tests/LayerGradUtil.cpp +++ b/paddle/gserver/tests/LayerGradUtil.cpp @@ -303,13 +303,31 @@ void initDataLayer(TestConfig testConf, ICpuGpuVectorPtr sequenceStartPositions; ICpuGpuVectorPtr subSequenceStartPositions; IVectorPtr cpuSequenceDims; - for (size_t i = 0; i < testConf.inputDefs.size(); i++) { + for (size_t i = 0; i < testConf.inputDefs.size(); ++i) { + if (testConf.inputDefs[i].inputType != INPUT_SEQUENCE_LABEL) continue; + + const std::vector& labelSeqStartPositions = + testConf.inputDefs[i].labelSeqStartPositions; + if (labelSeqStartPositions.size() != 0) { + CHECK(!sequenceStartPositions); + CHECK_GE(labelSeqStartPositions.size(), 2); + + sequenceStartPositions = + ICpuGpuVector::create(labelSeqStartPositions.size(), useGpu); + sequenceStartPositions->copyFrom( + labelSeqStartPositions.data(), labelSeqStartPositions.size(), useGpu); + } + } + + for (size_t i = 0; i < testConf.inputDefs.size(); ++i) { LayerConfig config; config.set_name(testConf.inputDefs[i].name); config.set_type("data"); config.set_size(testConf.inputDefs[i].dim); LayerPtr layer = LayerPtr(new DataLayer(config)); - size_t numSequence = batchSize / 10 + 1; + size_t numSequence = sequenceStartPositions + ? sequenceStartPositions->getSize() - 1 + : batchSize / 10 + 1; Argument data; auto fillData = [&](bool trans, int height, int width) { @@ -336,9 +354,17 @@ void initDataLayer(TestConfig testConf, break; case INPUT_LABEL: case INPUT_SEQUENCE_LABEL: - data.ids = VectorT::create(batchSize, useGpu); - // now rand number can be 0 to inputDefs[i].dim - data.ids->rand(testConf.inputDefs[i].dim); + if (testConf.inputDefs[i].labelInitValue.size() != 0) { + const std::vector& labelInitValue = + testConf.inputDefs[i].labelInitValue; + CHECK_EQ(labelInitValue.size(), batchSize); + data.ids = VectorT::create(batchSize, useGpu); + data.ids->copyFrom(labelInitValue.data(), batchSize); + } else { + data.ids = VectorT::create(batchSize, useGpu); + // now rand number can be 0 to inputDefs[i].dim + data.ids->rand(testConf.inputDefs[i].dim); + } break; case INPUT_SPARSE_NON_VALUE_DATA: data.value = makeRandomSparseMatrix( diff --git a/paddle/gserver/tests/LayerGradUtil.h b/paddle/gserver/tests/LayerGradUtil.h index 62ac2d160fd91..46cfcd29e0f4d 100644 --- a/paddle/gserver/tests/LayerGradUtil.h +++ b/paddle/gserver/tests/LayerGradUtil.h @@ -64,6 +64,8 @@ struct InputDef { size_t paraSize; ParaSparse sparse; bool isStatic; + std::vector labelInitValue; + std::vector labelSeqStartPositions; InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) { inputType = type; name = nameIn; @@ -72,6 +74,23 @@ struct InputDef { sparse = {""}; isStatic = false; } + + InputDef(InputType type, + string nameIn, + size_t dimIn, + size_t sizeIn, + std::vector labelInitValue, + std::vector labelSeqStartPositions) + : labelInitValue(labelInitValue), + labelSeqStartPositions(labelSeqStartPositions) { + inputType = type; + name = nameIn; + dim = dimIn; + paraSize = sizeIn; + sparse = {""}; + isStatic = false; + } + InputDef(InputType type, string nameIn, size_t dimIn, From 39a547741cb953bc92095ff74b3962336acab3f8 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Wed, 21 Dec 2016 10:59:37 +0800 Subject: [PATCH 49/55] refine docs.sh --- paddle/scripts/travis/docs.sh | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/paddle/scripts/travis/docs.sh b/paddle/scripts/travis/docs.sh index 1b05dce29a12f..8690fe1d40c93 100755 --- a/paddle/scripts/travis/docs.sh +++ b/paddle/scripts/travis/docs.sh @@ -8,14 +8,8 @@ cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=ON make paddle_docs paddle_docs_cn # check websites for broken links -set +e -for i in cn en; do - linkchecker doc/$i/html/index.html - if [ $? -ne 0 ]; then - exit 1 - fi -done -set -e +linkchecker doc/en/html/index.html +linkchecker doc/cn/html/index.html # Parse Github URL REPO=`git config remote.origin.url` From e4c492d3b8d6dc7b700aca16db7c410cf1961f23 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Wed, 21 Dec 2016 11:21:45 +0800 Subject: [PATCH 50/55] change type to bool. --- python/paddle/trainer/config_parser.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 8389476e6a5a5..29704391f2be6 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -499,11 +499,15 @@ def __init__( is_shared=None, update_hooks=None, input_layer_argument=None, - not_make_layer_name_in_submodel=None, ): + make_layer_name_in_submodel=True, ): + """ + @param make_layer_name_in_submodel True by defalut, you might need to + set it carefully when adding Input in config_parser.py. + """ self.add_keys(locals()) - self.input_layer_name = MakeLayerNameInSubmodel(input_layer_name) - if not_make_layer_name_in_submodel: - self.input_layer_name = input_layer_name + self.input_layer_name = MakeLayerNameInSubmodel( + input_layer_name + ) if make_layer_name_in_submodel else input_layer_name # Define a projection for iexed layer @@ -1852,7 +1856,7 @@ def __init__(self, initial_mean=0.0, is_static=True, is_shared=is_shared, - not_make_layer_name_in_submodel=True, )) + make_layer_name_in_submodel=False, )) parallel_nn = bool(int(g_command_config_args.get("parallel_nn", 0))) cudnn_version = int(g_command_config_args.get("cudnn_version", 0)) From d09564b73f77e248748f5f07738b07e708275194 Mon Sep 17 00:00:00 2001 From: Peng Li Date: Wed, 21 Dec 2016 15:10:42 +0800 Subject: [PATCH 51/55] change std::vector to const reference --- paddle/gserver/tests/LayerGradUtil.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/gserver/tests/LayerGradUtil.h b/paddle/gserver/tests/LayerGradUtil.h index 46cfcd29e0f4d..4e88ac0e81ef2 100644 --- a/paddle/gserver/tests/LayerGradUtil.h +++ b/paddle/gserver/tests/LayerGradUtil.h @@ -66,6 +66,7 @@ struct InputDef { bool isStatic; std::vector labelInitValue; std::vector labelSeqStartPositions; + InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) { inputType = type; name = nameIn; @@ -79,8 +80,8 @@ struct InputDef { string nameIn, size_t dimIn, size_t sizeIn, - std::vector labelInitValue, - std::vector labelSeqStartPositions) + const std::vector& labelInitValue, + const std::vector& labelSeqStartPositions) : labelInitValue(labelInitValue), labelSeqStartPositions(labelSeqStartPositions) { inputType = type; From 8d24931588ff2152d90bb4eff2c14bcbfc7733c6 Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Wed, 21 Dec 2016 20:44:06 +0800 Subject: [PATCH 52/55] Change member variables from public to protected --- paddle/gserver/layers/PriorBox.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/gserver/layers/PriorBox.cpp b/paddle/gserver/layers/PriorBox.cpp index abaeaf3c1c3d0..36ace7597cd66 100644 --- a/paddle/gserver/layers/PriorBox.cpp +++ b/paddle/gserver/layers/PriorBox.cpp @@ -18,10 +18,10 @@ limitations under the License. */ namespace paddle { /** - * @brief A layer for generate prior box locations and variances. + * @brief A layer for generating priorbox locations and variances. * - Input: Two and only two input layer are accepted. The input layer must be * be a data output layer and a convolution output layer. - * - Output: The prior box locations and variances of the input data. + * - Output: The priorbox locations and variances of the input data. * Reference: * Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, * Cheng-Yang Fu, Alexander C. Berg. SSD: Single Shot MultiBox Detector @@ -31,8 +31,11 @@ class PriorBoxLayer : public Layer { public: explicit PriorBoxLayer(const LayerConfig& config) : Layer(config) {} bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + void forward(PassType passType); void backward(const UpdateCallback& callback) {} + +protected: int numPriors_; std::vector minSize_; std::vector maxSize_; From e031f0c4e80981f4242901041a125cecfd3c321c Mon Sep 17 00:00:00 2001 From: Peng Li Date: Thu, 22 Dec 2016 10:57:05 +0800 Subject: [PATCH 53/55] Fix typo in PyDataProvider2.py --- python/paddle/trainer/PyDataProvider2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/trainer/PyDataProvider2.py b/python/paddle/trainer/PyDataProvider2.py index de266bb5d3d07..a20dc6ee42443 100644 --- a/python/paddle/trainer/PyDataProvider2.py +++ b/python/paddle/trainer/PyDataProvider2.py @@ -278,7 +278,7 @@ def process(settings, file_name): custom calculate one sample's batch_size. It is very danger to set it to false and use - calc_batch_size together. Default is false. + calc_batch_size together. Default is true. :type can_over_batch_size: bool :param calc_batch_size: a method to calculate each sample's batch size. From 9baf7fc479fbf1bf3295e18571098ac31cac5e98 Mon Sep 17 00:00:00 2001 From: Peng Li Date: Thu, 22 Dec 2016 11:02:39 +0800 Subject: [PATCH 54/55] Fix data provider bug in srl demo Once encoutering a single sample whose size is larger than batch size, the provider will yield empty batch and terminate the current pass unexpectedly if can_over_batch_size=False. --- demo/semantic_role_labeling/dataprovider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demo/semantic_role_labeling/dataprovider.py b/demo/semantic_role_labeling/dataprovider.py index 042cd4e7a9e25..360c57ea6283c 100644 --- a/demo/semantic_role_labeling/dataprovider.py +++ b/demo/semantic_role_labeling/dataprovider.py @@ -43,7 +43,7 @@ def get_batch_size(yeild_data): init_hook=hook, should_shuffle=True, calc_batch_size=get_batch_size, - can_over_batch_size=False, + can_over_batch_size=True, cache=CacheType.CACHE_PASS_IN_MEM) def process(settings, file_name): with open(file_name, 'r') as fdata: From 89bf2e44f06f9a597b6ad55f65bcb1970b6c2044 Mon Sep 17 00:00:00 2001 From: Peng Li Date: Thu, 22 Dec 2016 11:31:28 +0800 Subject: [PATCH 55/55] Change float to real in NormLayer.h --- paddle/gserver/layers/NormLayer.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/gserver/layers/NormLayer.h b/paddle/gserver/layers/NormLayer.h index 86255b231b1ee..011bab8fdedab 100644 --- a/paddle/gserver/layers/NormLayer.h +++ b/paddle/gserver/layers/NormLayer.h @@ -50,7 +50,7 @@ class NormLayer : public Layer { class ResponseNormLayer : public NormLayer { protected: size_t channels_, size_, outputX_, imgSize_, outputY_, imgSizeY_; - float scale_, pow_; + real scale_, pow_; MatrixPtr denoms_; public: