Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implementation of constantpad-3d op #5529

Merged
merged 15 commits into from
Jul 20, 2021
Merged

Conversation

Flowingsun007
Copy link
Contributor

@Flowingsun007 Flowingsun007 commented Jul 17, 2021

constantpad 3d 的op实现包括:

  • cpu/cuda kernel
  • functional api
  • gradient func
  • api doc
  • doctest
  • random tests vs pytorch

constantpad3d

Snip20210717_1

@@ -0,0 +1,66 @@
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个系列的文件名,叫做 pad3d_xxxx ,改作全称的 constantpad3d_xxxx 会不会更好。
还有就是,依照命名的约定,我们一般不用 xxx_kernels_util.h(cpp,cu) 而是 xxx_kernel_util.h(cpp,cu) 也就是 util 里的 kernel 不用复数。
(目前合到 master 里的只有一个用了 xxx_kernels_util,应该是review时的漏网之鱼

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,命名已修改

return static_cast<int8_t>(integral);
}

template<>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个空模板有什么意义吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和L26-27的配合使用

const int64_t d_idx = 2;
const int64_t h_idx = 3;
const int64_t w_idx = 4;
// padding vector: [left, right, top, bottom, font, back]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

font->front

const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int64_t>& padding,
IN_T constant_value) {
// for NCDHW format input tensor, index of n,c,d,h,w is 0,1,2,3,4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释是不是可以直接去掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我留一处吧(其余的删掉

const int64_t d_idx = 2;
const int64_t h_idx = 3;
const int64_t w_idx = 4;
// padding vector: [left, right, top, bottom, font, back]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

.Input("dy", op.GetGradTensorWithOpOutput("y", 0))
.Output("dx")
.Attr("padding", op.attr<std::vector<int64_t>>("padding"))
.Attr("floating_value", op.attr<double>("floating_value"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

带这两个attr的意义是什么呢?在实现的时候可以做自动类型推断吧。

Copy link
Contributor Author

@Flowingsun007 Flowingsun007 Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个主要和2d系列的保持了对齐(后面要改的话,提另一个统一改吧

index_helper.OffsetToNdIndex(num, n, c, d, h, w);

const int64_t src_num = n_channel * x_depth * x_height * x_width;
if (pad_font <= d && d < pad_font + x_depth && w >= pad_left && w < x_width + pad_left
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里变量的拼写错误也fix一下吧,pad_front

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

@oneflow-ci-bot oneflow-ci-bot self-requested a review July 19, 2021 14:13
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot July 19, 2021 15:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants