From b28b2f172b2763dd8917833c2708309f98299a0a Mon Sep 17 00:00:00 2001 From: QI JUN Date: Mon, 27 Nov 2017 18:35:57 +0800 Subject: [PATCH] refine test_recognize_digits_mlp and format codes (#5937) --- paddle/capi/Matrix.cpp | 4 +- paddle/capi/matrix.h | 8 +- paddle/framework/tensor_util.h | 9 +- paddle/operators/math/maxouting.cc | 31 ++- paddle/operators/math/maxouting.cu | 80 ++++--- paddle/operators/math/maxouting.h | 8 +- paddle/operators/maxout_op.cc | 38 ++- paddle/operators/maxout_op.cu.cc | 8 +- paddle/operators/maxout_op.h | 2 +- paddle/operators/roi_pool_op.cc | 24 +- paddle/operators/roi_pool_op.cu | 216 ++++++++---------- paddle/operators/roi_pool_op.h | 3 +- paddle/operators/sequence_slice_op.cc | 5 +- python/paddle/v2/dataset/uci_housing.py | 4 +- .../tests/book/test_recognize_digits_mlp.py | 12 +- .../paddle/v2/fluid/tests/test_maxout_op.py | 4 +- .../paddle/v2/fluid/tests/test_roi_pool_op.py | 48 ++-- 17 files changed, 231 insertions(+), 273 deletions(-) mode change 100755 => 100644 paddle/operators/roi_pool_op.cc mode change 100755 => 100644 paddle/operators/roi_pool_op.cu mode change 100755 => 100644 paddle/operators/roi_pool_op.h mode change 100755 => 100644 paddle/operators/sequence_slice_op.cc diff --git a/paddle/capi/Matrix.cpp b/paddle/capi/Matrix.cpp index d5b55e1c95f24..30f3a766f0c65 100644 --- a/paddle/capi/Matrix.cpp +++ b/paddle/capi/Matrix.cpp @@ -55,7 +55,7 @@ paddle_error paddle_matrix_set_row(paddle_matrix mat, } PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat, - paddle_real* value) { + paddle_real* value) { if (mat == nullptr || value == nullptr) return kPD_NULLPTR; auto ptr = cast(mat); if (ptr->mat == nullptr) return kPD_NULLPTR; @@ -75,7 +75,7 @@ PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat, } PD_API paddle_error paddle_matrix_get_value(paddle_matrix mat, - paddle_real* result) { + paddle_real* result) { if (mat == nullptr || result == nullptr) return kPD_NULLPTR; auto ptr = cast(mat); if (ptr->mat == nullptr) return kPD_NULLPTR; diff --git a/paddle/capi/matrix.h b/paddle/capi/matrix.h index 01b8bad2ee9f5..8cc3e0034e058 100644 --- a/paddle/capi/matrix.h +++ b/paddle/capi/matrix.h @@ -79,7 +79,7 @@ PD_API paddle_error paddle_matrix_set_row(paddle_matrix mat, * @note value should contain enough element of data to init the mat */ PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat, - paddle_real* value); + paddle_real* value); /** * @brief PDMatGetRow Get raw row buffer from matrix @@ -93,14 +93,14 @@ PD_API paddle_error paddle_matrix_get_row(paddle_matrix mat, paddle_real** rawRowBuffer); /** - * @brief copy data from the matrix + * @brief copy data from the matrix * @param [in] mat Target matrix - * @param [out] result pointer to store the matrix data + * @param [out] result pointer to store the matrix data * @return paddle_error * @note the space of the result should allocated before invoke this API */ PD_API paddle_error paddle_matrix_get_value(paddle_matrix mat, - paddle_real* result); + paddle_real* result); /** * @brief PDMatCreateNone Create None Matrix * @return diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index 8ee2e15a59113..4e34b90d57eed 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -135,18 +135,17 @@ inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx, auto dst_ptr = static_cast(dst->data()); if (platform::is_cpu_place(src.place())) { - memory::Copy(dst_place, dst_ptr, boost::get(src.place()), - src_ptr, size); + memory::Copy(dst_place, dst_ptr, + boost::get(src.place()), src_ptr, size); } #ifdef PADDLE_WITH_CUDA else if (platform::is_gpu_place(src.place())) { // NOLINT memory::Copy( - dst_place, dst_ptr, boost::get(src.place()), src_ptr, - size, + dst_place, dst_ptr, boost::get(src.place()), + src_ptr, size, reinterpret_cast(ctx).stream()); } #endif - } } // namespace framework diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index e5168ce7afd41..c9003962d33b7 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -23,8 +23,7 @@ template class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, - framework::Tensor * output, + const framework::Tensor& input, framework::Tensor* output, int groups) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; @@ -37,34 +36,30 @@ class MaxOutFunctor { T* output_data = output->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; ++i) { - int new_bindex = c_size * i; + int new_bindex = c_size * i; for (int c = 0; c < output_channels; ++c) { int new_cindex = fea_size * c; for (int f = 0; f < fea_size; ++f) { T ele = static_cast(-FLT_MAX); for (int ph = 0; ph < groups; ++ph) { - T x = input_data[(new_bindex + new_cindex) * groups - + ph * fea_size + f]; + T x = input_data[(new_bindex + new_cindex) * groups + + ph * fea_size + f]; ele = ele > x ? ele : x; } - output_data[(new_bindex+new_cindex+f)] = ele; + output_data[(new_bindex + new_cindex + f)] = ele; } } } } }; - - template class MaxOutGradFunctor { -public: + public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, - framework::Tensor * input_grad, + const framework::Tensor& input, framework::Tensor* input_grad, const framework::Tensor& output, - const framework::Tensor& output_grad, - int groups) { + const framework::Tensor& output_grad, int groups) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -84,11 +79,11 @@ class MaxOutGradFunctor { bool continue_match = true; int output_idx = blen + clen + f; for (int g = 0; g < groups && continue_match; ++g) { - int input_idx = input_idx0 + fea_size * g; - if (input_data[input_idx] == output_data[output_idx]) { - input_grad_data[input_idx] += output_grad_data[output_idx]; - continue_match = false; - } + int input_idx = input_idx0 + fea_size * g; + if (input_data[input_idx] == output_data[output_idx]) { + input_grad_data[input_idx] += output_grad_data[output_idx]; + continue_match = false; + } } } } diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu index 7c698577b8a82..c3fabcae081e2 100644 --- a/paddle/operators/math/maxouting.cu +++ b/paddle/operators/math/maxouting.cu @@ -21,9 +21,9 @@ namespace math { template __global__ void KernelMaxOut(const int nthreads, const T* input_data, - const int channels, - const int input_height, const int input_width, - int groups, T* output_data ) { + const int channels, const int input_height, + const int input_width, int groups, + T* output_data) { const int size = input_height * input_width * channels / groups; const int feat_len = input_height * input_width; int index = blockIdx.x * blockDim.x + threadIdx.x; @@ -34,7 +34,7 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data, int channel_idx = batch_offset / feat_len; int feat_idx = batch_offset % feat_len; int data_idx = - (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; + (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; T ele = static_cast(-FLT_MAX); for (int g = 0; g < groups; ++g) { T x = input_data[data_idx + g * feat_len]; @@ -44,34 +44,35 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data, } } template -__global__ void KernelMaxoutGrad( - const int nthreads, const T* input_data, const T* output_data, - const T* output_grad, T* input_grad, const int channels, - const int input_height, const int input_width, int groups) { - const int size = input_height * input_width * channels / groups; - const int feat_len = input_height * input_width; - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (int i = index; i < nthreads; i += offset) { - int batch_idx = i / size; - int batch_offset = i % size; - int channel_idx = batch_offset / feat_len; - int feat_idx = batch_offset % feat_len; - int data_idx = +__global__ void KernelMaxoutGrad(const int nthreads, const T* input_data, + const T* output_data, const T* output_grad, + T* input_grad, const int channels, + const int input_height, const int input_width, + int groups) { + const int size = input_height * input_width * channels / groups; + const int feat_len = input_height * input_width; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int batch_idx = i / size; + int batch_offset = i % size; + int channel_idx = batch_offset / feat_len; + int feat_idx = batch_offset % feat_len; + int data_idx = (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; - int max_index = -1; - bool continue_match = true; - for (int g = 0; g < groups && continue_match; ++g) { - if (input_data[data_idx + g * feat_len] == output_data[i]) { - max_index = data_idx + g * feat_len; - continue_match = false; - break; - } - } - if (max_index != -1) { - input_grad[max_index] += output_grad[index]; + int max_index = -1; + bool continue_match = true; + for (int g = 0; g < groups && continue_match; ++g) { + if (input_data[data_idx + g * feat_len] == output_data[i]) { + max_index = data_idx + g * feat_len; + continue_match = false; + break; } } + if (max_index != -1) { + input_grad[max_index] += output_grad[index]; + } + } } /* * All tensors are in NCHW format. @@ -80,7 +81,7 @@ template class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor * output, + const framework::Tensor& input, framework::Tensor* output, int groups) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; @@ -92,7 +93,7 @@ class MaxOutFunctor { const T* input_data = input.data(); T* output_data = output->mutable_data(context.GetPlace()); - int nthreads = output->numel(); + int nthreads = output->numel(); int blocks = (nthreads + 1024 - 1) / 1024; dim3 threads(1024, 1); dim3 grid(blocks, 1); @@ -101,8 +102,7 @@ class MaxOutFunctor { T><<(context) .stream()>>>(nthreads, input_data, input_channels, - input_height, input_width, groups, - output_data); + input_height, input_width, groups, output_data); } }; /* @@ -112,11 +112,9 @@ template class MaxOutGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, - framework::Tensor * input_grad, + const framework::Tensor& input, framework::Tensor* input_grad, const framework::Tensor& output, - const framework::Tensor& output_grad, - int groups) { + const framework::Tensor& output_grad, int groups) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -129,7 +127,7 @@ class MaxOutGradFunctor { const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad->mutable_data(context.GetPlace()); - int nthreads = output.numel(); + int nthreads = output.numel(); int blocks = (nthreads + 1024 - 1) / 1024; dim3 threads(1024, 1); dim3 grid(blocks, 1); @@ -137,9 +135,9 @@ class MaxOutGradFunctor { KernelMaxoutGrad< T><<(context) - .stream()>>>( - nthreads, input_data, output_data, output_grad_data, input_grad_data, - input_channels, input_height, input_width, groups); + .stream()>>>(nthreads, input_data, output_data, + output_grad_data, input_grad_data, input_channels, + input_height, input_width, groups); } }; diff --git a/paddle/operators/math/maxouting.h b/paddle/operators/math/maxouting.h index d4c9da38ab8f8..2d9069b0b3ca3 100644 --- a/paddle/operators/math/maxouting.h +++ b/paddle/operators/math/maxouting.h @@ -21,15 +21,14 @@ namespace paddle { namespace operators { namespace math { -#define FLT_MAX \ - __FLT_MAX__ +#define FLT_MAX __FLT_MAX__ template class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor * output, + const framework::Tensor& input, framework::Tensor* output, int groups); }; @@ -37,8 +36,7 @@ template class MaxOutGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, - framework::Tensor * input_grad, + const framework::Tensor& input, framework::Tensor* input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, int groups); }; diff --git a/paddle/operators/maxout_op.cc b/paddle/operators/maxout_op.cc index 95467f2e69093..e203a25d54437 100644 --- a/paddle/operators/maxout_op.cc +++ b/paddle/operators/maxout_op.cc @@ -22,16 +22,17 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { public: MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", + AddInput( + "X", "(Tensor) The input tensor of maxout operator. " "The format of input tensor is NCHW. Where N is batch size, C is the " "number of channels, H and W is the height and width of feature."); AddOutput("Out", - "(Tensor) The output tensor of maxout operator." - "The format of output tensor is also NCHW." - "Where N is batch size, C is " - "the number of channels, H and W is the height and " - "width of feature."); + "(Tensor) The output tensor of maxout operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of feature."); AddAttr( "groups", R"DOC("Specifies how many groups the input tensor will be split" @@ -59,21 +60,19 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { } }; - class MaxOutOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MaxoutOp" + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of MaxoutOp" "should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of MaxoutOp should not be null."); auto in_x_dims = ctx->GetInputDim("X"); int groups = ctx->Attrs().Get("groups"); // check groups > 1 - PADDLE_ENFORCE_GT( - groups, 1, - "groups should be larger than 1 in maxoutop"); + PADDLE_ENFORCE_GT(groups, 1, "groups should be larger than 1 in maxoutop"); std::vector output_shape({in_x_dims[0], in_x_dims[1] / groups}); output_shape.push_back(in_x_dims[2]); output_shape.push_back(in_x_dims[3]); @@ -87,18 +86,17 @@ class MaxOutOpGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), - "Input(X@GRAD) should not be null."); + "Input(X@GRAD) should not be null."); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad, - ops::MaxOutOpGrad); -REGISTER_OP_CPU_KERNEL(maxout, ops::MaxOutKernel); -REGISTER_OP_CPU_KERNEL(maxout_grad, - ops::MaxOutGradKernel); + ops::MaxOutOpGrad); +REGISTER_OP_CPU_KERNEL(maxout, + ops::MaxOutKernel); +REGISTER_OP_CPU_KERNEL( + maxout_grad, ops::MaxOutGradKernel); diff --git a/paddle/operators/maxout_op.cu.cc b/paddle/operators/maxout_op.cu.cc index a5823fba6848a..decd43913d69d 100644 --- a/paddle/operators/maxout_op.cu.cc +++ b/paddle/operators/maxout_op.cu.cc @@ -18,8 +18,6 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel, ops::MaxOutKernel); -REGISTER_OP_GPU_KERNEL(maxout_grad, - ops::MaxOutGradKernel, - ops::MaxOutGradKernel); +REGISTER_OP_GPU_KERNEL( + maxout_grad, ops::MaxOutGradKernel, + ops::MaxOutGradKernel); diff --git a/paddle/operators/maxout_op.h b/paddle/operators/maxout_op.h index c404cd16a9b23..44a0d073dda64 100644 --- a/paddle/operators/maxout_op.h +++ b/paddle/operators/maxout_op.h @@ -53,7 +53,7 @@ class MaxOutGradKernel : public framework::OpKernel { zero(device_ctx, in_x_grad, static_cast(0.0)); math::MaxOutGradFunctor maxout_backward; maxout_backward(context.device_context(), *in_x, in_x_grad, *out, - *out_grad, groups); + *out_grad, groups); } } }; diff --git a/paddle/operators/roi_pool_op.cc b/paddle/operators/roi_pool_op.cc old mode 100755 new mode 100644 index 156db9358689c..2b5e66c96b726 --- a/paddle/operators/roi_pool_op.cc +++ b/paddle/operators/roi_pool_op.cc @@ -43,8 +43,8 @@ class ROIPoolOp : public framework::OperatorWithKernel { "ROIs should be a 2-D tensor of shape (num_rois, 5)" "given as [[batch_id, x1, y1, x2, y2], …]."); PADDLE_ENFORCE(rois_dims[1] == kROISize, - "ROIs should be a 2-D tensor of shape (num_rois, 5)" - "given as [[batch_id, x1, y1, x2, y2], …]."); + "ROIs should be a 2-D tensor of shape (num_rois, 5)" + "given as [[batch_id, x1, y1, x2, y2], …]."); int pooled_height = ctx->Attrs().Get("pooled_height"); int pooled_width = ctx->Attrs().Get("pooled_width"); @@ -65,7 +65,7 @@ class ROIPoolOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Argmax", out_dims); - } + } protected: framework::OpKernelType GetKernelType( @@ -100,7 +100,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel { class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { public: ROIPoolOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker) + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(Tensor), " @@ -125,21 +125,22 @@ class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor), " "Argmaxes corresponding to indices in X used " "for gradient computation. Only output " - "if arg “is_test” is false.").AsIntermediate(); + "if arg “is_test” is false.") + .AsIntermediate(); AddAttr("spatial_scale", "(float, default 1.0), " "Multiplicative spatial scale factor " "to translate ROI coords from their input scale " "to the scale used when pooling.") - .SetDefault(1.0); + .SetDefault(1.0); AddAttr("pooled_height", "(int, default 1), " "The pooled output height.") - .SetDefault(1); + .SetDefault(1); AddAttr("pooled_width", "(int, default 1), " "The pooled output width.") - .SetDefault(1); + .SetDefault(1); AddComment(R"DOC( ROIPool operator @@ -153,11 +154,10 @@ ROI Pooling for Faster-RCNN. The link below is a further introduction: } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker, - roi_pool_grad, ops::ROIPoolGradOp); +REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker, roi_pool_grad, + ops::ROIPoolGradOp); REGISTER_OP_CPU_KERNEL( - roi_pool, - ops::CPUROIPoolOpKernel, + roi_pool, ops::CPUROIPoolOpKernel, ops::CPUROIPoolOpKernel); REGISTER_OP_CPU_KERNEL( roi_pool_grad, diff --git a/paddle/operators/roi_pool_op.cu b/paddle/operators/roi_pool_op.cu old mode 100755 new mode 100644 index 97df45f1b5779..9a4c8ca752bb7 --- a/paddle/operators/roi_pool_op.cu +++ b/paddle/operators/roi_pool_op.cu @@ -29,101 +29,95 @@ static inline int NumBlocks(const int N) { kNumMaxinumNumBlocks); } - template - __global__ void GPUROIPoolForward( - const int nthreads, const T* input_data, const int64_t* input_rois, - const float spatial_scale, const int channels, const int height, - const int width, const int pooled_height, const int pooled_width, - T* output_data, int64_t* argmax_data) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; - - const int64_t* offset_input_rois = input_rois + n * kROISize; - int roi_batch_ind = offset_input_rois[0]; - int roi_start_w = round(offset_input_rois[1] * spatial_scale); - int roi_start_h = round(offset_input_rois[2] * spatial_scale); - int roi_end_w = round(offset_input_rois[3] * spatial_scale); - int roi_end_h = round(offset_input_rois[4] * spatial_scale); - - int roi_width = max(roi_end_w - roi_start_w + 1, 1); - int roi_height = max(roi_end_h - roi_start_h + 1, 1); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - hstart = min(max(hstart + roi_start_h, 0), height); - hend = min(max(hend + roi_start_h, 0), height); - wstart = min(max(wstart + roi_start_w, 0), width); - wend = min(max(wend + roi_start_w, 0), width); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - T maxval = is_empty ? 0 : -std::numeric_limits::max(); - int maxidx = -1; - const T* offset_input_data = - input_data + (roi_batch_ind * channels + c) * height * width; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int input_data_index = h * width + w; - if (offset_input_data[input_data_index] > maxval) { - maxval = offset_input_data[input_data_index]; - maxidx = input_data_index; - } +template +__global__ void GPUROIPoolForward(const int nthreads, const T* input_data, + const int64_t* input_rois, + const float spatial_scale, const int channels, + const int height, const int width, + const int pooled_height, + const int pooled_width, T* output_data, + int64_t* argmax_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const int64_t* offset_input_rois = input_rois + n * kROISize; + int roi_batch_ind = offset_input_rois[0]; + int roi_start_w = round(offset_input_rois[1] * spatial_scale); + int roi_start_h = round(offset_input_rois[2] * spatial_scale); + int roi_end_w = round(offset_input_rois[3] * spatial_scale); + int roi_end_h = round(offset_input_rois[4] * spatial_scale); + + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + hstart = min(max(hstart + roi_start_h, 0), height); + hend = min(max(hend + roi_start_h, 0), height); + wstart = min(max(wstart + roi_start_w, 0), width); + wend = min(max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + T maxval = is_empty ? 0 : -std::numeric_limits::max(); + int maxidx = -1; + const T* offset_input_data = + input_data + (roi_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_data_index = h * width + w; + if (offset_input_data[input_data_index] > maxval) { + maxval = offset_input_data[input_data_index]; + maxidx = input_data_index; } } - output_data[index] = maxval; - if (argmax_data) { - argmax_data[index] = maxidx; - } + } + output_data[index] = maxval; + if (argmax_data) { + argmax_data[index] = maxidx; } } +} template __global__ void GPUROIPoolBackward( - const int nthreads, - const int64_t* input_rois, - const T* output_grad, - const int64_t* argmax_data, - const int num_rois, - const float spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - T* input_grad) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (int i = index; i < nthreads; i += offset) { - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; - - const int64_t* offset_input_rois = input_rois + n * kROISize; - int roi_batch_ind = offset_input_rois[0]; - int input_offset = (roi_batch_ind * channels + c) * height * width; - int output_offset = (n * channels + c) * pooled_height * pooled_width; - const T* offset_output_grad = output_grad + output_offset; - T* offset_input_grad = input_grad + input_offset; - const int64_t* offset_argmax_data = argmax_data + output_offset; - - int argmax = offset_argmax_data[ph * pooled_width + pw]; - if (argmax != -1) { - platform::CudaAtomicAdd(offset_input_grad + argmax, + const int nthreads, const int64_t* input_rois, const T* output_grad, + const int64_t* argmax_data, const int num_rois, const float spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, T* input_grad) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const int64_t* offset_input_rois = input_rois + n * kROISize; + int roi_batch_ind = offset_input_rois[0]; + int input_offset = (roi_batch_ind * channels + c) * height * width; + int output_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_output_grad = output_grad + output_offset; + T* offset_input_grad = input_grad + input_offset; + const int64_t* offset_argmax_data = argmax_data + output_offset; + + int argmax = offset_argmax_data[ph * pooled_width + pw]; + if (argmax != -1) { + platform::CudaAtomicAdd( + offset_input_grad + argmax, static_cast(offset_output_grad[ph * pooled_width + pw])); - } } } - +} template class GPUROIPoolOpKernel : public framework::OpKernel { @@ -145,25 +139,18 @@ class GPUROIPoolOpKernel : public framework::OpKernel { int width = in_dims[3]; size_t rois_num = rois->dims()[0]; - if (rois_num== 0) return; + if (rois_num == 0) return; int output_size = out->numel(); int blocks = NumBlocks(output_size); int threads = kNumCUDAThreads; - GPUROIPoolForward - <<>>( - output_size, - in->data(), - rois->data(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - out->mutable_data(ctx.GetPlace()), - argmax->mutable_data(ctx.GetPlace())); + GPUROIPoolForward< + T><<>>( + output_size, in->data(), rois->data(), spatial_scale, + channels, height, width, pooled_height, pooled_width, + out->mutable_data(ctx.GetPlace()), + argmax->mutable_data(ctx.GetPlace())); } }; @@ -175,10 +162,8 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel { auto* rois = ctx.Input("ROIs"); auto* argmax = ctx.Input("Argmax"); - auto* out_grad = - ctx.Input(framework::GradVarName("Out")); - auto* x_grad = - ctx.Output(framework::GradVarName("X")); + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto* x_grad = ctx.Output(framework::GradVarName("X")); auto pooled_height = ctx.Attr("pooled_height"); auto pooled_width = ctx.Attr("pooled_width"); @@ -199,21 +184,13 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel { int threads = kNumCUDAThreads; if (output_grad_size > 0) { - GPUROIPoolBackward - <<>>( - output_grad_size, - rois->data(), - out_grad->data(), - argmax->data(), - rois_num, - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - x_grad->mutable_data(ctx.GetPlace())); - } + GPUROIPoolBackward< + T><<>>( + output_grad_size, rois->data(), out_grad->data(), + argmax->data(), rois_num, spatial_scale, channels, height, + width, pooled_height, pooled_width, + x_grad->mutable_data(ctx.GetPlace())); + } } } }; @@ -223,8 +200,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - roi_pool, - ops::GPUROIPoolOpKernel, + roi_pool, ops::GPUROIPoolOpKernel, ops::GPUROIPoolOpKernel); REGISTER_OP_GPU_KERNEL( roi_pool_grad, diff --git a/paddle/operators/roi_pool_op.h b/paddle/operators/roi_pool_op.h old mode 100755 new mode 100644 index bd7736d63125f..1691eb482b03e --- a/paddle/operators/roi_pool_op.h +++ b/paddle/operators/roi_pool_op.h @@ -136,8 +136,7 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel { auto* out_grad = ctx.Input(framework::GradVarName("Out")); - auto* x_grad = - ctx.Output(framework::GradVarName("X")); + auto* x_grad = ctx.Output(framework::GradVarName("X")); auto pooled_height = ctx.Attr("pooled_height"); auto pooled_width = ctx.Attr("pooled_width"); diff --git a/paddle/operators/sequence_slice_op.cc b/paddle/operators/sequence_slice_op.cc old mode 100755 new mode 100644 index cbe0b4233160d..255683a572c0e --- a/paddle/operators/sequence_slice_op.cc +++ b/paddle/operators/sequence_slice_op.cc @@ -45,7 +45,7 @@ class SequenceSliceOp : public framework::OperatorWithKernel { // Initialize the output's dims to maximum, // and re-set to real dims by the value of Offset and Length at kernel ctx->SetOutputDim("Out", input_dims); - } + } protected: framework::OpKernelType GetKernelType( @@ -93,8 +93,7 @@ class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor), " "a vector to describe the length of every input sequence for " "sub sequence item."); - AddOutput("Out", - "(LoDTensor), the output of SequenceSliceOp."); + AddOutput("Out", "(LoDTensor), the output of SequenceSliceOp."); AddComment(R"DOC( Sequence slice operator diff --git a/python/paddle/v2/dataset/uci_housing.py b/python/paddle/v2/dataset/uci_housing.py index 98b97c75ca72f..f10bf7e42a1ea 100644 --- a/python/paddle/v2/dataset/uci_housing.py +++ b/python/paddle/v2/dataset/uci_housing.py @@ -38,6 +38,7 @@ URL_MODEL = 'https://github.com/PaddlePaddle/book/raw/develop/01.fit_a_line/fit_a_line.tar' MD5_MODEL = '52fc3da8ef3937822fcdd87ee05c0c9b' + def feature_range(maximums, minimums): import matplotlib matplotlib.use('Agg') @@ -114,7 +115,8 @@ def reader(): def model(): - tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'fit_a_line.tar', MD5_MODEL) + tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'fit_a_line.tar', + MD5_MODEL) with open(tar_file, 'r') as f: parameters = Parameters.from_tar(f) return parameters diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index c96d186ffe8d9..8ca45134dc01e 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -35,6 +35,13 @@ accuracy = fluid.evaluator.Accuracy(input=predict, label=label) +inference_program = fluid.default_main_program().clone() +test_accuracy = fluid.evaluator.Accuracy( + input=predict, label=label, main_program=inference_program) +test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states +inference_program = fluid.io.get_inference_program( + test_target, main_program=inference_program) + train_reader = paddle.batch( paddle.reader.shuffle( paddle.dataset.mnist.train(), buf_size=8192), @@ -69,11 +76,6 @@ acc = np.array(outs[1]) pass_acc = accuracy.eval(exe) - test_accuracy = fluid.evaluator.Accuracy(input=predict, label=label) - - test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states - inference_program = fluid.io.get_inference_program(test_target) - test_accuracy.reset(exe) for data in test_reader(): x_data = np.array(map(lambda x: x[0], data)).astype("float32") diff --git a/python/paddle/v2/fluid/tests/test_maxout_op.py b/python/paddle/v2/fluid/tests/test_maxout_op.py index 05e42f315833c..5fbed43e254b8 100644 --- a/python/paddle/v2/fluid/tests/test_maxout_op.py +++ b/python/paddle/v2/fluid/tests/test_maxout_op.py @@ -30,9 +30,7 @@ def test_check_grad(self): def init_test_case(self): self.MaxOut_forward_naive = maxout_forward_naive self.shape = [100, 6, 2, 2] - self.groups=2 - - + self.groups = 2 if __name__ == '__main__': diff --git a/python/paddle/v2/fluid/tests/test_roi_pool_op.py b/python/paddle/v2/fluid/tests/test_roi_pool_op.py index 7cedb930ca861..a28d9c7f82d37 100644 --- a/python/paddle/v2/fluid/tests/test_roi_pool_op.py +++ b/python/paddle/v2/fluid/tests/test_roi_pool_op.py @@ -4,24 +4,22 @@ import sys from op_test import OpTest + class TestROIPoolOp(OpTest): def set_data(self): self.init_test_case() self.make_rois() self.calc_roi_pool() - self.inputs = { - 'X': self.x, - 'ROIs': self.rois} - + self.inputs = {'X': self.x, 'ROIs': self.rois} + self.attrs = { 'spatial_scale': self.spatial_scale, 'pooled_height': self.pooled_height, - 'pooled_width': self.pooled_width} + 'pooled_width': self.pooled_width + } - self.outputs = { - 'Out': self.outs, - 'Argmax': self.argmaxes} + self.outputs = {'Out': self.outs, 'Argmax': self.argmaxes} def init_test_case(self): self.batch_size = 5 @@ -30,10 +28,9 @@ def init_test_case(self): self.width = 4 # n, c, h, w - self.x_dim = (self.batch_size, self.channels, - self.height, self.width) + self.x_dim = (self.batch_size, self.channels, self.height, self.width) - self.spatial_scale = 1.0/4.0 + self.spatial_scale = 1.0 / 4.0 self.pooled_height = 2 self.pooled_width = 2 self.rois_num = 2 @@ -41,13 +38,11 @@ def init_test_case(self): self.x = np.random.random(self.x_dim).astype('float32') def calc_roi_pool(self): - out_data = np.zeros( - (self.rois_num, self.channels, - self.pooled_height, self.pooled_width)) - argmax_data = np.zeros( - (self.rois_num, self.channels, - self.pooled_height, self.pooled_width)) - + out_data = np.zeros((self.rois_num, self.channels, self.pooled_height, + self.pooled_width)) + argmax_data = np.zeros((self.rois_num, self.channels, + self.pooled_height, self.pooled_width)) + for i in range(self.rois_num): roi = self.rois[i] roi_batch_id = roi[0] @@ -56,8 +51,8 @@ def calc_roi_pool(self): roi_end_w = int(round(roi[3] * self.spatial_scale)) roi_end_h = int(round(roi[4] * self.spatial_scale)) - roi_height = int(max(roi_end_h - roi_start_h + 1, 1)); - roi_width = int(max(roi_end_w - roi_start_w + 1, 1)); + roi_height = int(max(roi_end_h - roi_start_h + 1, 1)) + roi_width = int(max(roi_end_w - roi_start_w + 1, 1)) x_i = self.x[roi_batch_id] @@ -84,7 +79,7 @@ def calc_roi_pool(self): out_data[i, c, ph, pw] = -sys.float_info.max argmax_data[i, c, ph, pw] = -1 - + for h in range(hstart, hend): for w in range(wstart, wend): if x_i[c, h, w] > out_data[i, c, ph, pw]: @@ -104,11 +99,11 @@ def make_rois(self): y1 = np.random.random_integers( 0, self.height / self.spatial_scale - self.pooled_height) - x2 = np.random.random_integers( - x1 + self.pooled_width, self.width / self.spatial_scale) - y2 = np.random.random_integers( - y1 + self.pooled_height, self.height / self.spatial_scale) - + x2 = np.random.random_integers(x1 + self.pooled_width, + self.width / self.spatial_scale) + y2 = np.random.random_integers(y1 + self.pooled_height, + self.height / self.spatial_scale) + roi = [batch_ids[i], x1, y1, x2, y2] rois.append(roi) self.rois = np.array(rois).astype("int64") @@ -123,5 +118,6 @@ def test_check_output(self): def test_check_grad(self): self.check_grad(['X'], 'Out') + if __name__ == '__main__': unittest.main()