From 97e9dd72375258ed69fbbab39f340d23878002f5 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 8 Nov 2017 14:15:58 +0800 Subject: [PATCH 1/9] add dilation for im2col --- paddle/operators/conv_cudnn_op.cc | 2 - paddle/operators/conv_op.cc | 13 +- paddle/operators/conv_op.h | 29 +- paddle/operators/conv_transpose_op.h | 16 +- paddle/operators/math/context_project.h | 10 +- paddle/operators/math/im2col.cc | 281 +++++++++--------- paddle/operators/math/im2col.cu | 366 +++++++++++++----------- paddle/operators/math/im2col.h | 11 +- paddle/operators/math/im2col_test.cc | 18 +- 9 files changed, 395 insertions(+), 351 deletions(-) diff --git a/paddle/operators/conv_cudnn_op.cc b/paddle/operators/conv_cudnn_op.cc index 97f31bf22d707..4c65b60d2349d 100644 --- a/paddle/operators/conv_cudnn_op.cc +++ b/paddle/operators/conv_cudnn_op.cc @@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker { CudnnConvOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : Conv2DOpMaker(proto, op_checker) { - AddAttr>("dilations", "dilations of convolution operator.") - .SetDefault(std::vector{1, 1}); AddAttr("workspace_size_MB", "workspace size for cudnn, in MB, " "workspace is a section of GPU memory which will be " diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index a6f65f1016592..852ac2ae37ca5 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); int groups = ctx->Attrs().Get("groups"); + std::vector dilations = ctx->Attrs().Get>("dilations"); int input_channels = in_dims[1]; int output_channels = filter_dims[0]; @@ -54,7 +55,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < paddings.size(); ++i) { output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], - paddings[i], strides[i])); + dilations[i], paddings[i], paddings[i], + strides[i])); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } @@ -90,6 +92,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, "first half of the input channels, while the second half of the filters " "is only connected to the second half of the input channels.") .SetDefault(1); + AddAttr>("dilations", + "(vector default:{1, 1}), the dilations of " + "convolution operator.") + .SetDefault(std::vector{1, 1}); AddComment(R"DOC( Convolution Operator. @@ -151,6 +157,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, "first half of the input channels, while the second half of the filters " "is only connected to the second half of the input channels.") .SetDefault(1); + AddAttr>("dilations", + "(vector default:{1, 1, 1}), the dilations of " + "convolution operator. Currently, conv3d doesn't " + "support dilation.") + .SetDefault(std::vector{1, 1, 1}); AddComment(R"DOC( Convolution3D Operator. diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 7c1729213bf3f..2459f03a1a925 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -27,9 +27,12 @@ using Tensor = framework::Tensor; // Base convolution operator definations for other conv // like operators to reuse the implementation. -inline int OutputSize(int input_size, int filter_size, int padding, - int stride) { - int output_size = (input_size - filter_size + 2 * padding) / stride + 1; +inline int OutputSize(int input_size, int filter_size, int dilation, + int padding_up, int padding_down, int stride) { + int output_size = (input_size + padding_up + padding_down - + (dilation * (filter_size - 1) + 1)) / + stride + + 1; return output_size; } @@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); int groups = context.Attr("groups"); + std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { // im2col math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); + im2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], paddings[0], + paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { // vol2col math::Vol2ColFunctor vol2col; @@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); int groups = context.Attr("groups"); + std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { math::Col2ImFunctor col2im; - col2im(context.device_context(), in_grad_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); + col2im(context.device_context(), in_grad_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], + paddings[0], paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { math::Col2VolFunctor col2vol; @@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); + im2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], + paddings[0], paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { math::Vol2ColFunctor vol2col; vol2col(context.device_context(), in_slice, col, strides[0], diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index 6c1a6220d784a..cbfad88b3982a 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -69,6 +69,9 @@ class GemmConvTransposeKernel : public framework::OpKernel { // TODO(Zhuoyuan): Paddings can be added in future. // groups will alway be disabled in conv2dtranspose. + int dilation_h = 1; + int dilation_w = 1; + const int batch_size = static_cast(input->dims()[0]); // input_shape_vec: {h, w} or {d, h, w} @@ -140,8 +143,8 @@ class GemmConvTransposeKernel : public framework::OpKernel { // from (c * k_h * k_w, h * w) to (c, o_h, o_w) math::Col2ImFunctor col2im; - col2im(context.device_context(), output_batch, col, strides[0], - strides[1], 0, 0, 0, 0); + col2im(context.device_context(), output_batch, col, dilation_h, + dilation_w, strides[0], strides[1], 0, 0, 0, 0); } else if (filter_shape_vec.size() == 3) { // col2vol: col_matrix -> dy // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) @@ -174,6 +177,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // Actually, no paddings and groups allowed in conv transpose. std::vector paddings = context.Attr>("paddings"); + int dilation_h = 1; + int dilation_w = 1; + const int batch_size = static_cast(input->dims()[0]); // input_shape_vec: {h, w} or {d, h, w} @@ -248,9 +254,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // im2col: dy -> col matrix // from (c, o_h, o_w) to (c * k_h * k_w, h * w) math::Im2ColFunctor im2col; - im2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); + im2col(context.device_context(), output_grad_batch, col, dilation_h, + dilation_w, strides[0], strides[1], paddings[0], paddings[0], + paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { // vol2col: dy -> col_matrix // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) diff --git a/paddle/operators/math/context_project.h b/paddle/operators/math/context_project.h index e0283360414fb..c67d84528fdd3 100644 --- a/paddle/operators/math/context_project.h +++ b/paddle/operators/math/context_project.h @@ -95,6 +95,9 @@ class ContextProjectFunctor { math::Im2ColFunctor im2col_ocf; + int dilation_h = 1; + int dilation_w = 1; + int input_row_begin, input_row_end; int sequence_height, sequence_width; sequence_width = in.dims()[1]; @@ -124,7 +127,7 @@ class ContextProjectFunctor { sequence_width}); // input_channels, input_height, input_width in_t.Resize(framework::make_ddim(input_shape)); - im2col_ocf(context, in_t, out_t, + im2col_ocf(context, in_t, out_t, dilation_h, dilation_w, /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, down_pad, 0, 0); out_t.Resize({sequence_height, context_length * sequence_width}); @@ -204,6 +207,9 @@ class ContextProjectGradFunctor { math::Col2ImFunctor col2im_ocf; + int dilation_h = 1; + int dilation_w = 1; + int input_row_begin, input_row_end; int sequence_height, sequence_width; sequence_width = in.dims()[1]; @@ -234,7 +240,7 @@ class ContextProjectGradFunctor { sequence_width}); // input_channels, input_height, input_width in_t.Resize(framework::make_ddim(input_shape)); - col2im_ocf(context, in_t, out_t, + col2im_ocf(context, in_t, out_t, dilation_h, dilation_w, /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, down_pad, 0, 0); out_t.Resize({sequence_height, context_length * sequence_width}); diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 3b1b0bd71dd37..b248863b4e96a 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -29,35 +29,36 @@ class Im2ColFunctor(); T* col_data = col.data(); @@ -66,19 +67,19 @@ class Im2ColFunctor= input_height || im_col_idx < 0 || - im_col_idx >= input_width) { - col_data[(c * output_height + h) * output_width + w] = T(0); - } else { - im_row_idx += c_im * input_height; - col_data[(c * output_height + h) * output_width + w] = - im_data[im_row_idx * input_width + im_col_idx]; - } + col_data[(c * col_height + h) * col_width + w] = + (im_row_idx < 0 || im_row_idx >= im_height || im_col_idx < 0 || + im_col_idx >= im_width) + ? static_cast(0) + : im_data[(im_row_idx + c_im * im_height) * im_width + + im_col_idx]; } } } @@ -95,35 +96,35 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int input_channels = im.dims()[0]; - int input_height = im.dims()[1]; - int input_width = im.dims()[2]; + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; - int output_height = col.dims()[3]; - int output_width = col.dims()[4]; + int col_height = col.dims()[3]; + int col_width = col.dims()[4]; - PADDLE_ENFORCE_EQ( - (input_height + padding_up + padding_down - filter_height) / - stride_height + - 1, - output_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ( - (input_width + padding_left + padding_right - filter_width) / - stride_width + - 1, - output_width, - "output_width and padding(padding_left, padding_right) are " - "inconsistent."); + PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); - int channels_col = input_channels * filter_height * filter_width; + int channels_col = im_channels * filter_height * filter_width; T* im_data = im.data(); const T* col_data = col.data(); @@ -132,16 +133,18 @@ class Col2ImFunctor= 0 && (im_row_idx) < input_height && - (im_col_idx) >= 0 && (im_col_idx) < input_width) { - im_row_idx += c_im * input_height; - im_data[im_row_idx * input_width + im_col_idx] += - col_data[(c * output_height + h) * output_width + w]; + if ((im_row_idx) >= 0 && (im_row_idx) < im_height && + (im_col_idx) >= 0 && (im_col_idx) < im_width) { + im_row_idx += c_im * im_height; + im_data[im_row_idx * im_width + im_col_idx] += + col_data[(c * col_height + h) * col_width + w]; } } } @@ -169,39 +172,38 @@ class Im2ColFunctor(); T* col_data = col.data(); - for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { - for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { - for (int channel = 0; channel < input_channels; ++channel) { + for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { + for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { + for (int channel = 0; channel < im_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; ++filter_row_idx) { for (int filter_col_idx = 0; filter_col_idx < filter_width; @@ -210,22 +212,21 @@ class Im2ColFunctor= input_height || - im_col_offset < 0 || im_col_offset >= input_width) { - col_data[col_offset] = T(0); - } else { - int im_offset = - (channel * input_height + im_row_offset) * input_width + - im_col_offset; - col_data[col_offset] = im_data[im_offset]; - } + int col_offset = + ((((col_row_idx)*col_width + col_col_idx) * im_channels + + channel) * + filter_height + + filter_row_idx) * + filter_width + + filter_col_idx; + + int im_offset = (channel * im_height + im_row_offset) * im_width + + im_col_offset; + col_data[col_offset] = + (im_row_offset < 0 || im_row_offset >= im_height || + im_col_offset < 0 || im_col_offset >= im_width) + ? static_cast(0) + : im_data[im_offset]; } } } @@ -244,40 +245,38 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int input_channels = im.dims()[0]; - int input_height = im.dims()[1]; - int input_width = im.dims()[2]; + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; - int output_height = col.dims()[0]; - int output_width = col.dims()[1]; + int col_height = col.dims()[0]; + int col_width = col.dims()[1]; - PADDLE_ENFORCE_EQ( - (input_height + padding_up + padding_down - filter_height) / - stride_height + - 1, - output_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ( - (input_width + padding_left + padding_right - filter_width) / - stride_width + - 1, - output_width, - "output_width and padding(padding_left, padding_right) are " - "inconsistent."); + PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) / + stride_height + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) / + stride_width + + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); T* im_data = im.data(); const T* col_data = col.data(); - for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { - for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { - for (int channel = 0; channel < input_channels; ++channel) { + for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { + for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { + for (int channel = 0; channel < im_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; ++filter_row_idx) { for (int filter_col_idx = 0; filter_col_idx < filter_width; @@ -286,17 +285,17 @@ class Col2ImFunctor= 0 && im_row_offset < input_height && - im_col_offset >= 0 && im_col_offset < input_width) { + int col_offset = + (((col_row_idx * col_width + col_col_idx) * im_channels + + channel) * + filter_height + + filter_row_idx) * + filter_width + + filter_col_idx; + if (im_row_offset >= 0 && im_row_offset < im_height && + im_col_offset >= 0 && im_col_offset < im_width) { int im_offset = - (channel * input_height + im_row_offset) * input_width + + (channel * im_height + im_row_offset) * im_width + im_col_offset; im_data[im_offset] += col_data[col_offset]; } diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index 7b201fdbf3c5d..69e2abee03b34 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -20,36 +20,32 @@ namespace operators { namespace math { template -__global__ void im2col(const T* data_im, int num_outs, int height, int width, +__global__ void im2col(const T* data_im, int num_outs, int im_height, + int im_width, int dilation_h, int dilation_w, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, - int output_height, int output_width, T* data_col) { - int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + int col_height, int col_width, T* data_col) { + const int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; if (index < num_outs) { - int w_out = index % output_width; - index /= output_width; - int h_out = index % output_height; - int channel_in = index / output_height; + int w_out = index % col_width; + int h_out = (index / col_width) % col_height; + int channel_in = index / col_width / col_height; int channel_out = channel_in * filter_height * filter_width; - int h_in = h_out * stride_height; - int w_in = w_out * stride_width; + int h_in = h_out * stride_height - padding_height; + int w_in = w_out * stride_width - padding_width; - data_col += (channel_out * output_height + h_out) * output_width + w_out; + data_col += (channel_out * col_height + h_out) * col_width + w_out; + data_im += (channel_in * im_height + h_in) * im_width + w_in; for (int i = 0; i < filter_height; ++i) { for (int j = 0; j < filter_width; ++j) { - int rIdx = int(h_in + i); - int cIdx = int(w_in + j); - if ((rIdx - (int)padding_height) >= (int)height || - (rIdx - (int)padding_height) < 0 || - (cIdx - (int)padding_width) >= (int)width || - (cIdx - (int)padding_width) < 0) { - *data_col = 0; - } else { - rIdx = rIdx + channel_in * height - padding_height; - cIdx = cIdx - padding_width; - *data_col = data_im[rIdx * width + cIdx]; - } - data_col += output_height * output_width; + int rIdx = h_in + i * dilation_h; + int cIdx = w_in + j * dilation_w; + *data_col = + (rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0) + ? 0 + : data_im[i * dilation_h * im_width + j * dilation_w]; + data_col += col_height * col_width; } } } @@ -66,29 +62,36 @@ class Im2ColFunctor<<(context) .stream()>>>( - im.data(), num_outputs, input_height, input_width, filter_height, - filter_width, stride_height, stride_width, padding_up, padding_left, - output_height, output_width, col.data()); + im.data(), num_outputs, im_height, im_width, dilation_h, dilation_w, + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, col_height, col_width, col.data()); } }; template -__global__ void col2im(size_t n, const T* data_col, size_t height, size_t width, - size_t channels, size_t filter_height, - size_t filter_width, size_t stride_height, - size_t stride_width, size_t padding_height, - size_t padding_width, size_t output_height, - size_t output_width, T* data_im) { - size_t index = +__global__ void col2im(int n, const T* data_col, int im_height, int im_width, + int dilation_h, int dilation_w, int filter_height, + int filter_width, int stride_height, int stride_width, + int padding_height, int padding_width, int col_height, + int col_width, T* data_im) { + const int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + + const int d_filter_height = dilation_h * (filter_height - 1) + 1; + const int d_filter_width = dilation_w * (filter_width - 1) + 1; + if (index < n) { T val = 0; - int w = int(index % width); - int h = int((index / width) % height); - int c = int(index / (width * height)); - if ((w - (int)padding_width) >= 0 && - (w - (int)padding_width) < (width - 2 * padding_width) && - (h - (int)padding_height) >= 0 && - (h - padding_height) < (height - 2 * padding_height)) { - // compute the start and end of the output - int w_col_start = (w < (int)filter_width) - ? 0 - : (w - int(filter_width)) / (int)stride_width + 1; - int w_col_end = - min((int)(w / (int)stride_width + 1), (int)(output_width)); - int h_col_start = (h < (int)filter_height) - ? 0 - : (h - (int)filter_height) / (int)stride_height + 1; - int h_col_end = min(int(h / stride_height + 1), int(output_height)); - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - // the col location: [c * width * height + h_out, w_out] - int c_col = int(c * filter_height * filter_width) + - (h - h_col * (int)stride_height) * (int)filter_width + - (w - w_col * (int)stride_width); - val += - data_col[(c_col * output_height + h_col) * output_width + w_col]; + int w = index % im_width; + int h = (index / im_width) % im_height; + int c = index / (im_width * im_height); + + // compute the start and end of the output + int w_col_start = + (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1; + int w_col_end = min(w / stride_width + 1, col_width); + int h_col_start = + (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1; + int h_col_end = min(h / stride_height + 1, col_height); + + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + int h_off = (h - h_col * stride_height); + int w_off = (w - w_col * stride_width); + if (h_off % dilation_h == 0 && w_off % dilation_w == 0) { + h_off /= dilation_h; + w_off /= dilation_w; + int data_col_index = + (((c * filter_height + h_off) * filter_width + w_off) * + col_height + + h_col) * + col_width + + w_col; + val += data_col[data_col_index]; } } - h -= padding_height; - w -= padding_width; - data_im[c * ((width - 2 * padding_width) * - (height - 2 * padding_height)) + - h * (width - 2 * padding_width) + w] += val; } + data_im[index] = val; } } @@ -160,32 +163,36 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int input_channels = im.dims()[0]; - int input_height = im.dims()[1]; - int input_width = im.dims()[2]; + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; - int output_height = col.dims()[3]; - int output_width = col.dims()[4]; - - PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / - stride_height + - 1 == - output_height); - PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / - stride_width + - 1 == - output_width); - - size_t num_kernels = input_channels * - (input_height + padding_up + padding_down) * - (input_width + padding_left + padding_right); + int col_height = col.dims()[3]; + int col_width = col.dims()[4]; + + PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - + (dilation_h * (filter_height - 1) + 1)) / + stride_height + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - + (dilation_w * (filter_width - 1) + 1)) / + stride_width + + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); + + size_t num_kernels = im_channels * im_height * im_width; size_t blocks = (num_kernels + 1024 - 1) / 1024; size_t block_x = 512; @@ -198,10 +205,9 @@ class Col2ImFunctor<<(context) .stream()>>>( - num_kernels, col.data(), input_height + padding_up + padding_down, - input_width + padding_left + padding_left, input_channels, + num_kernels, col.data(), im_height, im_width, dilation_h, dilation_w, filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, output_height, output_width, im.data()); + padding_left, col_height, col_width, im.data()); } }; @@ -215,33 +221,32 @@ template class Col2ImFunctor; template -__global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, - int input_height, int input_width, int filter_height, +__global__ void im2colOCF(const T* im_data, T* col_data, int im_channels, + int im_height, int im_width, int filter_height, int filter_width, int stride_height, int stride_width, - int padding_height, int padding_width, - int output_height, int output_width) { + int padding_height, int padding_width, int col_height, + int col_width) { int swid = blockIdx.x; int shid = blockIdx.y; - for (int channelid = threadIdx.z; channelid < input_channels; + for (int channelid = threadIdx.z; channelid < im_channels; channelid += blockDim.z) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; int height_offset = idy + shid * stride_height - padding_height; - int im_offset = width_offset + height_offset * input_width + - channelid * input_height * input_width; + int im_offset = width_offset + height_offset * im_width + + channelid * im_height * im_width; int col_offset = idx + idy * filter_width + channelid * filter_height * filter_width + - (shid * output_width + swid) * - (input_channels * filter_height * filter_width); - - if (height_offset >= input_height || height_offset < 0 || - width_offset >= input_width || width_offset < 0) { - col_data[col_offset] = T(0); - } else { - col_data[col_offset] = im_data[im_offset]; - } + (shid * col_width + swid) * + (im_channels * filter_height * filter_width); + + col_data[col_offset] = + (height_offset >= im_height || height_offset < 0 || + width_offset >= im_width || width_offset < 0) + ? T(0) + : im_data[im_offset]; } } } @@ -258,26 +263,33 @@ class Im2ColFunctor<<(context) .stream()>>>( - im.data(), col.data(), input_channels, input_height, input_width, + im.data(), col.data(), im_channels, im_height, im_width, filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, output_height, output_width); + padding_left, col_height, col_width); } }; template -__global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, - int input_height, int input_width, int filter_height, +__global__ void col2imOCF(T* im_data, const T* col_data, int im_channels, + int im_height, int im_width, int filter_height, int filter_width, int stride_height, int stride_width, - int padding_height, int padding_width, - int output_height, int output_width) { + int padding_height, int padding_width, int col_height, + int col_width) { int swid = blockIdx.x; int shid = blockIdx.y; - for (int channelid = threadIdx.z; channelid < input_channels; + for (int channelid = threadIdx.z; channelid < im_channels; channelid += blockDim.z) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; int height_offset = idy + shid * stride_height - padding_height; - int im_offset = width_offset + height_offset * input_width + - channelid * input_height * input_width; + int im_offset = width_offset + height_offset * im_width + + channelid * im_height * im_width; int col_offset = idx + idy * filter_width + channelid * filter_height * filter_width + - (shid * output_width + swid) * - (input_channels * filter_height * filter_width); + (shid * col_width + swid) * + (im_channels * filter_height * filter_width); - if (height_offset >= 0 && height_offset < input_height && - width_offset >= 0 && width_offset < input_width) { + if (height_offset >= 0 && height_offset < im_height && + width_offset >= 0 && width_offset < im_width) { paddle::platform::CudaAtomicAdd(im_data + im_offset, col_data[col_offset]); } @@ -350,27 +361,33 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int input_channels = im.dims()[0]; - int input_height = im.dims()[1]; - int input_width = im.dims()[2]; + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; - int output_height = col.dims()[0]; - int output_width = col.dims()[1]; - - PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / - stride_height + - 1 == - output_height); - PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / - stride_width + - 1 == - output_width); + int col_height = col.dims()[0]; + int col_width = col.dims()[1]; + + PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - + (dilation_h * (filter_height - 1) + 1)) / + stride_height + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - + (dilation_w * (filter_width - 1) + 1)) / + stride_width + + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); int block_dim_x = 0; int block_dim_y = 0; @@ -389,15 +406,14 @@ class Col2ImFunctor<<(context) .stream()>>>( - im.data(), col.data(), input_channels, input_height, input_width, + im.data(), col.data(), im_channels, im_height, im_width, filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, output_height, output_width); + padding_left, col_height, col_width); } }; diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index c736d4fa523c2..d1c9595a328d3 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -74,17 +74,18 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& im, framework::Tensor& col, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right); + int dilation_h, int dilation_w, int stride_height, + int stride_width, int padding_up, int padding_down, + int padding_left, int padding_right); }; template class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right); + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 5763782c4edec..3385fe8721cb4 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -47,6 +47,8 @@ void testIm2col() { int filter_size = 2; int stride = 1; int padding = 0; + int dilation_h = 1; + int dilation_w = 1; int output_height = (input_height - filter_size + 2 * padding) / stride + 1; int output_width = (input_width - filter_size + 2 * padding) / stride + 1; float* input_ptr = input_tmp.mutable_data( @@ -85,10 +87,10 @@ void testIm2col() { paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; - im2col(*context, input, output_cfo, stride, stride, padding, padding, padding, - padding); - im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding, - padding, padding); + im2col(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, + padding, padding, padding, padding); + im2col_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, + stride, padding, padding, padding, padding); float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; @@ -131,8 +133,8 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im(*context, input, output_cfo, stride, stride, padding, padding, padding, - padding); + col2im(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, + padding, padding, padding, padding); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { @@ -153,8 +155,8 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding, - padding, padding); + col2im_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, + stride, padding, padding, padding, padding); if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data(); From b6f9ba484ee285b75d40272f8a2f48267fb3284c Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 8 Nov 2017 18:19:41 +0800 Subject: [PATCH 2/9] fix conv2d doc --- paddle/operators/conv_op.cc | 14 ++++++++++---- python/paddle/v2/framework/tests/test_conv2d_op.py | 5 ++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 852ac2ae37ca5..a848b9b49cd2f 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -54,6 +54,12 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < paddings.size(); ++i) { + PADDLE_ENFORCE(in_dims[i + 2] + 2 * paddings[i] - + (dilations[i] * (filter_dims[i + 2] - 1) + 1) > + 0, + "Due to the settings of paddings, filter_dims and " + "dilations, the output size is less than 0, please check " + "again."); output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], paddings[i], paddings[i], strides[i])); @@ -100,11 +106,11 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, Convolution Operator. The convolution operation calculates the output based on the input, filter -and strides, paddings, groups parameters. The size of each dimension of the +and strides, paddings, groups, dilations parameters. The size of each dimension of the parameters is checked in the infer-shape. Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch size, C is the number of channels, H is the height of the feature, and W is -the width of the feature. Parameters(ksize, strides, paddings) are two elements. +the width of the feature. Parameters(ksize, strides, paddings, dilations) are two elements. These two elements represent height and width, respectively. The input(X) size and output(Out) size may be different. @@ -115,8 +121,8 @@ The input(X) size and output(Out) size may be different. Output: Output shape: (N, C_out, H_out, W_out) where - H_out = (H_in - filter_size[0] + 2 * paddings[0]) / strides[0] + 1; - W_out = (W_in - filter_size[1] + 2 * paddings[1]) / strides[1] + 1; + H_out = (H_in + 2 * paddings[0] - (dilations[0]*(filter_size[0] - 1) + 1)) / strides[0] + 1; + W_out = (W_in + 2 * paddings[1] - (dilations[1]*(filter_size[1] - 1) + 1)) / strides[1] + 1; )DOC"); } diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 04ae7f294c27f..f3f3930dab00a 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -39,6 +39,7 @@ class TestConv2dOp(OpTest): def setUp(self): self.init_op_type() self.init_group() + self.init_dilation() self.init_test_case() conv2d_param = {'stride': self.stride, 'pad': self.pad} @@ -80,12 +81,14 @@ def test_check_grad_no_input(self): def init_test_case(self): self.pad = [0, 0] self.stride = [1, 1] - self.dilations = [1, 1] self.input_size = [2, 3, 5, 5] # NCHW assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3] + def init_dilation(self): + self.dilations = [1, 1] + def init_group(self): self.groups = 1 From 21ce704247b53e08cb092a7602f351464892f528 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 9 Nov 2017 11:02:04 +0800 Subject: [PATCH 3/9] refine conv2d for filter size:(1,1) --- paddle/operators/conv_op.h | 256 ++++++++++++------ .../v2/framework/tests/test_conv2d_op.py | 19 ++ 2 files changed, 192 insertions(+), 83 deletions(-) diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 2459f03a1a925..8e9f3b0b0e9fa 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -35,6 +35,18 @@ inline int OutputSize(int input_size, int filter_size, int dilation, 1; return output_size; } +inline bool NotExpand(std::vector& filter_dim, + std::vector& strides, std::vector& paddings, + std::vector& dilations) { + bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; + for (size_t j = 0; j < strides.size(); ++j) { + filter_1 &= (static_cast(filter_dim[j]) == 1); + strides_1 &= (strides[j] == 1); + padding_0 &= (paddings[j] == 0); + dilation_1 &= (dilations[j] == 1); + } + return filter_1 && strides_1 && padding_0 && dilation_1; +} // Define Op classes in .h file so that other conv // operator implementations can reuse the code. @@ -110,14 +122,17 @@ class GemmConvKernel : public framework::OpKernel { framework::DDim col_matrix_shape = framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; - col.mutable_data(col_shape, context.GetPlace()); // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. Tensor col_matrix; - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); + if (!not_expand) { + col.mutable_data(col_shape, context.GetPlace()); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } framework::DDim input_shape = framework::slice_ddim( input->dims(), 1, static_cast(input->dims().size())); @@ -134,31 +149,51 @@ class GemmConvKernel : public framework::OpKernel { int in_step = static_cast(input->dims()[1]) / groups; int out_step = static_cast(output->dims()[1]) / groups; - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (filter_shape_vec.size() == 2) { - // im2col - math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], paddings[0], - paddings[1], paddings[1]); - } else if (filter_shape_vec.size() == 3) { - // vol2col - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + if (!not_expand) { + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (filter_shape_vec.size() == 2) { + // im2col + math::Im2ColFunctor im2col; + im2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], + paddings[0], paddings[1], paddings[1]); + } else if (filter_shape_vec.size() == 3) { + // vol2col + math::Vol2ColFunctor vol2col; + vol2col(context.device_context(), in_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); } + } + } else { + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, false, - col_matrix, false, T(1.0), &out_slice, T(0.0)); + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); + } } } } @@ -235,14 +270,17 @@ class GemmConvGradKernel : public framework::OpKernel { int in_step = static_cast(input->dims()[1]) / groups; int out_step = static_cast(output_grad->dims()[1]) / groups; + bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. Tensor col_matrix; - col.mutable_data(col_shape, context.GetPlace()); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); + if (!not_expand) { + col.mutable_data(col_shape, context.GetPlace()); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } math::SetConstant set_zero; @@ -250,33 +288,60 @@ class GemmConvGradKernel : public framework::OpKernel { input_grad->mutable_data(context.GetPlace()); set_zero(context.device_context(), input_grad, static_cast(0)); - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // gemm - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, true, - out_grad_slice, false, T(1.0), &col_matrix, - T(0.0)); - // col2im - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - - if (filter_shape_vec.size() == 2) { - math::Col2ImFunctor col2im; - col2im(context.device_context(), in_grad_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - - } else if (filter_shape_vec.size() == 3) { - math::Col2VolFunctor col2vol; - col2vol(context.device_context(), in_grad_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + if (!not_expand) { + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = + filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + + if (filter_shape_vec.size() == 2) { + math::Col2ImFunctor col2im; + col2im(context.device_context(), in_grad_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], + paddings[0], paddings[1], paddings[1]); + + } else if (filter_shape_vec.size() == 3) { + math::Col2VolFunctor col2vol; + col2vol(context.device_context(), in_grad_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } + } + } + } else { + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = + filter.Slice(g * out_step, (g + 1) * out_step); + + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + + col_matrix.ShareDataWith(in_grad_slice); + col_matrix.Resize(col_matrix_shape); + + math::matmul(context.device_context(), filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); } } } @@ -288,34 +353,59 @@ class GemmConvGradKernel : public framework::OpKernel { filter_grad_.Resize(filter_matrix_shape); set_zero(context.device_context(), filter_grad, static_cast(0)); - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // im2col - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (filter_shape_vec.size() == 2) { - math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - } else if (filter_shape_vec.size() == 3) { - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + if (!not_expand) { + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // im2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (filter_shape_vec.size() == 2) { + math::Im2ColFunctor im2col; + im2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], + paddings[0], paddings[1], paddings[1]); + } else if (filter_shape_vec.size() == 3) { + math::Vol2ColFunctor vol2col; + vol2col(context.device_context(), in_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } + + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); + } + } + } else { + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // im2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); } - - // gemm - Tensor filter_grad_slice = - filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), out_grad_slice, - false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); } } } diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index f3f3930dab00a..4ba67cf006f70 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -104,6 +104,25 @@ def init_op_type(self): self.op_type = "conv2d" +class TestWith1x1(TestConv2dOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 1, 1] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 3 + + def init_op_type(self): + self.op_type = "conv2d" + + #----------------Conv2dCudnn---------------- From 93551bd232dacdc4afccb392f507eb48747c2978 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 9 Nov 2017 15:00:48 +0800 Subject: [PATCH 4/9] refine unit test (Add dilation) --- paddle/operators/math/im2col.cc | 12 ++-- .../v2/framework/tests/test_conv2d_op.py | 63 +++++++++++++++---- 2 files changed, 56 insertions(+), 19 deletions(-) diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index b248863b4e96a..2af55fa71f86a 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -73,13 +73,13 @@ class Im2ColFunctor= im_height || im_col_idx < 0 || - im_col_idx >= im_width) - ? static_cast(0) - : im_data[(im_row_idx + c_im * im_height) * im_width + - im_col_idx]; + col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || + im_col_idx < 0 || im_col_idx >= im_width) + ? static_cast(0) + : im_data[im_idx]; } } } diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 4ba67cf006f70..907b52c405d9e 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -10,23 +10,33 @@ def conv2d_forward_naive(input, filter, group, conv_param): assert np.mod(out_c, group) == 0 sub_out_c = out_c / group - stride, pad = conv_param['stride'], conv_param['pad'] - out_h = 1 + (in_h + 2 * pad[0] - f_h) / stride[0] - out_w = 1 + (in_w + 2 * pad[1] - f_w) / stride[1] + stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[ + 'dilation'] + out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) / stride[0] + out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) / stride[1] out = np.zeros((in_n, out_c, out_h, out_w)) + d_bolck_w = (dilation[0] * (f_h - 1) + 1) + d_bolck_h = (dilation[1] * (f_w - 1) + 1) + input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], )), mode='constant', constant_values=0) + + filter_dilation = np.zeros((out_c, f_c, d_bolck_h, d_bolck_w)) + filter_dilation[:, :, 0:d_bolck_h:dilation[0], 0:d_bolck_w:dilation[ + 1]] = filter + for i in range(out_h): for j in range(out_w): for g in range(group): input_pad_masked = \ input_pad[:, g * f_c:(g + 1) * f_c, - i * stride[0]:i * stride[0] + f_h, - j * stride[1]:j * stride[1] + f_w] + i * stride[0]:i * stride[0] + d_bolck_h, + j * stride[1]:j * stride[1] + d_bolck_w] - f_sub = filter[g * sub_out_c:(g + 1) * sub_out_c, :, :, :] + f_sub = filter_dilation[g * sub_out_c:(g + 1) * + sub_out_c, :, :, :] for k in range(sub_out_c): out[:, g * sub_out_c + k, i, j] = \ np.sum(input_pad_masked * f_sub[k, :, :, :], @@ -42,7 +52,11 @@ def setUp(self): self.init_dilation() self.init_test_case() - conv2d_param = {'stride': self.stride, 'pad': self.pad} + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } input = np.random.random(self.input_size).astype("float32") filter = np.random.random(self.filter_size).astype("float32") output = conv2d_forward_naive(input, filter, self.groups, @@ -123,24 +137,47 @@ def init_op_type(self): self.op_type = "conv2d" -#----------------Conv2dCudnn---------------- +class TestWithDilation(TestConv2dOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 3, 3] + def init_dilation(self): + self.dilations = [2, 2] -class TestCudnn(TestConv2dOp): def init_group(self): - self.groups = 1 + self.groups = 3 + def init_op_type(self): + self.op_type = "conv2d" + + +#----------------Conv2dCudnn---------------- + + +class TestCudnn(TestConv2dOp): def init_op_type(self): self.op_type = "conv_cudnn" -class TestCudnnWithGroup(TestConv2dOp): - def init_group(self): - self.groups = 3 +class TestCudnnWithGroup(TestWithGroup): + def init_op_type(self): + self.op_type = "conv_cudnn" + +class TestCudnnWith1x1(TestWith1x1): def init_op_type(self): self.op_type = "conv_cudnn" +# cudnn v5 does not support dilation conv. +# class TestCudnnWithDilation(TestWithDilation): +# def init_op_type(self): +# self.op_type = "conv_cudnn" + if __name__ == '__main__': unittest.main() From 271fc9c1198e90813fee647b7020ee752aae549a Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 10 Nov 2017 10:25:44 +0800 Subject: [PATCH 5/9] Add dilation for vol2col --- paddle/operators/conv_op.h | 15 +-- paddle/operators/conv_transpose_op.h | 13 ++- paddle/operators/math/im2col.cu | 1 + paddle/operators/math/vol2col.cc | 80 ++++++++++++--- paddle/operators/math/vol2col.cu | 139 +++++++++++++++++++------- paddle/operators/math/vol2col.h | 2 + paddle/operators/math/vol2col_test.cc | 9 +- 7 files changed, 189 insertions(+), 70 deletions(-) diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 8e9f3b0b0e9fa..af2c8fb163eca 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -165,9 +165,9 @@ class GemmConvKernel : public framework::OpKernel { } else if (filter_shape_vec.size() == 3) { // vol2col math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + vol2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], dilations[2], strides[0], strides[1], + strides[2], paddings[0], paddings[1], paddings[2]); } // gemm @@ -314,7 +314,8 @@ class GemmConvGradKernel : public framework::OpKernel { } else if (filter_shape_vec.size() == 3) { math::Col2VolFunctor col2vol; - col2vol(context.device_context(), in_grad_slice, col, strides[0], + col2vol(context.device_context(), in_grad_slice, col, + dilations[0], dilations[1], dilations[2], strides[0], strides[1], strides[2], paddings[0], paddings[1], paddings[2]); } @@ -371,9 +372,9 @@ class GemmConvGradKernel : public framework::OpKernel { paddings[0], paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + vol2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], dilations[2], strides[0], strides[1], + strides[2], paddings[0], paddings[1], paddings[2]); } // gemm diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index cbfad88b3982a..18ca6b20e0349 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -69,6 +69,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { // TODO(Zhuoyuan): Paddings can be added in future. // groups will alway be disabled in conv2dtranspose. + int dilaiton_d = 1; int dilation_h = 1; int dilation_w = 1; @@ -149,8 +150,9 @@ class GemmConvTransposeKernel : public framework::OpKernel { // col2vol: col_matrix -> dy // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) math::Col2VolFunctor col2vol; - col2vol(context.device_context(), output_batch, col, strides[0], - strides[1], strides[2], 0, 0, 0); + col2vol(context.device_context(), output_batch, col, dilaiton_d, + dilation_h, dilation_w, strides[0], strides[1], strides[2], 0, + 0, 0); } } } @@ -177,6 +179,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // Actually, no paddings and groups allowed in conv transpose. std::vector paddings = context.Attr>("paddings"); + int dilaiton_d = 1; int dilation_h = 1; int dilation_w = 1; @@ -261,9 +264,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // vol2col: dy -> col_matrix // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + vol2col(context.device_context(), output_grad_batch, col, dilaiton_d, + dilation_h, dilation_w, strides[0], strides[1], strides[2], + paddings[0], paddings[1], paddings[2]); } if (input_grad) { diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index 69e2abee03b34..9da427fdf1477 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -145,6 +145,7 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width, h_col) * col_width + w_col; + val += data_col[data_col_index]; } } diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc index e9718a0473815..d383ee81526ae 100644 --- a/paddle/operators/math/vol2col.cc +++ b/paddle/operators/math/vol2col.cc @@ -29,6 +29,7 @@ class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& vol, framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const { @@ -48,6 +49,28 @@ class Vol2ColFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; + PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - + ((dilation_d * (filter_depth - 1) + 1))) / + stride_depth + + 1, + output_depth, + "input_depth and output_depth are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + output_height, + "input_height and output_height are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + output_width, + "input_width and output_width are " + "Mismatching."); + const T* vol_data = vol.data(); T* col_data = col.data(); @@ -57,24 +80,25 @@ class Vol2ColFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int c_in = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * stride_depth - padding_depth + d_offset; + int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; for (int h = 0; h < output_height; ++h) { - int h_pad = h * stride_height - padding_height + h_offset; + int h_pad = + h * stride_height - padding_height + h_offset * dilation_h; for (int w = 0; w < output_width; ++w) { - int w_pad = w * stride_width - padding_width + w_offset; + int w_pad = + w * stride_width - padding_width + w_offset * dilation_w; int col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; - if (h_pad < 0 || h_pad >= input_height || w_pad < 0 || - w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) { - col_data[col_idx] = static_cast(0); - } else { - int vol_idx = - ((c_in * input_depth + d_pad) * input_height + h_pad) * - input_width + - w_pad; - col_data[col_idx] = vol_data[vol_idx]; - } + int vol_idx = + ((c_in * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + col_data[col_idx] = + (h_pad < 0 || h_pad >= input_height || w_pad < 0 || + w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) + ? static_cast(0) + : vol_data[vol_idx]; } } } @@ -93,6 +117,7 @@ class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& vol, const framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const { @@ -112,6 +137,27 @@ class Col2VolFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; + PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - + ((dilation_d * (filter_depth - 1) + 1))) / + stride_depth + + 1, + output_depth, + "input_depth and output_depth are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + output_height, + "input_height and output_height are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + output_width, + "input_width and output_width are " + "Mismatching."); T* vol_data = vol.data(); const T* col_data = col.data(); @@ -121,11 +167,13 @@ class Col2VolFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int cIm = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * stride_depth - padding_depth + d_offset; + int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; for (int h = 0; h < output_height; ++h) { - int h_pad = h * stride_height - padding_height + h_offset; + int h_pad = + h * stride_height - padding_height + h_offset * dilation_h; for (int w = 0; w < output_width; ++w) { - int w_pad = w * stride_width - padding_width + w_offset; + int w_pad = + w * stride_width - padding_width + w_offset * dilation_w; if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { diff --git a/paddle/operators/math/vol2col.cu b/paddle/operators/math/vol2col.cu index 27b11fb237575..080d3e5466704 100644 --- a/paddle/operators/math/vol2col.cu +++ b/paddle/operators/math/vol2col.cu @@ -21,11 +21,12 @@ namespace math { template __global__ void vol2col(int num_kernels, const T* data_vol, int depth, - int height, int width, int filter_depth, - int filter_height, int filter_width, int stride_depth, - int stride_height, int stride_width, int padding_depth, - int padding_height, int padding_width, int output_detph, - int output_height, int output_width, T* data_col) { + int height, int width, int dilation_d, int dilation_h, + int dilation_w, int filter_depth, int filter_height, + int filter_width, int stride_depth, int stride_height, + int stride_width, int padding_depth, int padding_height, + int padding_width, int output_detph, int output_height, + int output_width, T* data_col) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { int w_out = index % output_width; @@ -44,12 +45,14 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth, for (int k = 0; k < filter_depth; ++k) { for (int i = 0; i < filter_height; ++i) { for (int j = 0; j < filter_width; ++j) { - int d = d_in + k; - int h = h_in + i; - int w = w_in + j; + int d = d_in + k * dilation_d; + int h = h_in + i * dilation_h; + int w = w_in + j * dilation_w; + int col_idx = (k * dilation_d * height + i * dilation_h) * width + + j * dilation_w; *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && w < width) - ? data_vol[(k * height + i) * width + j] + ? data_vol[col_idx] : 0; data_col += output_detph * output_height * output_width; } @@ -69,6 +72,7 @@ class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& vol, framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const { @@ -86,6 +90,28 @@ class Vol2ColFunctor { int output_height = col.dims()[5]; int output_width = col.dims()[6]; + PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - + ((dilation_d * (filter_depth - 1) + 1))) / + stride_depth + + 1, + output_depth, + "input_depth and output_depth are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + output_height, + "input_height and output_height are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + output_width, + "input_width and output_width are " + "Mismatching."); + int num_outputs = input_channels * output_depth * output_height * output_width; @@ -95,19 +121,25 @@ class Vol2ColFunctor { reinterpret_cast(context) .stream()>>>( num_outputs, vol.data(), input_depth, input_height, input_width, - filter_depth, filter_height, filter_width, stride_depth, stride_height, - stride_width, padding_depth, padding_height, padding_width, - output_depth, output_height, output_width, col.data()); + dilation_d, dilation_h, dilation_w, filter_depth, filter_height, + filter_width, stride_depth, stride_height, stride_width, padding_depth, + padding_height, padding_width, output_depth, output_height, + output_width, col.data()); } }; template __global__ void col2vol(int num_kernels, const T* data_col, int depth, - int height, int width, int filter_depth, - int filter_height, int filter_width, int stride_depth, - int stride_height, int stride_width, int padding_depth, - int padding_height, int padding_width, int output_detph, - int output_height, int output_width, T* data_vol) { + int height, int width, int dilation_d, int dilation_h, + int dilation_w, int filter_depth, int filter_height, + int filter_width, int stride_depth, int stride_height, + int stride_width, int padding_depth, int padding_height, + int padding_width, int output_detph, int output_height, + int output_width, T* data_vol) { + const int d_filter_depth = dilation_d * (filter_depth - 1) + 1; + const int d_filter_height = dilation_h * (filter_height - 1) + 1; + const int d_filter_width = dilation_w * (filter_width - 1) + 1; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { T src_val = 0; @@ -115,35 +147,42 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, int h = (index / width) % height + padding_height; int d = (index / width / height) % depth + padding_depth; int c = index / width / height / depth; + // compute the start and end of the output int w_col_start = - (w < filter_width) ? 0 : (w - filter_width) / stride_width + 1; + (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1; int w_col_end = min(w / stride_width + 1, output_width); int h_col_start = - (h < filter_height) ? 0 : (h - filter_height) / stride_height + 1; + (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1; int h_col_end = min(h / stride_height + 1, output_height); int d_col_start = - (d < filter_depth) ? 0 : (d - filter_depth) / stride_depth + 1; + (d < d_filter_depth) ? 0 : (d - d_filter_depth) / stride_depth + 1; int d_col_end = min(d / stride_depth + 1, output_detph); - int offset = (c * filter_depth * filter_height * filter_width + - d * filter_width * filter_height + h * filter_width + w) * - output_detph * output_height * output_width; - - int coeff_d_col = - (1 - stride_depth * filter_width * filter_height * output_detph) * - output_height * output_width; - int coeff_h_col = - (1 - stride_height * filter_width * output_detph * output_height) * - output_width; - int coeff_w_col = - (1 - stride_width * output_detph * output_height * output_width); - for (int d_col = d_col_start; d_col < d_col_end; ++d_col) { for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - src_val += data_col[offset + d_col * coeff_d_col + - h_col * coeff_h_col + w_col * coeff_w_col]; + int d_off = (d - d_col * stride_depth); + int h_off = (h - h_col * stride_height); + int w_off = (w - w_col * stride_width); + if (d_off % dilation_d == 0 && h_off % dilation_h == 0 && + w_off % dilation_w == 0) { + d_off /= dilation_d; + h_off /= dilation_h; + w_off /= dilation_w; + + int data_col_index = + (((((c * filter_depth + d_off) * filter_height + h_off) * + filter_width + + w_off) * + output_detph + + d_col) * + output_height + + h_col) * + output_width + + w_col; + src_val += data_col[data_col_index]; + } } } } @@ -162,6 +201,7 @@ class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& vol, const framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const { @@ -179,6 +219,28 @@ class Col2VolFunctor { int output_height = col.dims()[5]; int output_width = col.dims()[6]; + PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - + ((dilation_d * (filter_depth - 1) + 1))) / + stride_depth + + 1, + output_depth, + "input_depth and output_depth are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + output_height, + "input_height and output_height are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + output_width, + "input_width and output_width are " + "Mismatching."); + int num_kernels = input_channels * input_depth * input_height * input_width; const int threads = 1024; @@ -188,9 +250,10 @@ class Col2VolFunctor { reinterpret_cast(context) .stream()>>>( num_kernels, col.data(), input_depth, input_height, input_width, - filter_depth, filter_height, filter_width, stride_depth, stride_height, - stride_width, padding_depth, padding_height, padding_width, - output_depth, output_height, output_width, vol.data()); + dilation_d, dilation_h, dilation_w, filter_depth, filter_height, + filter_width, stride_depth, stride_height, stride_width, padding_depth, + padding_height, padding_width, output_depth, output_height, + output_width, vol.data()); } }; diff --git a/paddle/operators/math/vol2col.h b/paddle/operators/math/vol2col.h index f022365a16fbf..c2d8257c0ba5b 100644 --- a/paddle/operators/math/vol2col.h +++ b/paddle/operators/math/vol2col.h @@ -58,6 +58,7 @@ class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& vol, framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const; @@ -68,6 +69,7 @@ class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& vol, const framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const; diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc index 74590d17cd0f9..9d673ad36cfed 100644 --- a/paddle/operators/math/vol2col_test.cc +++ b/paddle/operators/math/vol2col_test.cc @@ -64,6 +64,7 @@ void testVol2col() { int filter_size = 2; int stride = 1; int padding = 0; + int dilation = 1; int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1; int output_height = (input_height - filter_size + 2 * padding) / stride + 1; int output_width = (input_width - filter_size + 2 * padding) / stride + 1; @@ -85,8 +86,8 @@ void testVol2col() { *place); paddle::operators::math::Vol2ColFunctor vol2col; - vol2col(*context, input, output, stride, stride, stride, padding, padding, - padding); + vol2col(*context, input, output, dilation, dilation, dilation, stride, stride, + stride, padding, padding, padding); float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float* out_cfo_ptr; @@ -111,8 +112,8 @@ void testVol2col() { } paddle::operators::math::Col2VolFunctor col2vol; - col2vol(*context, input, output, stride, stride, stride, padding, padding, - padding); + col2vol(*context, input, output, dilation, dilation, dilation, stride, stride, + stride, padding, padding, padding); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { From 7d73b8fc8e7080b02167808a1a71bd4219089b88 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 10 Nov 2017 11:33:12 +0800 Subject: [PATCH 6/9] fix unit test (conv3d) --- paddle/operators/math/vol2col.cc | 1 + .../v2/framework/tests/test_conv3d_op.py | 84 ++++++++++++++----- 2 files changed, 64 insertions(+), 21 deletions(-) diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc index d383ee81526ae..bd509a94f3fb1 100644 --- a/paddle/operators/math/vol2col.cc +++ b/paddle/operators/math/vol2col.cc @@ -181,6 +181,7 @@ class Col2VolFunctor { ((cIm * input_depth + d_pad) * input_height + h_pad) * input_width + w_pad; + int col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; diff --git a/python/paddle/v2/framework/tests/test_conv3d_op.py b/python/paddle/v2/framework/tests/test_conv3d_op.py index 44c192f58d25f..934ea46437d67 100644 --- a/python/paddle/v2/framework/tests/test_conv3d_op.py +++ b/python/paddle/v2/framework/tests/test_conv3d_op.py @@ -10,27 +10,40 @@ def conv3d_forward_naive(input, filter, group, conv_param): assert np.mod(out_c, group) == 0 sub_out_c = out_c / group - stride, pad = conv_param['stride'], conv_param['pad'] - out_d = 1 + (in_d + 2 * pad[0] - f_h) / stride[0] - out_h = 1 + (in_h + 2 * pad[1] - f_h) / stride[1] - out_w = 1 + (in_w + 2 * pad[2] - f_w) / stride[2] + stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[ + 'dilations'] + + out_d = 1 + (in_d + 2 * pad[0] - (dilation[0] * (f_d - 1) + 1)) / stride[0] + out_h = 1 + (in_h + 2 * pad[1] - (dilation[1] * (f_h - 1) + 1)) / stride[1] + out_w = 1 + (in_w + 2 * pad[2] - (dilation[2] * (f_w - 1) + 1)) / stride[2] + out = np.zeros((in_n, out_c, out_d, out_h, out_w)) + d_bolck_d = (dilation[0] * (f_d - 1) + 1) + d_bolck_h = (dilation[1] * (f_h - 1) + 1) + d_bolck_w = (dilation[2] * (f_w - 1) + 1) + input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], ), (pad[2], )), mode='constant', constant_values=0) + + filter_dilation = np.zeros((out_c, f_c, d_bolck_d, d_bolck_h, d_bolck_w)) + filter_dilation[:, :, 0:d_bolck_d:dilation[0], 0:d_bolck_h:dilation[1], 0: + d_bolck_w:dilation[2]] = filter + for d in range(out_d): for i in range(out_h): for j in range(out_w): for g in range(group): input_pad_masked = \ input_pad[:, g * f_c:(g + 1) * f_c, - d * stride[0]:d * stride[0] + f_d, - i * stride[1]:i * stride[1] + f_h, - j * stride[2]:j * stride[2] + f_w] - f_sub = filter[g * sub_out_c:(g + 1) * - sub_out_c, :, :, :, :] + d * stride[0]:d * stride[0] + d_bolck_d, + i * stride[1]:i * stride[1] + d_bolck_h, + j * stride[2]:j * stride[2] + d_bolck_w] + + f_sub = filter_dilation[g * sub_out_c:(g + 1) * + sub_out_c, :, :, :, :] for k in range(sub_out_c): out[:, g * sub_out_c + k, d, i, j] = \ np.sum(input_pad_masked * f_sub[k, :, :, :, :], @@ -43,9 +56,14 @@ class TestConv3dOp(OpTest): def setUp(self): self.init_group() self.init_op_type() + self.init_dilation() self.init_test_case() - conv3d_param = {'stride': self.stride, 'pad': self.pad} + conv3d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilations': self.dilations + } input = np.random.random(self.input_size).astype("float32") filter = np.random.random(self.filter_size).astype("float32") output = conv3d_forward_naive(input, filter, self.groups, @@ -55,7 +73,8 @@ def setUp(self): self.attrs = { 'strides': self.stride, 'paddings': self.pad, - 'groups': self.groups + 'groups': self.groups, + 'dilations': self.dilations } self.outputs = {'Output': output} @@ -88,6 +107,9 @@ def init_test_case(self): f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3, 3] + def init_dilation(self): + self.dilations = [1, 1, 1] + def init_group(self): self.groups = 1 @@ -104,27 +126,47 @@ def init_test_case(self): f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3, 3] - def init_group(self): - self.groups = 1 - def init_op_type(self): - self.op_type = "conv3d" +class TestWithGroup1(TestConv3dOp): + def init_group(self): + self.groups = 3 -class TestWithGroup1(TestConv3dOp): +class TestWithGroup2(TestCase1): def init_group(self): self.groups = 3 - def init_op_type(self): - self.op_type = "conv3d" +class TestWith1x1(TestConv3dOp): + def init_test_case(self): + self.pad = [0, 0, 0] + self.stride = [1, 1, 1] + self.input_size = [2, 3, 4, 4, 4] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 1, 1, 1] + + def init_dilation(self): + self.dilations = [1, 1, 1] -class TestWithGroup2(TestCase1): def init_group(self): self.groups = 3 - def init_op_type(self): - self.op_type = "conv3d" + +class TestWithDilation(TestConv3dOp): + def init_test_case(self): + self.pad = [0, 0, 0] + self.stride = [1, 1, 1] + self.input_size = [2, 3, 6, 6, 6] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 2, 2, 2] + + def init_dilation(self): + self.dilations = [2, 2, 2] + + def init_group(self): + self.groups = 3 if __name__ == '__main__': From 356d6954043923d30ef8b1b116b66cbfa1dca7e1 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 14 Nov 2017 19:19:57 +0800 Subject: [PATCH 7/9] follow comments --- paddle/operators/conv_op.cc | 40 ++-- paddle/operators/conv_op.h | 257 +++++++++--------------- paddle/operators/conv_transpose_op.cc | 23 ++- paddle/operators/conv_transpose_op.h | 52 +++-- paddle/operators/math/context_project.h | 19 +- paddle/operators/math/im2col.cc | 168 ++++++++-------- paddle/operators/math/im2col.cu | 160 +++++++-------- paddle/operators/math/im2col.h | 25 ++- paddle/operators/math/im2col_test.cc | 26 ++- paddle/operators/math/vol2col.cc | 112 +++++------ paddle/operators/math/vol2col.cu | 96 ++++----- paddle/operators/math/vol2col.h | 29 ++- paddle/operators/math/vol2col_test.cc | 21 +- 13 files changed, 487 insertions(+), 541 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index a848b9b49cd2f..e1a11a38b3e56 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/conv_op.h" +#include namespace paddle { namespace operators { @@ -53,7 +54,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { "The number of output channels should be divided by groups."); std::vector output_shape({in_dims[0], filter_dims[0]}); - for (size_t i = 0; i < paddings.size(); ++i) { + for (size_t i = 0; i < strides.size(); ++i) { PADDLE_ENFORCE(in_dims[i + 2] + 2 * paddings[i] - (dilations[i] * (filter_dims[i + 2] - 1) + 1) > 0, @@ -61,8 +62,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { "dilations, the output size is less than 0, please check " "again."); output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], - dilations[i], paddings[i], paddings[i], - strides[i])); + dilations[i], paddings[i], strides[i])); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } @@ -86,9 +86,15 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, AddOutput("Output", "(Tensor) The output tensor of convolution operator. " "The format of output tensor is also NCHW."); - AddAttr>("strides", "strides of convolution operator.") + AddAttr>("strides", + "(vector default:{1, 1}), the " + "strides(h_stride, w_stride) of " + "convolution operator.") .SetDefault({1, 1}); - AddAttr>("paddings", "paddings of convolution operator.") + AddAttr>("paddings", + "(vector default:{0, 0}), the " + "paddings(h_pad, w_pad) of " + "convolution operator.") .SetDefault({0, 0}); AddAttr( "groups", @@ -99,9 +105,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, "is only connected to the second half of the input channels.") .SetDefault(1); AddAttr>("dilations", - "(vector default:{1, 1}), the dilations of " + "(vector default:{1, 1}), the " + "dilations(h_dilation, w_dilation) of " "convolution operator.") - .SetDefault(std::vector{1, 1}); + .SetDefault({1, 1}); AddComment(R"DOC( Convolution Operator. @@ -147,13 +154,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, AddOutput("Output", "(Tensor) The output tensor of convolution operator." "The format of output tensor is also NCDHW."); - AddAttr>( - "strides", - "(vector, default:{0, 0, 0}), the strides of convolution operator.") + AddAttr>("strides", + "(vector, default:{1, 1, 1}), the " + "strides(d_stride, h_stride, w_stride) of " + "convolution operator.") .SetDefault({1, 1, 1}); - AddAttr>( - "paddings", - "(vector, default:{0, 0, 0}), the paddings of convolution operator.") + AddAttr>("paddings", + "(vector, default:{0, 0, 0}), the " + "paddings(d_pad, h_pad, w_pad) of convolution " + "operator.") .SetDefault({0, 0, 0}); AddAttr( "groups", @@ -164,10 +173,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, "is only connected to the second half of the input channels.") .SetDefault(1); AddAttr>("dilations", - "(vector default:{1, 1, 1}), the dilations of " + "(vector default:{1, 1, 1}), the " + "dilations(d_dilation, h_dilation, w_dilation) of " "convolution operator. Currently, conv3d doesn't " "support dilation.") - .SetDefault(std::vector{1, 1, 1}); + .SetDefault({1, 1, 1}); AddComment(R"DOC( Convolution3D Operator. diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index af2c8fb163eca..fac5f1d0e25fe 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -28,24 +28,22 @@ using Tensor = framework::Tensor; // Base convolution operator definations for other conv // like operators to reuse the implementation. inline int OutputSize(int input_size, int filter_size, int dilation, - int padding_up, int padding_down, int stride) { - int output_size = (input_size + padding_up + padding_down - - (dilation * (filter_size - 1) + 1)) / - stride + - 1; + int padding, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + const int output_size = (input_size + 2 * padding - dkernel) / stride + 1; return output_size; } -inline bool NotExpand(std::vector& filter_dim, - std::vector& strides, std::vector& paddings, - std::vector& dilations) { +inline bool IsExpand(std::vector& filter_dim, + std::vector& strides, std::vector& paddings, + std::vector& dilations) { bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; for (size_t j = 0; j < strides.size(); ++j) { - filter_1 &= (static_cast(filter_dim[j]) == 1); - strides_1 &= (strides[j] == 1); - padding_0 &= (paddings[j] == 0); - dilation_1 &= (dilations[j] == 1); + filter_1 = filter_1 && (static_cast(filter_dim[j]) == 1); + strides_1 = strides_1 && (strides[j] == 1); + padding_0 = padding_0 && (paddings[j] == 0); + dilation_1 = dilation_1 && (dilations[j] == 1); } - return filter_1 && strides_1 && padding_0 && dilation_1; + return !(filter_1 && strides_1 && padding_0 && dilation_1); } // Define Op classes in .h file so that other conv @@ -65,14 +63,12 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { class ConvOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override; }; class ConvOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override; }; @@ -88,9 +84,9 @@ class GemmConvKernel : public framework::OpKernel { Tensor* output = context.Output("Output"); output->mutable_data(context.GetPlace()); + int groups = context.Attr("groups"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -122,13 +118,13 @@ class GemmConvKernel : public framework::OpKernel { framework::DDim col_matrix_shape = framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); - bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations); + bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. Tensor col_matrix; - if (!not_expand) { + if (is_expand) { col.mutable_data(col_shape, context.GetPlace()); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); @@ -149,51 +145,37 @@ class GemmConvKernel : public framework::OpKernel { int in_step = static_cast(input->dims()[1]) / groups; int out_step = static_cast(output->dims()[1]) / groups; - if (!not_expand) { - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; - if (filter_shape_vec.size() == 2) { - // im2col - math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - } else if (filter_shape_vec.size() == 3) { - // vol2col - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], dilations[2], strides[0], strides[1], - strides[2], paddings[0], paddings[1], paddings[2]); - } + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, false, - col_matrix, false, T(1.0), &out_slice, T(0.0)); - } - } - } else { - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + if (!is_expand) { col.ShareDataWith(in_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, false, - col_matrix, false, T(1.0), &out_slice, T(0.0)); + } else if (filter_shape_vec.size() == 2) { + // im2col + im2col(context.device_context(), in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (filter_shape_vec.size() == 3) { + // vol2col + vol2col(context.device_context(), in_slice, dilations, strides, + paddings, &col); } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); } } } @@ -217,9 +199,9 @@ class GemmConvGradKernel : public framework::OpKernel { if (!input_grad && !filter_grad) return; + int groups = context.Attr("groups"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -270,13 +252,13 @@ class GemmConvGradKernel : public framework::OpKernel { int in_step = static_cast(input->dims()[1]) / groups; int out_step = static_cast(output_grad->dims()[1]) / groups; - bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations); + bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. Tensor col_matrix; - if (!not_expand) { + if (is_expand) { col.mutable_data(col_shape, context.GetPlace()); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); @@ -288,61 +270,38 @@ class GemmConvGradKernel : public framework::OpKernel { input_grad->mutable_data(context.GetPlace()); set_zero(context.device_context(), input_grad, static_cast(0)); - if (!not_expand) { - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = - input_grad->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // gemm - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = - filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, true, - out_grad_slice, false, T(1.0), &col_matrix, - T(0.0)); - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - - if (filter_shape_vec.size() == 2) { - math::Col2ImFunctor col2im; - col2im(context.device_context(), in_grad_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - - } else if (filter_shape_vec.size() == 3) { - math::Col2VolFunctor col2vol; - col2vol(context.device_context(), in_grad_slice, col, - dilations[0], dilations[1], dilations[2], strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); - } - } - } - } else { - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = - input_grad->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // gemm - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = - filter.Slice(g * out_step, (g + 1) * out_step); - - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + math::Col2VolFunctor col2vol; + math::Col2ImFunctor col2im; + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { col_matrix.ShareDataWith(in_grad_slice); col_matrix.Resize(col_matrix_shape); - - math::matmul(context.device_context(), filter_slice, true, - out_grad_slice, false, T(1.0), &col_matrix, - T(0.0)); + } + math::matmul(context.device_context(), filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); + + if (is_expand && filter_shape_vec.size() == 2) { + col2im(context.device_context(), col, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &in_grad_slice); + } else if (is_expand && filter_shape_vec.size() == 3) { + col2vol(context.device_context(), col, dilations, strides, paddings, + &in_grad_slice); } } } @@ -353,60 +312,38 @@ class GemmConvGradKernel : public framework::OpKernel { Tensor filter_grad_ = *filter_grad; filter_grad_.Resize(filter_matrix_shape); set_zero(context.device_context(), filter_grad, static_cast(0)); + math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // im2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - if (!not_expand) { - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // im2col - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (filter_shape_vec.size() == 2) { - math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - } else if (filter_shape_vec.size() == 3) { - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], dilations[2], strides[0], strides[1], - strides[2], paddings[0], paddings[1], paddings[2]); - } - - // gemm - Tensor filter_grad_slice = - filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), out_grad_slice, - false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); - } - } - } else { - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // im2col - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - + if (!is_expand) { col.ShareDataWith(in_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - - // gemm - Tensor filter_grad_slice = - filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), out_grad_slice, - false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); + } else if (filter_shape_vec.size() == 2) { + im2col(context.device_context(), in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (filter_shape_vec.size() == 3) { + vol2col(context.device_context(), in_slice, dilations, strides, + paddings, &col); } + + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); } } } diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index 50081779a5ea3..6f47a6d6a0f08 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -51,7 +51,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { "as the number of filters."); std::vector output_shape({in_dims[0], filter_dims[1]}); - for (size_t i = 0; i < paddings.size(); ++i) { + for (size_t i = 0; i < strides.size(); ++i) { output_shape.push_back((in_dims[i + 2] - 1) * strides[i] + filter_dims[i + 2]); } @@ -77,13 +77,14 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( AddOutput("Output", "(Tensor) The output tensor of convolution transpose operator. " "The format of output tensor is also NCHW."); - AddAttr>( - "strides", - "(vector defalut:{1, 1}), strides of convolution transpose operator.") + AddAttr>("strides", + "(vector defalut:{1, 1}), strides of " + "convolution transpose operator.") .SetDefault({1, 1}); AddAttr>( "paddings", - "(vector defalut:{0, 0}), paddings of convolution transpose operator.") + "(vector defalut:{0, 0}), paddings(h_pad, w_pad) of convolution " + "transpose operator.") .SetDefault({0, 0}); AddComment(R"DOC( Convolution2D Transpose Operator. @@ -132,13 +133,13 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker( "Where N is batch size, C is " "the number of channels, D is the depth of the feature, H is the " "height of the feature, and W is the width of the feature."); - AddAttr>( - "strides", - "(vector defalut:{1, 1, 1}), strides of convolution transpose operator.") + AddAttr>("strides", + "(vector defalut:{1, 1, 1}), strides of " + "convolution transpose operator.") .SetDefault({1, 1, 1}); - AddAttr>( - "paddings", - "(vector defalut:{0, 0, 0}), paddings of convolution transpose operator.") + AddAttr>("paddings", + "(vector defalut:{0, 0, 0}), paddings(d_pad, " + "h_pad, w_pad) of convolution transpose operator.") .SetDefault({0, 0, 0}); AddComment(R"DOC( Convolution3D Transpose Operator. diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index 18ca6b20e0349..4b2bd60437da8 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -43,16 +43,12 @@ class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { class ConvTransposeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: void InferShape(framework::InferShapeContext* ctx) const override; }; class ConvTransposeOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: void InferShape(framework::InferShapeContext* ctx) const override; }; @@ -66,13 +62,11 @@ class GemmConvTransposeKernel : public framework::OpKernel { Tensor* output = context.Output("Output"); std::vector strides = context.Attr>("strides"); + // Actually, no paddings and groups allowed in conv transpose. + std::vector paddings = context.Attr>("paddings"); // TODO(Zhuoyuan): Paddings can be added in future. // groups will alway be disabled in conv2dtranspose. - int dilaiton_d = 1; - int dilation_h = 1; - int dilation_w = 1; - const int batch_size = static_cast(input->dims()[0]); // input_shape_vec: {h, w} or {d, h, w} @@ -124,6 +118,10 @@ class GemmConvTransposeKernel : public framework::OpKernel { math::SetConstant set_zero; set_zero(context.device_context(), output, static_cast(0)); + math::Col2ImFunctor col2im; + math::Col2VolFunctor col2vol; + std::vector dilations({1, 1, 1}); + // convolution transpose: gemm + col2im or col2vol (similar to conv-backward // on input) for (int i = 0; i < batch_size; i++) { @@ -142,17 +140,16 @@ class GemmConvTransposeKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { // col2im: col_matrix -> dy // from (c * k_h * k_w, h * w) to (c, o_h, o_w) - math::Col2ImFunctor col2im; - - col2im(context.device_context(), output_batch, col, dilation_h, - dilation_w, strides[0], strides[1], 0, 0, 0, 0); + col2im(context.device_context(), col, + std::vector{dilations[0], dilations[1]}, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &output_batch); } else if (filter_shape_vec.size() == 3) { // col2vol: col_matrix -> dy // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) - math::Col2VolFunctor col2vol; - col2vol(context.device_context(), output_batch, col, dilaiton_d, - dilation_h, dilation_w, strides[0], strides[1], strides[2], 0, - 0, 0); + col2vol(context.device_context(), col, dilations, strides, + std::vector{0, 0, 0}, &output_batch); } } } @@ -179,10 +176,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // Actually, no paddings and groups allowed in conv transpose. std::vector paddings = context.Attr>("paddings"); - int dilaiton_d = 1; - int dilation_h = 1; - int dilation_w = 1; - const int batch_size = static_cast(input->dims()[0]); // input_shape_vec: {h, w} or {d, h, w} @@ -237,6 +230,10 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { Tensor filter_grad_; math::SetConstant set_zero; + math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + std::vector dilations({1, 1, 1}); + if (input_grad) { input_grad->mutable_data(context.GetPlace()); set_zero(context.device_context(), input_grad, static_cast(0)); @@ -256,17 +253,16 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { // im2col: dy -> col matrix // from (c, o_h, o_w) to (c * k_h * k_w, h * w) - math::Im2ColFunctor im2col; - im2col(context.device_context(), output_grad_batch, col, dilation_h, - dilation_w, strides[0], strides[1], paddings[0], paddings[0], - paddings[1], paddings[1]); + im2col(context.device_context(), output_grad_batch, + std::vector{dilations[0], dilations[1]}, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); } else if (filter_shape_vec.size() == 3) { // vol2col: dy -> col_matrix // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), output_grad_batch, col, dilaiton_d, - dilation_h, dilation_w, strides[0], strides[1], strides[2], - paddings[0], paddings[1], paddings[2]); + vol2col(context.device_context(), output_grad_batch, dilations, + strides, paddings, &col); } if (input_grad) { diff --git a/paddle/operators/math/context_project.h b/paddle/operators/math/context_project.h index c67d84528fdd3..d9f952c387f2c 100644 --- a/paddle/operators/math/context_project.h +++ b/paddle/operators/math/context_project.h @@ -95,8 +95,9 @@ class ContextProjectFunctor { math::Im2ColFunctor im2col_ocf; - int dilation_h = 1; - int dilation_w = 1; + std::vector dilation({1, 1}); + std::vector padding({up_pad, 0, down_pad, 0}); + std::vector stride({context_stride, 1}); int input_row_begin, input_row_end; int sequence_height, sequence_width; @@ -126,10 +127,7 @@ class ContextProjectFunctor { {1, input_row_end - input_row_begin, sequence_width}); // input_channels, input_height, input_width in_t.Resize(framework::make_ddim(input_shape)); - - im2col_ocf(context, in_t, out_t, dilation_h, dilation_w, - /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, - down_pad, 0, 0); + im2col_ocf(context, in_t, dilation, stride, padding, &out_t); out_t.Resize({sequence_height, context_length * sequence_width}); } } @@ -207,8 +205,9 @@ class ContextProjectGradFunctor { math::Col2ImFunctor col2im_ocf; - int dilation_h = 1; - int dilation_w = 1; + std::vector dilation({1, 1}); + std::vector padding({up_pad, 0, down_pad, 0}); + std::vector stride({context_stride, 1}); int input_row_begin, input_row_end; int sequence_height, sequence_width; @@ -240,9 +239,7 @@ class ContextProjectGradFunctor { sequence_width}); // input_channels, input_height, input_width in_t.Resize(framework::make_ddim(input_shape)); - col2im_ocf(context, in_t, out_t, dilation_h, dilation_w, - /*stride_height*/ context_stride, /*stride_width*/ 1, - up_pad, down_pad, 0, 0); + col2im_ocf(context, out_t, dilation, stride, padding, &in_t); out_t.Resize({sequence_height, context_length * sequence_width}); } } diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 2af55fa71f86a..c10c44c52076c 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -28,40 +28,39 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; - int filter_height = col.dims()[1]; - int filter_width = col.dims()[2]; - int col_height = col.dims()[3]; - int col_width = col.dims()[4]; + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int col_height = col->dims()[3]; + int col_width = col->dims()[4]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + ((dilation[0] * (filter_height - 1) + 1))) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + ((dilation[1] * (filter_width - 1) + 1))) / + stride[1] + 1, col_width, - "col_width and padding(padding_left, padding_right) are " + "Output_height and padding(padding_up, padding_down) are " "inconsistent."); int channels_col = im_channels * filter_height * filter_width; const T* im_data = im.data(); - T* col_data = col.data(); + T* col_data = col->data(); for (int c = 0; c < channels_col; ++c) { int w_offset = c % filter_width; @@ -69,10 +68,8 @@ class Im2ColFunctor class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right) { - PADDLE_ENFORCE(im.dims().size() == 3); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im) { + PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; int col_height = col.dims()[3]; int col_width = col.dims()[4]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + ((dilation[0] * (filter_height - 1) + 1))) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + ((dilation[1] * (filter_width - 1) + 1))) / + stride[1] + 1, col_width, - "col_width and padding(padding_left, padding_right) are " + "Output_height and padding(padding_up, padding_down) are " "inconsistent."); int channels_col = im_channels * filter_height * filter_width; - T* im_data = im.data(); + T* im_data = im->data(); const T* col_data = col.data(); for (int c = 0; c < channels_col; ++c) { @@ -135,10 +133,8 @@ class Col2ImFunctor= 0 && (im_row_idx) < im_height && (im_col_idx) >= 0 && (im_col_idx) < im_width) { @@ -171,35 +167,32 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; - int filter_height = col.dims()[3]; - int filter_width = col.dims()[4]; - int col_height = col.dims()[0]; - int col_width = col.dims()[1]; + int filter_height = col->dims()[3]; + int filter_width = col->dims()[4]; + int col_height = col->dims()[0]; + int col_width = col->dims()[1]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) / - stride_height + - 1, - col_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) / - stride_width + - 1, - col_width, - "col_width and padding(padding_left, padding_right) are " - "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); const T* im_data = im.data(); - T* col_data = col.data(); + T* col_data = col->data(); for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { @@ -209,9 +202,9 @@ class Im2ColFunctor class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right) { - PADDLE_ENFORCE(im.dims().size() == 3); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im) { + PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; int col_height = col.dims()[0]; int col_width = col.dims()[1]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) / - stride_height + - 1, - col_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) / - stride_width + - 1, - col_width, - "col_width and padding(padding_left, padding_right) are " - "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); - T* im_data = im.data(); + T* im_data = im->data(); const T* col_data = col.data(); for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { @@ -282,9 +274,9 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; - int filter_height = col.dims()[1]; - int filter_width = col.dims()[2]; - int col_height = col.dims()[3]; - int col_width = col.dims()[4]; - - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - (dilation_h * (filter_height - 1) + 1)) / - stride_height + + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int col_height = col->dims()[3]; + int col_width = col->dims()[4]; + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + (dilation[0] * (filter_height - 1) + 1)) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - (dilation_w * (filter_width - 1) + 1)) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + (dilation[1] * (filter_width - 1) + 1)) / + stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " @@ -100,9 +99,9 @@ class Im2ColFunctor<<(context) .stream()>>>( - im.data(), num_outputs, im_height, im_width, dilation_h, dilation_w, - filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, col_height, col_width, col.data()); + im.data(), num_outputs, im_height, im_width, dilation[0], + dilation[1], filter_height, filter_width, stride[0], stride[1], + padding[0], padding[1], col_height, col_width, col->data()); } }; @@ -163,31 +162,32 @@ template class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right) { - PADDLE_ENFORCE(im.dims().size() == 3); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im) { + PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; int col_height = col.dims()[3]; int col_width = col.dims()[4]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - (dilation_h * (filter_height - 1) + 1)) / - stride_height + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + (dilation[0] * (filter_height - 1) + 1)) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - (dilation_w * (filter_width - 1) + 1)) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + (dilation[1] * (filter_width - 1) + 1)) / + stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " @@ -206,9 +206,9 @@ class Col2ImFunctor<<(context) .stream()>>>( - num_kernels, col.data(), im_height, im_width, dilation_h, dilation_w, - filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, col_height, col_width, im.data()); + num_kernels, col.data(), im_height, im_width, dilation[0], + dilation[1], filter_height, filter_width, stride[0], stride[1], + padding[0], padding[2], col_height, col_width, im->data()); } }; @@ -222,11 +222,11 @@ template class Col2ImFunctor; template -__global__ void im2colOCF(const T* im_data, T* col_data, int im_channels, - int im_height, int im_width, int filter_height, - int filter_width, int stride_height, int stride_width, +__global__ void im2colOCF(const T* im_data, int im_channels, int im_height, + int im_width, int filter_height, int filter_width, + int stride_height, int stride_width, int padding_height, int padding_width, int col_height, - int col_width) { + int col_width, T* col_data) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < im_channels; @@ -263,30 +263,29 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; - int filter_height = col.dims()[3]; - int filter_width = col.dims()[4]; - int col_height = col.dims()[0]; - int col_width = col.dims()[1]; - - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - (dilation_h * (filter_height - 1) + 1)) / - stride_height + + int filter_height = col->dims()[3]; + int filter_width = col->dims()[4]; + int col_height = col->dims()[0]; + int col_width = col->dims()[1]; + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + (dilation[0] * (filter_height - 1) + 1)) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - (dilation_w * (filter_width - 1) + 1)) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + (dilation[1] * (filter_width - 1) + 1)) / + stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " @@ -314,18 +313,18 @@ class Im2ColFunctor<<(context) .stream()>>>( - im.data(), col.data(), im_channels, im_height, im_width, - filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, col_height, col_width); + im.data(), im_channels, im_height, im_width, filter_height, + filter_width, stride[0], stride[1], padding[0], padding[1], col_height, + col_width, col->data()); } }; template -__global__ void col2imOCF(T* im_data, const T* col_data, int im_channels, - int im_height, int im_width, int filter_height, - int filter_width, int stride_height, int stride_width, +__global__ void col2imOCF(const T* col_data, int im_channels, int im_height, + int im_width, int filter_height, int filter_width, + int stride_height, int stride_width, int padding_height, int padding_width, int col_height, - int col_width) { + int col_width, T* im_data) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < im_channels; @@ -361,30 +360,31 @@ template class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right) { - PADDLE_ENFORCE(im.dims().size() == 3); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im) { + PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; int col_height = col.dims()[0]; int col_width = col.dims()[1]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - (dilation_h * (filter_height - 1) + 1)) / - stride_height + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + (dilation[0] * (filter_height - 1) + 1)) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - (dilation_w * (filter_width - 1) + 1)) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + (dilation[1] * (filter_width - 1) + 1)) / + stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " @@ -412,9 +412,9 @@ class Col2ImFunctor<<(context) .stream()>>>( - im.data(), col.data(), im_channels, im_height, im_width, - filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, col_height, col_width); + col.data(), im_channels, im_height, im_width, filter_height, + filter_width, stride[0], stride[1], padding[0], padding[1], col_height, + col_width, im->data()); } }; diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index d1c9595a328d3..deb60051beef5 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -35,6 +35,15 @@ enum class ColFormat { kCFO = 0, kOCF = 1 }; * \param colData Column data. * \param colShape The shape of colData. * + * \param dilations dilation data. + * \param 2-dimension [dilation_height, dilation_width]. + * + * \param strides stride data. + * \param 2-dimension [stride_height, stride_width]. + * + * \param paddings padding data. + * \param 4-dimension [up_pad, left_pad, down_pad, right_pad]. + * * If the template argument Format is kCFO, the shape of colData is: * [input_channels, filter_height, filter_width, output_height, output_width] * So, it is easy to reshape into a convolution matrix for convolution @@ -73,19 +82,19 @@ template class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right); + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col); }; template class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 3385fe8721cb4..10c28da72ba9d 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -45,12 +45,14 @@ void testIm2col() { int input_height = 2; int input_width = 3; int filter_size = 2; - int stride = 1; - int padding = 0; - int dilation_h = 1; - int dilation_w = 1; - int output_height = (input_height - filter_size + 2 * padding) / stride + 1; - int output_width = (input_width - filter_size + 2 * padding) / stride + 1; + std::vector stride({1, 1}); // stride_y, stride_x + std::vector padding( + {0, 0, 0, 0}); // up_pad, left_pad, down_pad, right_pad + std::vector dilation({1, 1}); // dilation_y, dilation_x + int output_height = + (input_height - filter_size + padding[0] + padding[1]) / stride[0] + 1; + int output_width = + (input_width - filter_size + padding[2] + padding[3]) / stride[1] + 1; float* input_ptr = input_tmp.mutable_data( {1, input_height, input_width}, paddle::platform::CPUPlace()); float arr[6] = {0, 1, 2, 3, 4, 5}; @@ -87,10 +89,8 @@ void testIm2col() { paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; - im2col(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, - padding, padding, padding, padding); - im2col_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, - stride, padding, padding, padding, padding); + im2col(*context, input, dilation, stride, padding, &output_cfo); + im2col_ocf(*context, input, dilation, stride, padding, &output_ocf); float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; @@ -133,8 +133,7 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, - padding, padding, padding, padding); + col2im(*context, output_cfo, dilation, stride, padding, &input); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { @@ -155,8 +154,7 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, - stride, padding, padding, padding, padding); + col2im_ocf(*context, output_ocf, dilation, stride, padding, &input); if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data(); diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc index bd509a94f3fb1..99eb7fd46de42 100644 --- a/paddle/operators/math/vol2col.cc +++ b/paddle/operators/math/vol2col.cc @@ -28,51 +28,51 @@ template class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& vol, framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const { + const framework::Tensor& vol, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* col) const { PADDLE_ENFORCE(vol.dims().size() == 4); - PADDLE_ENFORCE(col.dims().size() == 7); + PADDLE_ENFORCE(col->dims().size() == 7); int input_channels = vol.dims()[0]; int input_depth = vol.dims()[1]; int input_height = vol.dims()[2]; int input_width = vol.dims()[3]; - int filter_depth = col.dims()[1]; - int filter_height = col.dims()[2]; - int filter_width = col.dims()[3]; - int output_depth = col.dims()[4]; - int output_height = col.dims()[5]; - int output_width = col.dims()[6]; + int filter_depth = col->dims()[1]; + int filter_height = col->dims()[2]; + int filter_width = col->dims()[3]; + int output_depth = col->dims()[4]; + int output_height = col->dims()[5]; + int output_width = col->dims()[6]; int channels_col = input_channels * filter_depth * filter_height * filter_width; - PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - - ((dilation_d * (filter_depth - 1) + 1))) / - stride_depth + + PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + 1, output_depth, "input_depth and output_depth are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + "mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + 1, output_height, "input_height and output_height are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + "mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + 1, output_width, "input_width and output_width are " - "Mismatching."); + "mismatching."); const T* vol_data = vol.data(); - T* col_data = col.data(); + T* col_data = col->data(); for (int c = 0; c < channels_col; ++c) { int w_offset = c % filter_width; @@ -80,13 +80,11 @@ class Vol2ColFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int c_in = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; + int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0]; for (int h = 0; h < output_height; ++h) { - int h_pad = - h * stride_height - padding_height + h_offset * dilation_h; + int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1]; for (int w = 0; w < output_width; ++w) { - int w_pad = - w * stride_width - padding_width + w_offset * dilation_w; + int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2]; int col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; @@ -116,18 +114,18 @@ template class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& vol, const framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const { - PADDLE_ENFORCE(vol.dims().size() == 4); + const framework::Tensor& col, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* vol) const { + PADDLE_ENFORCE(vol->dims().size() == 4); PADDLE_ENFORCE(col.dims().size() == 7); - int input_channels = vol.dims()[0]; - int input_depth = vol.dims()[1]; - int input_height = vol.dims()[2]; - int input_width = vol.dims()[3]; + int input_channels = vol->dims()[0]; + int input_depth = vol->dims()[1]; + int input_height = vol->dims()[2]; + int input_width = vol->dims()[3]; int filter_depth = col.dims()[1]; int filter_height = col.dims()[2]; int filter_width = col.dims()[3]; @@ -137,28 +135,28 @@ class Col2VolFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; - PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - - ((dilation_d * (filter_depth - 1) + 1))) / - stride_depth + + PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + 1, output_depth, "input_depth and output_depth are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + "mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + 1, output_height, "input_height and output_height are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + "mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + 1, output_width, "input_width and output_width are " - "Mismatching."); - T* vol_data = vol.data(); + "mismatching."); + T* vol_data = vol->data(); const T* col_data = col.data(); for (int c = 0; c < channels_col; ++c) { @@ -167,13 +165,11 @@ class Col2VolFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int cIm = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; + int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0]; for (int h = 0; h < output_height; ++h) { - int h_pad = - h * stride_height - padding_height + h_offset * dilation_h; + int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1]; for (int w = 0; w < output_width; ++w) { - int w_pad = - w * stride_width - padding_width + w_offset * dilation_w; + int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2]; if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { diff --git a/paddle/operators/math/vol2col.cu b/paddle/operators/math/vol2col.cu index 080d3e5466704..addae3caf89ee 100644 --- a/paddle/operators/math/vol2col.cu +++ b/paddle/operators/math/vol2col.cu @@ -71,42 +71,42 @@ template class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& vol, framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const { + const framework::Tensor& vol, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* col) const { PADDLE_ENFORCE(vol.dims().size() == 4); - PADDLE_ENFORCE(col.dims().size() == 7); + PADDLE_ENFORCE(col->dims().size() == 7); int input_channels = vol.dims()[0]; int input_depth = vol.dims()[1]; int input_height = vol.dims()[2]; int input_width = vol.dims()[3]; - int filter_depth = col.dims()[1]; - int filter_height = col.dims()[2]; - int filter_width = col.dims()[3]; - int output_depth = col.dims()[4]; - int output_height = col.dims()[5]; - int output_width = col.dims()[6]; + int filter_depth = col->dims()[1]; + int filter_height = col->dims()[2]; + int filter_width = col->dims()[3]; + int output_depth = col->dims()[4]; + int output_height = col->dims()[5]; + int output_width = col->dims()[6]; - PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - - ((dilation_d * (filter_depth - 1) + 1))) / - stride_depth + + PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + 1, output_depth, "input_depth and output_depth are " "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + 1, output_height, "input_height and output_height are " "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + 1, output_width, "input_width and output_width are " @@ -121,10 +121,10 @@ class Vol2ColFunctor { reinterpret_cast(context) .stream()>>>( num_outputs, vol.data(), input_depth, input_height, input_width, - dilation_d, dilation_h, dilation_w, filter_depth, filter_height, - filter_width, stride_depth, stride_height, stride_width, padding_depth, - padding_height, padding_width, output_depth, output_height, - output_width, col.data()); + dilations[0], dilations[1], dilations[2], filter_depth, filter_height, + filter_width, strides[0], strides[1], strides[2], paddings[0], + paddings[1], paddings[2], output_depth, output_height, output_width, + col->data()); } }; @@ -200,18 +200,18 @@ template class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& vol, const framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const { - PADDLE_ENFORCE(vol.dims().size() == 4); + const framework::Tensor& col, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* vol) const { + PADDLE_ENFORCE(vol->dims().size() == 4); PADDLE_ENFORCE(col.dims().size() == 7); - int input_channels = vol.dims()[0]; - int input_depth = vol.dims()[1]; - int input_height = vol.dims()[2]; - int input_width = vol.dims()[3]; + int input_channels = vol->dims()[0]; + int input_depth = vol->dims()[1]; + int input_height = vol->dims()[2]; + int input_width = vol->dims()[3]; int filter_depth = col.dims()[1]; int filter_height = col.dims()[2]; int filter_width = col.dims()[3]; @@ -219,23 +219,23 @@ class Col2VolFunctor { int output_height = col.dims()[5]; int output_width = col.dims()[6]; - PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - - ((dilation_d * (filter_depth - 1) + 1))) / - stride_depth + + PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + 1, output_depth, "input_depth and output_depth are " "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + 1, output_height, "input_height and output_height are " "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + 1, output_width, "input_width and output_width are " @@ -250,10 +250,10 @@ class Col2VolFunctor { reinterpret_cast(context) .stream()>>>( num_kernels, col.data(), input_depth, input_height, input_width, - dilation_d, dilation_h, dilation_w, filter_depth, filter_height, - filter_width, stride_depth, stride_height, stride_width, padding_depth, - padding_height, padding_width, output_depth, output_height, - output_width, vol.data()); + dilations[0], dilations[1], dilations[2], filter_depth, filter_height, + filter_width, strides[0], strides[1], strides[2], paddings[0], + paddings[1], paddings[2], output_depth, output_height, output_width, + vol->data()); } }; diff --git a/paddle/operators/math/vol2col.h b/paddle/operators/math/vol2col.h index c2d8257c0ba5b..cbc30bd754608 100644 --- a/paddle/operators/math/vol2col.h +++ b/paddle/operators/math/vol2col.h @@ -31,6 +31,15 @@ namespace math { * \param colData Column data. * \param colShape The shape of colData. * + * \param dilations dilation data. + * \param 3-dimension [dilation_depth, dilation_height, dilation_width]. + * + * \param strides stride data. + * \param 3-dimension [stride_depth, stride_height, stride_width]. + * + * \param paddings padding data. + * \param 3-dimension [d_pad, h_pad, w_pad]. + * * The shape of colData is: * [input_channels, filter_depth, filter_height, filter_width, output_depth, * output_height, output_width] @@ -57,22 +66,22 @@ template class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& vol, framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const; + const framework::Tensor& vol, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* col) const; }; template class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& vol, const framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const; + const framework::Tensor& col, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* vol) const; }; } // namespace math diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc index 9d673ad36cfed..c31c716842f30 100644 --- a/paddle/operators/math/vol2col_test.cc +++ b/paddle/operators/math/vol2col_test.cc @@ -62,12 +62,15 @@ void testVol2col() { int input_height = 2; int input_width = 3; int filter_size = 2; - int stride = 1; - int padding = 0; - int dilation = 1; - int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1; - int output_height = (input_height - filter_size + 2 * padding) / stride + 1; - int output_width = (input_width - filter_size + 2 * padding) / stride + 1; + std::vector strides({1, 1, 1}); + std::vector paddings({0, 0, 0}); + std::vector dilations({1, 1, 1}); + int output_depth = + (input_depth - filter_size + 2 * paddings[0]) / strides[0] + 1; + int output_height = + (input_height - filter_size + 2 * paddings[1]) / strides[1] + 1; + int output_width = + (input_width - filter_size + 2 * paddings[2]) / strides[2] + 1; // Vol2Col test float* input_ptr = @@ -86,8 +89,7 @@ void testVol2col() { *place); paddle::operators::math::Vol2ColFunctor vol2col; - vol2col(*context, input, output, dilation, dilation, dilation, stride, stride, - stride, padding, padding, padding); + vol2col(*context, input, dilations, strides, paddings, &output); float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float* out_cfo_ptr; @@ -112,8 +114,7 @@ void testVol2col() { } paddle::operators::math::Col2VolFunctor col2vol; - col2vol(*context, input, output, dilation, dilation, dilation, stride, stride, - stride, padding, padding, padding); + col2vol(*context, output, dilations, strides, paddings, &input); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { From 31dc0193c958e9ba723ee89fc602a01479d0bbf1 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 15 Nov 2017 13:23:23 +0800 Subject: [PATCH 8/9] fix ContextProjectFunctor parameter order --- paddle/operators/math/context_project.h | 36 +++++++++++++------------ paddle/operators/math/vol2col.cu | 7 +++-- paddle/operators/sequence_conv_op.h | 22 +++++++-------- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/paddle/operators/math/context_project.h b/paddle/operators/math/context_project.h index d9f952c387f2c..845de82bbcb33 100644 --- a/paddle/operators/math/context_project.h +++ b/paddle/operators/math/context_project.h @@ -88,9 +88,10 @@ template class ContextProjectFunctor { public: void operator()(const platform::DeviceContext& context, const LoDTensor& in, - const Tensor& padding_data, Tensor& col, - bool padding_trainable, int context_start, int context_length, - int context_stride, int up_pad, int down_pad) { + const Tensor& padding_data, bool padding_trainable, + const int context_start, const int context_length, + const int context_stride, const int up_pad, + const int down_pad, Tensor* col) { auto lod_level_0 = in.lod()[0]; math::Im2ColFunctor im2col_ocf; @@ -109,8 +110,8 @@ class ContextProjectFunctor { : static_cast(lod_level_0[i]); input_row_end = static_cast(lod_level_0[i + 1]); - Tensor out_t = col.Slice(static_cast(lod_level_0[i]), - static_cast(lod_level_0[i + 1])); + Tensor out_t = col->Slice(static_cast(lod_level_0[i]), + static_cast(lod_level_0[i + 1])); sequence_height = static_cast(out_t.dims()[0]); @@ -133,8 +134,8 @@ class ContextProjectFunctor { } if (padding_trainable) { for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { - Tensor out_t = col.Slice(static_cast(lod_level_0[i]), - static_cast(lod_level_0[i + 1])); + Tensor out_t = col->Slice(static_cast(lod_level_0[i]), + static_cast(lod_level_0[i + 1])); sequence_height = static_cast(out_t.dims()[0]); @@ -197,10 +198,11 @@ class ContextProjectFunctor { template class ContextProjectGradFunctor { public: - void operator()(const platform::DeviceContext& context, LoDTensor& in, - Tensor& padding_data, Tensor& col, bool padding_trainable, - int context_start, int context_length, int context_stride, - int up_pad, int down_pad, bool input_grad, bool pad_grad) { + void operator()(const platform::DeviceContext& context, const LoDTensor& in, + bool padding_trainable, const int context_start, + const int context_length, const int context_stride, + const int up_pad, const int down_pad, bool pad_grad, + bool input_grad, Tensor* padding_data, Tensor* col) { auto lod_level_0 = in.lod()[0]; math::Col2ImFunctor col2im_ocf; @@ -220,8 +222,8 @@ class ContextProjectGradFunctor { : static_cast(lod_level_0[i]); input_row_end = static_cast(lod_level_0[i + 1]); - Tensor out_t = col.Slice(static_cast(lod_level_0[i]), - static_cast(lod_level_0[i + 1])); + Tensor out_t = col->Slice(static_cast(lod_level_0[i]), + static_cast(lod_level_0[i + 1])); sequence_height = static_cast(out_t.dims()[0]); @@ -247,8 +249,8 @@ class ContextProjectGradFunctor { if (pad_grad) { if (padding_trainable) { for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { - Tensor out_t = col.Slice(static_cast(lod_level_0[i]), - static_cast(lod_level_0[i + 1])); + Tensor out_t = col->Slice(static_cast(lod_level_0[i]), + static_cast(lod_level_0[i + 1])); sequence_height = static_cast(out_t.dims()[0]); out_t.Resize({sequence_height * context_length, sequence_width}); @@ -262,7 +264,7 @@ class ContextProjectGradFunctor { k + context_length < up_pad ? context_length : up_pad - k; Tensor out_t_sub = out_t.Slice(k * context_length, k * context_length + padding_size); - Tensor w_sub = padding_data.Slice(k, k + padding_size); + Tensor w_sub = padding_data->Slice(k, k + padding_size); auto out_t_sub_e = EigenMatrix::From(out_t_sub); auto w_sub_e = EigenMatrix::From(w_sub); w_sub_e.device(*context.GetEigenDevice()) = @@ -295,7 +297,7 @@ class ContextProjectGradFunctor { Tensor out_t_sub = out_t.Slice( (down_pad_begin_row + t) * context_length - padding_size, (down_pad_begin_row + t) * context_length); - Tensor w_sub = padding_data.Slice( + Tensor w_sub = padding_data->Slice( up_pad + padding_idx, up_pad + padding_idx + padding_size); auto out_t_sub_e = EigenMatrix::From(out_t_sub); auto w_sub_e = EigenMatrix::From(w_sub); diff --git a/paddle/operators/math/vol2col.cu b/paddle/operators/math/vol2col.cu index addae3caf89ee..dae3be858e9f4 100644 --- a/paddle/operators/math/vol2col.cu +++ b/paddle/operators/math/vol2col.cu @@ -174,10 +174,9 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, int data_col_index = (((((c * filter_depth + d_off) * filter_height + h_off) * filter_width + - w_off) * - output_detph + - d_col) * - output_height + + w_off))); + data_col_index = + ((data_col_index * output_detph + d_col) * output_height + h_col) * output_width + w_col; diff --git a/paddle/operators/sequence_conv_op.h b/paddle/operators/sequence_conv_op.h index a57e1752bb8ed..adee8d760e1d6 100644 --- a/paddle/operators/sequence_conv_op.h +++ b/paddle/operators/sequence_conv_op.h @@ -62,9 +62,9 @@ class SequenceConvKernel : public framework::OpKernel { math::ContextProjectFunctor seq_project_functor; - seq_project_functor(context.device_context(), *in, *padding_data, col, + seq_project_functor(context.device_context(), *in, *padding_data, padding_trainable, context_start, context_length, - context_stride, up_pad, down_pad); + context_stride, up_pad, down_pad, &col); math::matmul(context.device_context(), col, false, filter, false, static_cast(1.0), out, static_cast(0.0)); @@ -117,10 +117,10 @@ class SequenceConvGradKernel : public framework::OpKernel { in_g->set_lod(in->lod()); set_zero(context.device_context(), in_g, static_cast(0)); - seq_project_grad_functor(context.device_context(), *in_g, *padding_data_g, - col, padding_trainable, context_start, - context_length, context_stride, up_pad, down_pad, - true, false); + seq_project_grad_functor(context.device_context(), *in_g, + padding_trainable, context_start, context_length, + context_stride, up_pad, down_pad, false, true, + padding_data_g, &col); } if (padding_trainable && padding_data_g) { @@ -129,9 +129,9 @@ class SequenceConvGradKernel : public framework::OpKernel { LoDTensor* input = const_cast(in); seq_project_grad_functor(context.device_context(), *input, - *padding_data_g, col, padding_trainable, - context_start, context_length, context_stride, - up_pad, down_pad, false, true); + padding_trainable, context_start, context_length, + context_stride, up_pad, down_pad, true, false, + padding_data_g, &col); } if (filter_g) { @@ -146,9 +146,9 @@ class SequenceConvGradKernel : public framework::OpKernel { padding_data = context.Input("PaddingData"); } - seq_project_functor(context.device_context(), *in, *padding_data, col, + seq_project_functor(context.device_context(), *in, *padding_data, padding_trainable, context_start, context_length, - context_stride, up_pad, down_pad); + context_stride, up_pad, down_pad, &col); math::matmul(context.device_context(), col, true, out_grad, false, T(1.0), &filter_grad, T(1.0)); From 00e0881bfb1fa3d633a360032ce85e80e966a0b3 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 15 Nov 2017 19:58:39 +0800 Subject: [PATCH 9/9] remove conflict --- python/paddle/v2/framework/proto/__init__.py | 0 .../v2/framework/proto/framework_pb2.py | 1076 ----------------- 2 files changed, 1076 deletions(-) delete mode 100644 python/paddle/v2/framework/proto/__init__.py delete mode 100644 python/paddle/v2/framework/proto/framework_pb2.py diff --git a/python/paddle/v2/framework/proto/__init__.py b/python/paddle/v2/framework/proto/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/python/paddle/v2/framework/proto/framework_pb2.py b/python/paddle/v2/framework/proto/framework_pb2.py deleted file mode 100644 index 950cd2290724d..0000000000000 --- a/python/paddle/v2/framework/proto/framework_pb2.py +++ /dev/null @@ -1,1076 +0,0 @@ -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: framework.proto - -import sys -_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1')) -from google.protobuf.internal import enum_type_wrapper -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -from google.protobuf import descriptor_pb2 -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - -DESCRIPTOR = _descriptor.FileDescriptor( - name='framework.proto', - package='paddle.framework', - syntax='proto2', - serialized_pb=_b( - '\n\x0f\x66ramework.proto\x12\x10paddle.framework\"\x8c\x03\n\x06OpDesc\x12\x0c\n\x04type\x18\x03 \x02(\t\x12,\n\x06inputs\x18\x01 \x03(\x0b\x32\x1c.paddle.framework.OpDesc.Var\x12-\n\x07outputs\x18\x02 \x03(\x0b\x32\x1c.paddle.framework.OpDesc.Var\x12,\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32\x1d.paddle.framework.OpDesc.Attr\x1a\xbb\x01\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12(\n\x04type\x18\x02 \x02(\x0e\x32\x1a.paddle.framework.AttrType\x12\t\n\x01i\x18\x03 \x01(\x05\x12\t\n\x01\x66\x18\x04 \x01(\x02\x12\t\n\x01s\x18\x05 \x01(\t\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0e\n\x06\x66loats\x18\x07 \x03(\x02\x12\x0f\n\x07strings\x18\x08 \x03(\t\x12\t\n\x01\x62\x18\n \x01(\x08\x12\r\n\x05\x62ools\x18\x0b \x03(\x08\x12\x11\n\tblock_idx\x18\x0c \x01(\x05\x1a+\n\x03Var\x12\x11\n\tparameter\x18\x01 \x02(\t\x12\x11\n\targuments\x18\x02 \x03(\t\"\x9f\x03\n\x07OpProto\x12\x0c\n\x04type\x18\x01 \x02(\t\x12-\n\x06inputs\x18\x02 \x03(\x0b\x32\x1d.paddle.framework.OpProto.Var\x12.\n\x07outputs\x18\x03 \x03(\x0b\x32\x1d.paddle.framework.OpProto.Var\x12-\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32\x1e.paddle.framework.OpProto.Attr\x12\x0f\n\x07\x63omment\x18\x05 \x02(\t\x1a|\n\x03Var\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0f\n\x07\x63omment\x18\x02 \x02(\t\x12\x19\n\nduplicable\x18\x03 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0cintermediate\x18\x04 \x01(\x08:\x05\x66\x61lse\x12\x1e\n\x0fnot_in_gradient\x18\x05 \x01(\x08:\x05\x66\x61lse\x1ai\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12(\n\x04type\x18\x02 \x02(\x0e\x32\x1a.paddle.framework.AttrType\x12\x0f\n\x07\x63omment\x18\x03 \x02(\t\x12\x18\n\tgenerated\x18\x04 \x01(\x08:\x05\x66\x61lse\"b\n\rLoDTensorDesc\x12-\n\tdata_type\x18\x01 \x02(\x0e\x32\x1a.paddle.framework.DataType\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x03\x12\x14\n\tlod_level\x18\x03 \x01(\x05:\x01\x30\"L\n\x07VarDesc\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x33\n\nlod_tensor\x18\x02 \x01(\x0b\x32\x1f.paddle.framework.LoDTensorDesc\"|\n\tBlockDesc\x12\x0b\n\x03idx\x18\x01 \x02(\x05\x12\x12\n\nparent_idx\x18\x02 \x02(\x05\x12\'\n\x04vars\x18\x03 \x03(\x0b\x32\x19.paddle.framework.VarDesc\x12%\n\x03ops\x18\x04 \x03(\x0b\x32\x18.paddle.framework.OpDesc\":\n\x0bProgramDesc\x12+\n\x06\x62locks\x18\x01 \x03(\x0b\x32\x1b.paddle.framework.BlockDesc*s\n\x08\x41ttrType\x12\x07\n\x03INT\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x08\n\x04INTS\x10\x03\x12\n\n\x06\x46LOATS\x10\x04\x12\x0b\n\x07STRINGS\x10\x05\x12\x0b\n\x07\x42OOLEAN\x10\x06\x12\x0c\n\x08\x42OOLEANS\x10\x07\x12\t\n\x05\x42LOCK\x10\x08*S\n\x08\x44\x61taType\x12\x08\n\x04\x42OOL\x10\x00\x12\t\n\x05INT16\x10\x01\x12\t\n\x05INT32\x10\x02\x12\t\n\x05INT64\x10\x03\x12\x08\n\x04\x46P16\x10\x04\x12\x08\n\x04\x46P32\x10\x05\x12\x08\n\x04\x46P64\x10\x06' - )) -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -_ATTRTYPE = _descriptor.EnumDescriptor( - name='AttrType', - full_name='paddle.framework.AttrType', - filename=None, - file=DESCRIPTOR, - values=[ - _descriptor.EnumValueDescriptor( - name='INT', index=0, number=0, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='FLOAT', index=1, number=1, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='STRING', index=2, number=2, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='INTS', index=3, number=3, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='FLOATS', index=4, number=4, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='STRINGS', index=5, number=5, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='BOOLEAN', index=6, number=6, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='BOOLEANS', index=7, number=7, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='BLOCK', index=8, number=8, options=None, type=None), - ], - containing_type=None, - options=None, - serialized_start=1218, - serialized_end=1333, ) -_sym_db.RegisterEnumDescriptor(_ATTRTYPE) - -AttrType = enum_type_wrapper.EnumTypeWrapper(_ATTRTYPE) -_DATATYPE = _descriptor.EnumDescriptor( - name='DataType', - full_name='paddle.framework.DataType', - filename=None, - file=DESCRIPTOR, - values=[ - _descriptor.EnumValueDescriptor( - name='BOOL', index=0, number=0, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='INT16', index=1, number=1, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='INT32', index=2, number=2, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='INT64', index=3, number=3, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='FP16', index=4, number=4, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='FP32', index=5, number=5, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='FP64', index=6, number=6, options=None, type=None), - ], - containing_type=None, - options=None, - serialized_start=1335, - serialized_end=1418, ) -_sym_db.RegisterEnumDescriptor(_DATATYPE) - -DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE) -INT = 0 -FLOAT = 1 -STRING = 2 -INTS = 3 -FLOATS = 4 -STRINGS = 5 -BOOLEAN = 6 -BOOLEANS = 7 -BLOCK = 8 -BOOL = 0 -INT16 = 1 -INT32 = 2 -INT64 = 3 -FP16 = 4 -FP32 = 5 -FP64 = 6 - -_OPDESC_ATTR = _descriptor.Descriptor( - name='Attr', - full_name='paddle.framework.OpDesc.Attr', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', - full_name='paddle.framework.OpDesc.Attr.name', - index=0, - number=1, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='type', - full_name='paddle.framework.OpDesc.Attr.type', - index=1, - number=2, - type=14, - cpp_type=8, - label=2, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='i', - full_name='paddle.framework.OpDesc.Attr.i', - index=2, - number=3, - type=5, - cpp_type=1, - label=1, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='f', - full_name='paddle.framework.OpDesc.Attr.f', - index=3, - number=4, - type=2, - cpp_type=6, - label=1, - has_default_value=False, - default_value=float(0), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='s', - full_name='paddle.framework.OpDesc.Attr.s', - index=4, - number=5, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='ints', - full_name='paddle.framework.OpDesc.Attr.ints', - index=5, - number=6, - type=5, - cpp_type=1, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='floats', - full_name='paddle.framework.OpDesc.Attr.floats', - index=6, - number=7, - type=2, - cpp_type=6, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='strings', - full_name='paddle.framework.OpDesc.Attr.strings', - index=7, - number=8, - type=9, - cpp_type=9, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='b', - full_name='paddle.framework.OpDesc.Attr.b', - index=8, - number=10, - type=8, - cpp_type=7, - label=1, - has_default_value=False, - default_value=False, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='bools', - full_name='paddle.framework.OpDesc.Attr.bools', - index=9, - number=11, - type=8, - cpp_type=7, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='block_idx', - full_name='paddle.framework.OpDesc.Attr.block_idx', - index=10, - number=12, - type=5, - cpp_type=1, - label=1, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=202, - serialized_end=389, ) - -_OPDESC_VAR = _descriptor.Descriptor( - name='Var', - full_name='paddle.framework.OpDesc.Var', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='parameter', - full_name='paddle.framework.OpDesc.Var.parameter', - index=0, - number=1, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='arguments', - full_name='paddle.framework.OpDesc.Var.arguments', - index=1, - number=2, - type=9, - cpp_type=9, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=391, - serialized_end=434, ) - -_OPDESC = _descriptor.Descriptor( - name='OpDesc', - full_name='paddle.framework.OpDesc', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='type', - full_name='paddle.framework.OpDesc.type', - index=0, - number=3, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='inputs', - full_name='paddle.framework.OpDesc.inputs', - index=1, - number=1, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='outputs', - full_name='paddle.framework.OpDesc.outputs', - index=2, - number=2, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='attrs', - full_name='paddle.framework.OpDesc.attrs', - index=3, - number=4, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[ - _OPDESC_ATTR, - _OPDESC_VAR, - ], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=38, - serialized_end=434, ) - -_OPPROTO_VAR = _descriptor.Descriptor( - name='Var', - full_name='paddle.framework.OpProto.Var', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', - full_name='paddle.framework.OpProto.Var.name', - index=0, - number=1, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='comment', - full_name='paddle.framework.OpProto.Var.comment', - index=1, - number=2, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='duplicable', - full_name='paddle.framework.OpProto.Var.duplicable', - index=2, - number=3, - type=8, - cpp_type=7, - label=1, - has_default_value=True, - default_value=False, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='intermediate', - full_name='paddle.framework.OpProto.Var.intermediate', - index=3, - number=4, - type=8, - cpp_type=7, - label=1, - has_default_value=True, - default_value=False, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='not_in_gradient', - full_name='paddle.framework.OpProto.Var.not_in_gradient', - index=4, - number=5, - type=8, - cpp_type=7, - label=1, - has_default_value=True, - default_value=False, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=621, - serialized_end=745, ) - -_OPPROTO_ATTR = _descriptor.Descriptor( - name='Attr', - full_name='paddle.framework.OpProto.Attr', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', - full_name='paddle.framework.OpProto.Attr.name', - index=0, - number=1, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='type', - full_name='paddle.framework.OpProto.Attr.type', - index=1, - number=2, - type=14, - cpp_type=8, - label=2, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='comment', - full_name='paddle.framework.OpProto.Attr.comment', - index=2, - number=3, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='generated', - full_name='paddle.framework.OpProto.Attr.generated', - index=3, - number=4, - type=8, - cpp_type=7, - label=1, - has_default_value=True, - default_value=False, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=747, - serialized_end=852, ) - -_OPPROTO = _descriptor.Descriptor( - name='OpProto', - full_name='paddle.framework.OpProto', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='type', - full_name='paddle.framework.OpProto.type', - index=0, - number=1, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='inputs', - full_name='paddle.framework.OpProto.inputs', - index=1, - number=2, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='outputs', - full_name='paddle.framework.OpProto.outputs', - index=2, - number=3, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='attrs', - full_name='paddle.framework.OpProto.attrs', - index=3, - number=4, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='comment', - full_name='paddle.framework.OpProto.comment', - index=4, - number=5, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[ - _OPPROTO_VAR, - _OPPROTO_ATTR, - ], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=437, - serialized_end=852, ) - -_LODTENSORDESC = _descriptor.Descriptor( - name='LoDTensorDesc', - full_name='paddle.framework.LoDTensorDesc', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='data_type', - full_name='paddle.framework.LoDTensorDesc.data_type', - index=0, - number=1, - type=14, - cpp_type=8, - label=2, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='dims', - full_name='paddle.framework.LoDTensorDesc.dims', - index=1, - number=2, - type=3, - cpp_type=2, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='lod_level', - full_name='paddle.framework.LoDTensorDesc.lod_level', - index=2, - number=3, - type=5, - cpp_type=1, - label=1, - has_default_value=True, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=854, - serialized_end=952, ) - -_VARDESC = _descriptor.Descriptor( - name='VarDesc', - full_name='paddle.framework.VarDesc', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', - full_name='paddle.framework.VarDesc.name', - index=0, - number=1, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='lod_tensor', - full_name='paddle.framework.VarDesc.lod_tensor', - index=1, - number=2, - type=11, - cpp_type=10, - label=1, - has_default_value=False, - default_value=None, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=954, - serialized_end=1030, ) - -_BLOCKDESC = _descriptor.Descriptor( - name='BlockDesc', - full_name='paddle.framework.BlockDesc', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='idx', - full_name='paddle.framework.BlockDesc.idx', - index=0, - number=1, - type=5, - cpp_type=1, - label=2, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='parent_idx', - full_name='paddle.framework.BlockDesc.parent_idx', - index=1, - number=2, - type=5, - cpp_type=1, - label=2, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='vars', - full_name='paddle.framework.BlockDesc.vars', - index=2, - number=3, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='ops', - full_name='paddle.framework.BlockDesc.ops', - index=3, - number=4, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=1032, - serialized_end=1156, ) - -_PROGRAMDESC = _descriptor.Descriptor( - name='ProgramDesc', - full_name='paddle.framework.ProgramDesc', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='blocks', - full_name='paddle.framework.ProgramDesc.blocks', - index=0, - number=1, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=1158, - serialized_end=1216, ) - -_OPDESC_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE -_OPDESC_ATTR.containing_type = _OPDESC -_OPDESC_VAR.containing_type = _OPDESC -_OPDESC.fields_by_name['inputs'].message_type = _OPDESC_VAR -_OPDESC.fields_by_name['outputs'].message_type = _OPDESC_VAR -_OPDESC.fields_by_name['attrs'].message_type = _OPDESC_ATTR -_OPPROTO_VAR.containing_type = _OPPROTO -_OPPROTO_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE -_OPPROTO_ATTR.containing_type = _OPPROTO -_OPPROTO.fields_by_name['inputs'].message_type = _OPPROTO_VAR -_OPPROTO.fields_by_name['outputs'].message_type = _OPPROTO_VAR -_OPPROTO.fields_by_name['attrs'].message_type = _OPPROTO_ATTR -_LODTENSORDESC.fields_by_name['data_type'].enum_type = _DATATYPE -_VARDESC.fields_by_name['lod_tensor'].message_type = _LODTENSORDESC -_BLOCKDESC.fields_by_name['vars'].message_type = _VARDESC -_BLOCKDESC.fields_by_name['ops'].message_type = _OPDESC -_PROGRAMDESC.fields_by_name['blocks'].message_type = _BLOCKDESC -DESCRIPTOR.message_types_by_name['OpDesc'] = _OPDESC -DESCRIPTOR.message_types_by_name['OpProto'] = _OPPROTO -DESCRIPTOR.message_types_by_name['LoDTensorDesc'] = _LODTENSORDESC -DESCRIPTOR.message_types_by_name['VarDesc'] = _VARDESC -DESCRIPTOR.message_types_by_name['BlockDesc'] = _BLOCKDESC -DESCRIPTOR.message_types_by_name['ProgramDesc'] = _PROGRAMDESC -DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE -DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE - -OpDesc = _reflection.GeneratedProtocolMessageType( - 'OpDesc', - (_message.Message, ), - dict( - Attr=_reflection.GeneratedProtocolMessageType( - 'Attr', - (_message.Message, ), - dict( - DESCRIPTOR=_OPDESC_ATTR, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.OpDesc.Attr) - )), - Var=_reflection.GeneratedProtocolMessageType( - 'Var', - (_message.Message, ), - dict( - DESCRIPTOR=_OPDESC_VAR, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.OpDesc.Var) - )), - DESCRIPTOR=_OPDESC, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.OpDesc) - )) -_sym_db.RegisterMessage(OpDesc) -_sym_db.RegisterMessage(OpDesc.Attr) -_sym_db.RegisterMessage(OpDesc.Var) - -OpProto = _reflection.GeneratedProtocolMessageType( - 'OpProto', - (_message.Message, ), - dict( - Var=_reflection.GeneratedProtocolMessageType( - 'Var', - (_message.Message, ), - dict( - DESCRIPTOR=_OPPROTO_VAR, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.OpProto.Var) - )), - Attr=_reflection.GeneratedProtocolMessageType( - 'Attr', - (_message.Message, ), - dict( - DESCRIPTOR=_OPPROTO_ATTR, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.OpProto.Attr) - )), - DESCRIPTOR=_OPPROTO, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.OpProto) - )) -_sym_db.RegisterMessage(OpProto) -_sym_db.RegisterMessage(OpProto.Var) -_sym_db.RegisterMessage(OpProto.Attr) - -LoDTensorDesc = _reflection.GeneratedProtocolMessageType( - 'LoDTensorDesc', - (_message.Message, ), - dict( - DESCRIPTOR=_LODTENSORDESC, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.LoDTensorDesc) - )) -_sym_db.RegisterMessage(LoDTensorDesc) - -VarDesc = _reflection.GeneratedProtocolMessageType( - 'VarDesc', - (_message.Message, ), - dict( - DESCRIPTOR=_VARDESC, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.VarDesc) - )) -_sym_db.RegisterMessage(VarDesc) - -BlockDesc = _reflection.GeneratedProtocolMessageType( - 'BlockDesc', - (_message.Message, ), - dict( - DESCRIPTOR=_BLOCKDESC, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.BlockDesc) - )) -_sym_db.RegisterMessage(BlockDesc) - -ProgramDesc = _reflection.GeneratedProtocolMessageType( - 'ProgramDesc', - (_message.Message, ), - dict( - DESCRIPTOR=_PROGRAMDESC, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.ProgramDesc) - )) -_sym_db.RegisterMessage(ProgramDesc) - -# @@protoc_insertion_point(module_scope)