-
Notifications
You must be signed in to change notification settings - Fork 790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modules chunk #5324
Modules chunk #5324
Conversation
wanghongsheng01
commented
Jun 28, 2021
•
edited
Loading
edited
oneflow/python/nn/modules/chunk.py
Outdated
import doctest | ||
|
||
doctest.testmod(raise_on_error=False) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这下面的空行都删了吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯,好的
def test_chunk(test_case): | ||
arg_dict = OrderedDict() | ||
arg_dict["test_fun"] = [ | ||
_test_chunk_forward, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个操作torch也是没有反向吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有反向的,hongsheng 还是加入反向的测试吧(因为是用 slice 拼的,所以不用自己实现反向,但是反向的测试还是要有的)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯,有反向的,马上添加
from test_util import GenArgList | ||
|
||
|
||
def _test_chunk_forward(test_case, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test case多加两个吧,如:4-d input;dim=2/3;不同的chunks ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,马上添加
|
||
@oneflow_export("chunk") | ||
@experimental_api | ||
def chunk_op(input, chunks, dim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要导出下Tensor.chunk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯,好的
oneflow/python/nn/modules/chunk.py
Outdated
else: | ||
tup_list.append(v_chunk) | ||
splits.append( | ||
flow.experimental.slice(input, slice_tup_list=tup_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以用functional api改一下(flow.F.slice(xxx))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,这个有没有其他已 merge 的 module,可供参考呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已改好👌
flow.chunk 思路
channel = input.dim()
dim_size = input.shape[dim]
chunk_size = dim_size / chunks if dim_size % chunks == 0 else (int)(dim_size / chunks)
last_chunk_size = dim_size / chunks if dim_size % chunks == 0 else dim_size - (chunk_size * (chunks - 1))
chunk_dim_dict = {}
tup_ndim = []
splits = []
for chunk in range(0, chunks):
if dim_size % chunks == 0:
start = chunk * chunk_size
stop = (chunk + 1) * chunk_size
else:
start = chunk * chunk_size if chunk < chunks - 1 else chunk_size * (chunks - 1)
stop = (chunk + 1) * chunk_size if chunk < chunks - 1 else dim_size
step = 1
chunk_dim_dict.setdefault(dim, []).append([int(start), int(stop), step])
for k, v in chunk_dim_dict.items():
for v_chunk in v:
tup_list = []
for i in range(0, channel):
if i != dim:
tup_list.append([None, None, None])
else:
tup_list.append(v_chunk)
splits.append(
flow.experimental.slice(input, slice_tup_list=tup_list)
) |