Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… add_cudnn_pool3d
  • Loading branch information
chengduoZH committed Nov 15, 2017
2 parents ec1e2fc + 313f845 commit e825a49
Show file tree
Hide file tree
Showing 252 changed files with 5,208 additions and 1,662 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ third_party/
cmake-build-*

# generated while compiling
python/paddle/v2/framework/core.so
python/paddle/v2/fluid/core.so
paddle/pybind/pybind.h
CMakeFiles
cmake_install.cmake
paddle/.timestamp
python/paddlepaddle.egg-info/
paddle/pybind/pybind.h
python/paddle/v2/framework/tests/tmp/*
58 changes: 58 additions & 0 deletions doc/design/evaluator.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
## Evaluator Design

### The Problem

During training or serving, we provide the evaluation function to measure the model performance, e.g., accuracy, precision. In the operator based framework design, the data go through the network pipeline batch by batch. As a result, inside the operator, we only can calculate one minibatch metrics. We need to provide a mechanism to calculate the metrics for each N pass/batch the user wanted.

### Evaluator Design
Currently, every operation is expressed in the graph. we divide the evaluator process into three steps.

1. Initialize the metric state and add it into the block.

2. Calculate the statistic of the metric state in every mini-batch. The single operator is only responsible for calculating necessary statistics for one mini-batch. For example, accuracy operator only calculate a minibatch data if run once.


3. Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. When it comes to distributed training/Multi-GPU training, aggregate the value from different devices.

### Implementation
This design is shown in python API.
Each metric operator need to caculate the metric statistic and return the batch aware states, Python side responsible for accumulate the states for each pass.


```python
class Evaluator(object):
"""
Evaluator Base class.
"""
def __init__(self, name, **kwargs):
"""
Different evaluator may has different metric states. E.g, Accuracy need two variables, total and right sample counts.
Auc need four variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives`. So every evaluator should create its needed variables and append to main_program
The initialization of Evaluator should be responsible for:
create metric states and append to the main_program
"""
pass

def _update_ops(self, input, label, **kwargs)
"""
Add mini-batch evaluator caculate operators to the main_program.
Add increment operator to accumulate the metric states.
"""


def reset(self, executor, reset_program=None):
"""
Reset metric states at the begin of each pass/user specified batch number.
Execute the reset_program to reset the states.
"""


def eval(self, executor, eval_program=None):
"""
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
Execute the eval_program and return the result.
"""
return eval_result
```
8 changes: 8 additions & 0 deletions paddle/capi/Matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ paddle_error paddle_matrix_get_shape(paddle_matrix mat,

paddle_matrix paddle_matrix_create_sparse(
uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu) {
#ifndef PADDLE_MOBILE_INFERENCE
auto ptr = new paddle::capi::CMatrix();
ptr->mat = paddle::Matrix::createSparseMatrix(
height,
Expand All @@ -131,6 +132,9 @@ paddle_matrix paddle_matrix_create_sparse(
false,
useGpu);
return ptr;
#else
return nullptr;
#endif
}

paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
Expand All @@ -140,6 +144,7 @@ paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
uint64_t colSize,
float* valueArray,
uint64_t valueSize) {
#ifndef PADDLE_MOBILE_INFERENCE
if (mat == nullptr) return kPD_NULLPTR;
auto ptr = cast(mat);
if (rowArray == nullptr || colArray == nullptr ||
Expand All @@ -160,4 +165,7 @@ paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
} else {
return kPD_NOT_SUPPORTED;
}
#else
return kPD_NOT_SUPPORTED;
#endif
}
2 changes: 2 additions & 0 deletions paddle/capi/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ PD_API paddle_matrix paddle_matrix_create(uint64_t height,
* @param isBinary is binary (either 1 or 0 in matrix) or not.
* @param useGpu is using GPU or not.
* @return paddle_matrix.
* @note Mobile inference does not support this interface.
*/
PD_API paddle_matrix paddle_matrix_create_sparse(
uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu);
Expand Down Expand Up @@ -129,6 +130,7 @@ PD_API paddle_error paddle_matrix_get_shape(paddle_matrix mat,
* NULL if the matrix is binary.
* @param [in] valueSize length of value array. Zero if the matrix is binary.
* @return paddle_error
* @note Mobile inference does not support this interface.
*/
PD_API paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
int* rowArray,
Expand Down
2 changes: 2 additions & 0 deletions paddle/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ if(WITH_GPU)
set_source_files_properties(${CUDA_CXX_SOURCES}
PROPERTIES COMPILE_FLAGS "-D__NVCC__")
else()
if (NOT MOBILE_INFERENCE)
set(CUDA_CXX_SOURCES src/hl_warpctc_wrap.cc)
endif()
endif()

set(CUDA_CU_SOURCES
Expand Down
7 changes: 4 additions & 3 deletions paddle/cuda/include/hl_cnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License. */
#include "hl_base.h"

/**
* @brief Maximum pool forward.
* @brief Maximum pool forward with Mask output.
*
* @param[in] frameCnt batch size of input image.
* @param[in] inputData input data.
Expand All @@ -35,7 +35,7 @@ limitations under the License. */
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
*
* @param[out] maskData the location indices of select max data.
*/
extern void hl_maxpool_forward(const int frameCnt,
const real* inputData,
Expand All @@ -51,7 +51,8 @@ extern void hl_maxpool_forward(const int frameCnt,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride);
const int tgtStride,
real* maskData = NULL);

/**
* @brief Maximum pool backward.
Expand Down
3 changes: 2 additions & 1 deletion paddle/cuda/include/stub/hl_cnn_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ inline void hl_maxpool_forward(const int frameCnt,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride) {}
const int tgtStride,
real* MaskData) {}

inline void hl_maxpool_backward(const int frameCnt,
const real* inputData,
Expand Down
19 changes: 14 additions & 5 deletions paddle/cuda/src/hl_cuda_cnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ __global__ void KeMaxPoolForward(const int nthreads,
const int offsetH,
const int offsetW,
real* tgtData,
const int tgtStride) {
const int tgtStride,
real* maskData) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int pw = index % pooledW;
Expand All @@ -45,16 +46,22 @@ __global__ void KeMaxPoolForward(const int nthreads,
hstart = max(hstart, 0);
wstart = max(wstart, 0);
real maxval = -FLT_MAX;
int max_index = -1;
inputData += (frameNum * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (maxval < inputData[h * width + w])
maxval = inputData[h * width + w];
if (maxval < inputData[h * width + w]) {
max_index = h * width + w;
maxval = inputData[max_index];
}
}
}
int tgtIndex =
index % (pooledW * pooledH * channels) + frameNum * tgtStride;
tgtData[tgtIndex] = maxval;
if (maskData != NULL) {
maskData[tgtIndex] = max_index;
}
}
}

Expand All @@ -72,7 +79,8 @@ void hl_maxpool_forward(const int frameCnt,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride) {
const int tgtStride,
real* maskData) {
int num_kernels = pooledH * pooledW * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024;
dim3 threads(1024, 1);
Expand All @@ -92,7 +100,8 @@ void hl_maxpool_forward(const int frameCnt,
paddingH,
paddingW,
tgtData,
tgtStride);
tgtStride,
maskData);
CHECK_SYNC("hl_maxpool_forward failed");
}

Expand Down
6 changes: 3 additions & 3 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ py_proto_compile(framework_py_proto SRCS framework.proto)
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
add_custom_command(TARGET framework_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/proto
COMMAND cp *.py ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/proto/
COMMENT "Copy generated python proto into directory paddle/v2/framework/proto."
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/proto
COMMAND cp *.py ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/proto/
COMMENT "Copy generated python proto into directory paddle/v2/fluid/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})

cc_library(backward SRCS backward.cc DEPS net_op)
Expand Down
33 changes: 27 additions & 6 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,12 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs;
}

static BlockDescBind* CreateStepBlock(
ProgramDescBind& program_desc,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var,
int step_block_idx);

std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>* no_grad_vars,
Expand All @@ -392,13 +398,13 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(

if ((*it)->Type() == "recurrent") {
int step_block_idx = (*it)->GetBlockAttr("step_block");
auto backward_block_op_descs = MakeBlockBackward(
program_desc, step_block_idx, no_grad_vars, grad_to_var);
BlockDescBind* backward_block = CreateStepBlock(
program_desc, no_grad_vars, grad_to_var, step_block_idx);
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else if ((*it)->Type() == "conditional_block") {
BlockDescBind* backward_block =
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
for (auto& ptr : backward_block_op_descs) {
backward_block->AppendAllocatedOp(std::move(ptr));
}
CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
(*it)->GetBlockAttr("block"));
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else {
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
Expand Down Expand Up @@ -449,6 +455,21 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return backward_descs;
}

static BlockDescBind* CreateStepBlock(
ProgramDescBind& program_desc,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var,
int step_block_idx) {
auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx,
no_grad_vars, grad_to_var);
BlockDescBind* backward_block =
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
for (auto& ptr : backward_block_op_descs) {
backward_block->AppendAllocatedOp(move(ptr));
}
return backward_block;
}

ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars) {
Expand Down
22 changes: 22 additions & 0 deletions paddle/framework/var_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,32 @@ inline VarDesc::VarType ToVarType(std::type_index type) {
return VarDesc_VarType_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
return VarDesc_VarType_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return VarDesc_VarType_SELECTED_ROWS;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
}
}

template <typename Visitor>
inline void VisitVarType(const Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) {
case VarDesc_VarType_LOD_TENSOR:
visitor(var.Get<framework::LoDTensor>());
return;
case VarDesc_VarType_LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>());
return;
case VarDesc_VarType_LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>());
return;
case VarDesc_VarType_SELECTED_ROWS:
visitor(var.Get<SelectedRows>());
return;
default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
}
}

} // namespace framework
} // namespace paddle
6 changes: 6 additions & 0 deletions paddle/function/ConvOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ConvFunctionBase : public FunctionBase {
// function arguments
strides_ = config.get<std::vector<size_t>>("strides");
paddings_ = config.get<std::vector<size_t>>("paddings");
dilations_ = config.get<std::vector<size_t>>("dilations");
groups_ = config.get<size_t>("groups");

// number of inputs and outputs
Expand Down Expand Up @@ -118,6 +119,7 @@ class ConvFunctionBase : public FunctionBase {

std::vector<size_t> strides_;
std::vector<size_t> paddings_;
std::vector<size_t> dilations_;

/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
Expand All @@ -133,6 +135,10 @@ class ConvFunctionBase : public FunctionBase {

inline int paddingW() const { return paddings_[1]; }

inline int dilationH() const { return dilations_[0]; }

inline int dilationW() const { return dilations_[1]; }

// A temporary memory in convolution calculation.
MemoryHandlePtr memory_;

Expand Down
Loading

0 comments on commit e825a49

Please sign in to comment.