Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flip module #5541

Merged
merged 26 commits into from
Jul 20, 2021
Merged

Add flip module #5541

merged 26 commits into from
Jul 20, 2021

Conversation

BBuf
Copy link
Contributor

@BBuf BBuf commented Jul 19, 2021

添加flow.flip module,对齐torch.flip。

  • 添加flip op和kernel
  • 添加functor
  • 添加gradient_funcs的后向实现
  • 实现module,导出flow.flip和tensor.flip
  • 添加docstting和单元测试以及doctest
  • 修改review意见

docs截图:

图片

doctest截图:

图片

@BBuf BBuf requested review from doombeaker and hjchen2 July 19, 2021 09:27
@BBuf BBuf requested a review from Flowingsun007 July 19, 2021 10:05
class Flip(Module):
def __init__(self, dims) -> None:
super().__init__()
assert isinstance(dims, list) or isinstance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以写成:assert isinstance(dims, (list,tuple))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的



class Flip(Module):
def __init__(self, dims) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里dims没有参数类型检查


template<typename T>
__global__ void FlipGpuForward(const int32_t element, const int64_t total_dims,
const STRIDE_CONTIGUOUS_V stride_contiguous_v, const SIZE_V sizes_v,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为啥不直接传const std::vector<int32_t>,而要构造STRIDE_CONTIGUOUS_V、SIZE_V、STRIDE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda不能用stl这些结构,我沿用之前其它pr的做法

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,不过STRIDE_CONTIGUOUS_V、SIZE_V、STRIDE的结构都一样,应该复用一个结构体就行(不需要构造3次)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

} // namespace

template<typename T>
class FlipCpuKernel final : public user_op::OpKernel {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel包在user_op的namespace下吧

public:
FlipFunctor() { op_ = CHECK_JUST(one::OpBuilder("flip").Input("x").Output("y").Build()); }
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::vector<int32_t> dims) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const std::vector<int32_t> dims) const {
const std::vector<int32_t>& dims) const {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

op_ = CHECK_JUST(one::OpBuilder("flip_grad").Input("dy").Output("dx").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::vector<int32_t> dims) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const std::vector<int32_t> dims) const {
const std::vector<int32_t>& dims) const {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0);
*y_desc->mut_shape() = x_desc->shape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要check一下dims参数是否<=input tensor的维度

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,应该是小于不能等于。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch的可以等于

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

图片

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我意思是dims列表的长度<=x的维度数;你说的是具体到某个dim
WechatIMG82

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

两者应该都需要检查

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,

self.dims = dims

def forward(self, x):
return flow.F.flip(x, self.dims)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python层同样需要加一下dims检查,譬如:IndexError: Dimension out of range (expected to be in range of xxx, but got xxx)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

SIZE_V sizes_v;
for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape().At(i); }

STRIDE strides_v;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

标记一下TODO吧(等stride支持以后,这段逻辑就可以不用了)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

int32_t temp = cur_indices;
cur_indices = cur_indices / stride_contiguous_v.val[d];
rem = temp - cur_indices * stride_contiguous_v.val[d];
dst_offset += vis.val[d] ? (sizes_v.val[d] - 1 - cur_indices) * strides_v.val[d]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块逻辑和torch好像没对齐?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么没对齐呢

@BBuf BBuf requested review from oneflow-ci-bot and removed request for oneflow-ci-bot July 20, 2021 02:47
@oneflow-ci-bot oneflow-ci-bot removed their request for review July 20, 2021 04:26
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot July 20, 2021 04:26
@oneflow-ci-bot oneflow-ci-bot merged commit 2b208ec into master Jul 20, 2021
@oneflow-ci-bot oneflow-ci-bot deleted the add_flip_module branch July 20, 2021 07:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants