Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cmrnorm #854

Merged
merged 17 commits into from
Dec 20, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/util.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ function(link_paddle_exe TARGET_NAME)
target_circle_link_libraries(${TARGET_NAME}
ARCHIVE_START
paddle_gserver
paddle_function
${METRIC_LIBS}
ARCHIVE_END
paddle_pserver
Expand All @@ -106,6 +107,7 @@ function(link_paddle_exe TARGET_NAME)
paddle_parameter
paddle_proto
paddle_cuda
paddle_test_main
${METRIC_LIBS}
${PROTOBUF_LIBRARY}
${LIBGLOG_LIBRARY}
Expand Down
1 change: 1 addition & 0 deletions paddle/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(cuda)
add_subdirectory(function)
add_subdirectory(utils)
add_subdirectory(math)
add_subdirectory(parameter)
Expand Down
1 change: 1 addition & 0 deletions paddle/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp
WORKING_DIRECTORY ${PROJ_ROOT}/paddle
DEPENDS python_swig_sources
paddle_parameter
paddle_function
paddle_math
paddle_utils
paddle_gserver
Expand Down
5 changes: 3 additions & 2 deletions paddle/api/paddle_ld_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
whole_end = ""

LIB_DIRS = [
"math", 'utils', 'parameter', "gserver", "api", "cuda", "pserver",
"trainer"
"math", 'function', 'utils', 'parameter', "gserver", "api", "cuda",
"pserver", "trainer"
]
PARENT_LIB_DIRS = ['proto']

Expand Down Expand Up @@ -75,6 +75,7 @@ def libs_str(self):
libs = [
whole_start,
"-lpaddle_gserver",
"-lpaddle_function",
whole_end,
"-lpaddle_pserver",
"-lpaddle_trainer_lib",
Expand Down
56 changes: 0 additions & 56 deletions paddle/cuda/include/hl_cnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,62 +240,6 @@ extern void hl_avgpool_backward(const int frameCnt,
real* backGrad,
const int outStride);

/**
* @brief Cross-map-respose normalize forward.
*
* @param[in] frameCnt batch size of input image.
* @param[in] in input data.
* @param[in] scale buffer.
* @param[out] out output data.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] sizeX size.
* @param[in] alpha scale.
* @param[in] beta scale.
*
*/
extern void hl_CMRNorm_forward(size_t frameCnt,
const real* in,
real* scale,
real* out,
size_t channels,
size_t height,
size_t width,
size_t sizeX,
real alpha,
real beta);

/**
* @brief Cross-map-respose normalize backward.
*
* @param[in] frameCnt batch size of input image.
* @param[in] inV input data.
* @param[in] scale buffer.
* @param[out] outV output value.
* @param[out] outDiff output grad.
* @param[out] inDiff input grad.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] sizeX size.
* @param[in] alpha scale.
* @param[in] beta scale.
*
*/
extern void hl_CMRNorm_backward(size_t frameCnt,
const real* inV,
const real* scale,
const real* outV,
const real* outDiff,
real* inDiff,
size_t channels,
size_t height,
size_t width,
size_t sizeX,
real alpha,
real beta);

/**
* @brief Bilinear interpolation forward.
*
Expand Down
24 changes: 0 additions & 24 deletions paddle/cuda/include/stub/hl_cnn_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,30 +117,6 @@ inline void hl_avgpool_backward(const int frameCnt,
real* backGrad,
const int outStride) {}

inline void hl_CMRNorm_forward(size_t frameCnt,
const real* in,
real* scale,
real* out,
size_t channels,
size_t height,
size_t width,
size_t sizeX,
real alpha,
real beta) {}

inline void hl_CMRNorm_backward(size_t frameCnt,
const real* inV,
const real* scale,
const real* outV,
const real* outDiff,
real* inDiff,
size_t channels,
size_t height,
size_t width,
size_t sizeX,
real alpha,
real beta) {}

inline void hl_bilinear_forward(const real* inData,
const size_t inImgH,
const size_t inImgW,
Expand Down
158 changes: 0 additions & 158 deletions paddle/cuda/src/hl_cuda_cnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -381,164 +381,6 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad,
CHECK_SYNC("hl_avgpool_backward failed");
}

__global__ void KeCMRNormFillScale(size_t nthreads, const real* in,
real* scale, size_t channels,
size_t height, size_t width, size_t size,
real alpha) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < nthreads) {
// find out the local offset
size_t w = index % width;
size_t h = (index / width) % height;
size_t n = index / width / height;
size_t offset = (n * channels * height + h) * width + w;
size_t step = height * width;
in += offset;
scale += offset;
size_t head = 0;
size_t pre_pad = (size - 1) / 2;
size_t post_pad = size - pre_pad - 1;
real accum_scale = 0;
// fill the scale at [n, :, h, w]
// accumulate values
while (head < post_pad) {
accum_scale += in[head * step] * in[head * step];
++head;
}
// until we reach size, nothing needs to be subtracted
while (head < size) {
accum_scale += in[head * step] * in[head * step];
scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
++head;
}
// both add and subtract
while (head < channels) {
accum_scale += in[head * step] * in[head * step];
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
++head;
}
// subtract only
while (head < channels + post_pad) {
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
++head;
}
}
}

__global__ void KeCMRNormOutput(size_t nthreads, const real* in,
const real* scale, real negative_beta,
real* out) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < nthreads) {
out[index] = in[index] * pow(scale[index], negative_beta);
}
}

void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale,
real* out, size_t channels,
size_t height, size_t width, size_t sizeX,
real alpha, real beta) {
size_t threadsNum = frameCnt * height * width;
size_t blocksX = (threadsNum + 1024 - 1) / 1024;
size_t blocksY = 1;
dim3 threads(1024, 1);
dim3 grid(blocksX, blocksY);

KeCMRNormFillScale<<<grid, threads, 0, STREAM_DEFAULT>>>
(threadsNum, in, scale, channels, height, width, sizeX, alpha);

threadsNum = frameCnt * height * width *channels;
blocksX = (threadsNum + 1024 -1) / 1024;
dim3 threads2(1024, 1);
dim3 grid2(blocksX, blocksY);
KeCMRNormOutput<<<grid2, threads2, 0, STREAM_DEFAULT>>>
(threadsNum, in, scale, beta, out);
CHECK_SYNC("hl_CMRNorm_forward");
}

__global__ void KeCMRNormDiff(size_t nthreads, const real* bottom_data,
const real* top_data, const real* scale,
const real* top_diff, size_t channels,
size_t height, size_t width, size_t size,
real negative_beta, real cache_ratio,
real* bottom_diff ) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < nthreads) {
// find out the local offset
size_t w = index % width;
size_t h = (index / width) % height;
size_t n = index / width / height;
size_t offset = (n * channels * height + h) * width + w;
size_t step = height * width;
bottom_data += offset;
top_data += offset;
scale += offset;
top_diff += offset;
bottom_diff += offset;
int head = 0;
int pre_pad = size - (size + 1) / 2;
int post_pad = size - pre_pad - 1;
real accum_ratio = 0;
// accumulate values
while (head < post_pad) {
accum_ratio += top_diff[head * step] *
top_data[head * step] / scale[head * step];
++head;
}
// until we reach size, nothing needs to be subtracted
while (head < size) {
accum_ratio += top_diff[head * step] *
top_data[head * step] / scale[head * step];
bottom_diff[(head - post_pad) * step] +=
top_diff[(head - post_pad) * step] *
pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
// both add and subtract
while (head < channels) {
accum_ratio += top_diff[head * step] * top_data[head * step] /
scale[head * step];
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
bottom_diff[(head - post_pad) * step] +=
top_diff[(head - post_pad) * step] *
pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
// subtract only
while (head < channels + post_pad) {
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
bottom_diff[(head - post_pad) * step] +=
top_diff[(head - post_pad) * step] *
pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
}
}

void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
const real* scale,
const real* outV, const real* outDiff,
real *inDiff, size_t channels,
size_t height, size_t width, size_t sizeX,
real alpha, real beta) {
size_t threadsNum = frameCnt * height * width;
size_t blocksX = (threadsNum + 1024 - 1) / 1024;
size_t blocksY = 1;
dim3 threads(1024, 1);
dim3 grid(blocksX, blocksY);
KeCMRNormDiff <<<grid, threads, 0, STREAM_DEFAULT>>>
(threadsNum, inV, outV, scale, outDiff, channels,
height, width, sizeX, alpha, beta, inDiff);
CHECK_SYNC("hl_CMRNorm_backward");
}

__global__ void KeBilinearInterpFw(const real* in,
const size_t inImgH,
const size_t inImgW,
Expand Down
27 changes: 27 additions & 0 deletions paddle/function/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
file(GLOB h_files . *_op.h)
file(GLOB cpp_files . *_op.cpp)

list(APPEND h_files Function.h)
list(APPEND cpp_files Function.cpp)

if(WITH_GPU)
file(GLOB cu_files . *_op_gpu.cu)
cuda_compile(cu_objs ${cu_files})
endif()

add_library(paddle_function STATIC ${cpp_files} ${cu_objs})

add_library(paddle_test_main STATIC TestMain.cpp)

if(WITH_GPU)
# TODO:
# file(GLOB test_files . *_op_test.cpp)
# add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files})
add_simple_unittest(cross_map_normal_op_test)
endif()

add_style_check_target(paddle_function ${h_files})
add_style_check_target(paddle_function ${cpp_files})
if(WITH_GPU)
add_style_check_target(paddle_function ${cu_files})
endif()
49 changes: 49 additions & 0 deletions paddle/function/Function.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "Function.h"

namespace paddle {

template <>
size_t FuncConfig::get<size_t>(const std::string& key) const {
auto it = valueMap_.find(key);
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK ==> CHECK_NE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK_NE not support type of it.

return it->second.s;
}

template <>
real FuncConfig::get<real>(const std::string& key) const {
auto it = valueMap_.find(key);
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK ==> CHECK_NE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK_NE not support type of it.

return it->second.r;
}

template <>
FuncConfig& FuncConfig::set<size_t>(const std::string& key, size_t v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use count > 0 to indicate there is already a value for the key?
CHECK(valueMap_.count(key) > 0) << "Duplicated value: " << key;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK(true) will pass. There need valueMap_.count(key) == 0 is true.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是 CHECK_EQ(0, valueMap_.count(key)) ,这样万一count返回的不是0,则 glog 可以打印出count返回的值,以便debug。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK ==> CHECK_EQ

valueMap_[key].s = v;
return *this;
}

template <>
FuncConfig& FuncConfig::set<real>(const std::string& key, real v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as line 35

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK_EQ

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK ==> CHECK_EQ

valueMap_[key].r = v;
return *this;
}

ClassRegistrar<FunctionBase> FunctionBase::funcRegistrar_;

} // namespace paddle
Loading