Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dilations for conv2d and optimize conv2d code #5472

Merged
merged 11 commits into from
Nov 15, 2017

Conversation

chengduoZH
Copy link
Contributor

@chengduoZH chengduoZH commented Nov 8, 2017

fix #5495
fix #5507
fix #5550

  • Add dilations for im2col(ColFormat::kCFO).
  • Fix conv2d doc (Add dilation)
  • Refine conv2d_op unit test.(Add dilation)
  • Refine conv_op for filter size:(1,1), (when filter_size, paddings, strides, dilations are (1,1), (0, 0), (1,1), (1,1), conv_op just likes fc_op)
  • Add dilations for vol2col(ColFormat::kCFO).
  • Fix conv3d doc (Add dilation)
  • Refine conv3d_op unit test.(Add dilation)

@chengduoZH chengduoZH force-pushed the refine_im2col branch 4 times, most recently from 10101ea to 27805c2 Compare November 8, 2017 06:53
@chengduoZH chengduoZH force-pushed the refine_im2col branch 3 times, most recently from 607e0e8 to 520dec7 Compare November 8, 2017 10:04
@chengduoZH chengduoZH changed the title add dilations for im2col add dilations for conv2d Nov 8, 2017
@chengduoZH chengduoZH changed the title add dilations for conv2d Add dilations for conv2d and optimize conv2d code Nov 9, 2017
@chengduoZH chengduoZH force-pushed the refine_im2col branch 2 times, most recently from b2d5245 to caf24f3 Compare November 9, 2017 08:39
@chengduoZH chengduoZH force-pushed the refine_im2col branch 2 times, most recently from a5bbe8d to 3e60b6b Compare November 10, 2017 05:55
@chengduoZH chengduoZH force-pushed the refine_im2col branch 2 times, most recently from a3e15fd to 7d73b8f Compare November 14, 2017 11:34
filter_1 &= (static_cast<int>(filter_dim[j]) == 1);
strides_1 &= (strides[j] == 1);
padding_0 &= (paddings[j] == 0);
dilation_1 &= (dilations[j] == 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

&= -> &&= ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有这种表示&&=,我改成了strides_1 = strides_1 && (strides[j] == 1)

vol2col(context.device_context(), in_slice, col, strides[0],
strides[1], strides[2], paddings[0], paddings[1],
paddings[2]);
if (!not_expand) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

! not expand = expand, the logic is a little complex,

How about rename NotExpand to IsExpand? Then return True, means that it needs to expand, ortherwise, not expand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for (int i = 0; i < batch_size; i++) {
 // ....
  for (int g = 0; g < groups; g++) {
    if(!IsExpand) {
      ShareDataWith();
    } else if () {
      im2col();  
    } else if () {
      im2vol();
    }
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code structure is same as above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

写functor的时候,我也在考虑functor里面是否还有必要再次check shape的正确性,因为要么是在Op里计算得到的,要么InferShape里也已经check过。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,我也有想过这些,比如在Op里面检测过了,在GradOp中就不用检测了吧

int padding_down, int padding_left, int padding_right) {
int dilation_h, int dilation_w, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe std::vector<int>& dilations, std::vector<int>& strides, std::vector<int>& paddings are short. And the op also uses std::vector<int>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

int output_size = (input_size + padding_up + padding_down -
(dilation * (filter_size - 1) + 1)) /
stride +
1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const int dkernel = dilation * (filter_size - 1) + 1;
const int output_size = (input_size + padding_up + padding_down - dkernel)/stride + 1;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

1,
output_width,
"input_width and output_width are "
"Mismatching.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, whether it needs to check again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can write in this way first, and discuss it later. Because other functors also have similar problem.

output_height +
h_col) *
output_width +
w_col;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_col_index的计算太长,不容易看清楚。一些计算可以移到各自循环里。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把这个公式分成了两个,现在可能会好一点

@chengduoZH chengduoZH force-pushed the refine_im2col branch 2 times, most recently from 1607de4 to 8fffa9e Compare November 15, 2017 07:21
@chengduoZH chengduoZH merged commit 4fc9f55 into PaddlePaddle:develop Nov 15, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
2 participants