Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modules chunk #5324

Merged
merged 25 commits into from
Jul 6, 2021
Merged

Modules chunk #5324

merged 25 commits into from
Jul 6, 2021

Conversation

wanghongsheng01
Copy link
Contributor

@wanghongsheng01 wanghongsheng01 commented Jun 28, 2021

image
image

import doctest

doctest.testmod(raise_on_error=False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这下面的空行都删了吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,好的

def test_chunk(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_chunk_forward,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个操作torch也是没有反向吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有反向的,hongsheng 还是加入反向的测试吧(因为是用 slice 拼的,所以不用自己实现反向,但是反向的测试还是要有的)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,有反向的,马上添加

from test_util import GenArgList


def _test_chunk_forward(test_case, device):
Copy link
Contributor

@Flowingsun007 Flowingsun007 Jul 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test case多加两个吧,如:4-d input;dim=2/3;不同的chunks ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,马上添加


@oneflow_export("chunk")
@experimental_api
def chunk_op(input, chunks, dim):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要导出下Tensor.chunk

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,好的

@wanghongsheng01 wanghongsheng01 requested review from oneflow-ci-bot and removed request for oneflow-ci-bot July 6, 2021 02:12
else:
tup_list.append(v_chunk)
splits.append(
flow.experimental.slice(input, slice_tup_list=tup_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以用functional api改一下(flow.F.slice(xxx))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,这个有没有其他已 merge 的 module,可供参考呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改好👌

@wanghongsheng01
Copy link
Contributor Author

flow.chunk 思路

  1. 求出每块的大小,除不尽时,求出最后一块的大小
channel = input.dim()
dim_size = input.shape[dim]
chunk_size = dim_size / chunks if dim_size % chunks == 0 else (int)(dim_size / chunks)
last_chunk_size = dim_size / chunks if dim_size % chunks == 0 else dim_size - (chunk_size * (chunks - 1))
  1. 因为利用 slice 拼,所以对于每块 chunk,需要求用 slice 提取该 chunk 的 tensor 所需的整张图的 tuple_list
  2. 如何求每块 chunk对应的整张图的 tuple_list?
    对于切的维度,求该维度对应的 tuple_list = [start, stop, step]; 对于不切的维度,对应 tuple_list = [None, None, None]
  • 保存切的维度的 tuple_list:一键多值 dict = {
    '切的维度' : [ chunk1 对应的 [start, stop, step],chunk2 对应的 [start, stop, step],... ]
    }
chunk_dim_dict = {}
tup_ndim = []
splits = []

for chunk in range(0, chunks): 
    if dim_size % chunks == 0:
        start = chunk * chunk_size
        stop = (chunk + 1) * chunk_size
    else:
        start = chunk * chunk_size if chunk < chunks - 1 else chunk_size * (chunks - 1)
        stop = (chunk + 1) * chunk_size if chunk < chunks - 1 else dim_size
    step = 1
    chunk_dim_dict.setdefault(dim, []).append([int(start), int(stop), step])
  • 整张图的 tuple_list
    遍历切的维度的 value,对于每个 chunk,为该维度的 tuple_list 添加其他不切维度的 tuple_list = [None, None, None]
  • 对于每个 chunk,利用 slice 和该 chunk 对应的整张图的 tuple_list,提取该 chunk 的 tensor 分量
for k, v in chunk_dim_dict.items():
      for v_chunk in v:
          tup_list = []
          for i in range(0, channel):
              if i != dim:
                  tup_list.append([None, None, None])
              else:
                  tup_list.append(v_chunk)
          splits.append(
              flow.experimental.slice(input, slice_tup_list=tup_list)
          )

@wanghongsheng01 wanghongsheng01 requested review from oneflow-ci-bot and removed request for oneflow-ci-bot July 6, 2021 07:51
@oneflow-ci-bot oneflow-ci-bot removed their request for review July 6, 2021 08:13
@oneflow-ci-bot oneflow-ci-bot self-requested a review July 6, 2021 08:13
@wanghongsheng01 wanghongsheng01 requested review from oneflow-ci-bot and removed request for oneflow-ci-bot July 6, 2021 08:21
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot July 6, 2021 09:23
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot July 6, 2021 10:20
@oneflow-ci-bot oneflow-ci-bot merged commit 3e8d3a1 into master Jul 6, 2021
@oneflow-ci-bot oneflow-ci-bot deleted the modules_chunk branch July 6, 2021 11:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants