New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor: remove module not required, call function directly #5754
Conversation
unsqueeze, where, transpose, triu
return Eq()(input, other) | ||
|
||
if isinstance(other, flow.Tensor) or isinstance( | ||
other, flow._oneflow_internal.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以写成if isinstance(other, (flow.Tensor,flow._oneflow_internal.Tensor)):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flow.Tensor 和 flow._oneflow_internal.Tensor 已经合二为一了。我去掉了 flow._oneflow_internal.Tensor
), "The second tensor's shape should broadcastable with the first argument." | ||
if input.dtype != other.dtype: | ||
other = other.to(dtype=input.dtype) | ||
elif isinstance(other, int) or isinstance(other, float): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改
|
||
|
||
@register_tensor_op("argwhere") | ||
def argwhere_tebsor_op(x, dtype: Optional[flow.dtype] = None): | ||
def argwhere_tebsor_op(input, dtype: Optional[flow.dtype] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里 dtype 默认改成 flow.int32? argwhere_op 里面没处理 dtype == none 的情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改(连带这里有个原有的拼写错误……)
return Cat(dim=dim)(inputs) | ||
if len(inputs) == 1: | ||
return inputs[0] | ||
axis = dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些python代码太耗时了,安排人都迁移到c++里去?其他地方也一样(比如前面的chunk)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯,是打算迁移。PR描述里备注了要搬运C++的,都要搬运……
CI failed, removing label automerge |
Speed stats:
|
Speed stats:
|
把不必要的通过
XXXModule()()
先实例化 module 对象,再调用的方式。改为直接调用函数。修改结果如下:
flow.F.xxx
其它没有移除的 Module 及说明在:https://github.com/Oneflow-Inc/OneTeam/issues/521