-
Notifications
You must be signed in to change notification settings - Fork 662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add flip module #5541
Add flip module #5541
Conversation
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
oneflow/python/nn/modules/flip.py
Outdated
class Flip(Module): | ||
def __init__(self, dims) -> None: | ||
super().__init__() | ||
assert isinstance(dims, list) or isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以写成:assert isinstance(dims, (list,tuple))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
|
||
|
||
class Flip(Module): | ||
def __init__(self, dims) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里dims没有参数类型检查
oneflow/user/kernels/flip_kernel.cu
Outdated
|
||
template<typename T> | ||
__global__ void FlipGpuForward(const int32_t element, const int64_t total_dims, | ||
const STRIDE_CONTIGUOUS_V stride_contiguous_v, const SIZE_V sizes_v, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为啥不直接传const std::vector<int32_t>,而要构造STRIDE_CONTIGUOUS_V、SIZE_V、STRIDE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuda不能用stl这些结构,我沿用之前其它pr的做法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,不过STRIDE_CONTIGUOUS_V、SIZE_V、STRIDE的结构都一样,应该复用一个结构体就行(不需要构造3次)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
} // namespace | ||
|
||
template<typename T> | ||
class FlipCpuKernel final : public user_op::OpKernel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel包在user_op的namespace下吧
public: | ||
FlipFunctor() { op_ = CHECK_JUST(one::OpBuilder("flip").Input("x").Output("y").Build()); } | ||
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, | ||
const std::vector<int32_t> dims) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::vector<int32_t> dims) const { | |
const std::vector<int32_t>& dims) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
op_ = CHECK_JUST(one::OpBuilder("flip_grad").Input("dy").Output("dx").Build()); | ||
} | ||
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy, | ||
const std::vector<int32_t> dims) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::vector<int32_t> dims) const { | |
const std::vector<int32_t>& dims) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { | ||
const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0); | ||
user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); | ||
*y_desc->mut_shape() = x_desc->shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要check一下dims参数是否<=input tensor的维度
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,应该是小于不能等于。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch的可以等于
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
两者应该都需要检查
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,
oneflow/python/nn/modules/flip.py
Outdated
self.dims = dims | ||
|
||
def forward(self, x): | ||
return flow.F.flip(x, self.dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python层同样需要加一下dims检查,譬如:IndexError: Dimension out of range (expected to be in range of xxx, but got xxx)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
oneflow/user/kernels/flip_kernel.cpp
Outdated
SIZE_V sizes_v; | ||
for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape().At(i); } | ||
|
||
STRIDE strides_v; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
标记一下TODO吧(等stride支持以后,这段逻辑就可以不用了)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
int32_t temp = cur_indices; | ||
cur_indices = cur_indices / stride_contiguous_v.val[d]; | ||
rem = temp - cur_indices * stride_contiguous_v.val[d]; | ||
dst_offset += vis.val[d] ? (sizes_v.val[d] - 1 - cur_indices) * strides_v.val[d] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块逻辑和torch好像没对齐?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么没对齐呢
添加flow.flip module,对齐torch.flip。
docs截图:
doctest截图: