-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API -part #58917
Merged
Merged
Changes from 5 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
a64a033
[Init] add more split api
megemini bd3d0e9
[Update] update unittest
megemini b00983a
[Add] add docstrings
megemini caeb2e4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini d65310d
[Fix] fix merge
megemini 4da1e21
[Change] tensor_split with split
megemini 758192d
[Fix] remove out of range example
megemini fc8fdbd
[Fix] tensor_split docstring of supported data type
megemini 2d21c80
[Change] _tensor_split_indices with slice
megemini bbd301e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini acee8e2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini a43bcb8
[Change] resolve conflict
megemini 7ef61f8
[Change] h v d -split like tensor_split
megemini 59218c2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini faa8ac3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2194,17 +2194,262 @@ def _get_SectionsTensorList(one_list): | |
return outs | ||
|
||
|
||
def tensor_split(x, indices_or_sections, axis=0, name=None): | ||
""" | ||
Split the input tensor into multiple sub-Tensors along ``axis``, allowing not equally size. | ||
|
||
Args: | ||
x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, complex64, complex128 or int64. | ||
indices_or_sections (int|list|tuple): If ``indices_or_sections`` is an int ``n``, ``x`` is split into ``n`` sections along ``axis``. | ||
If ``x`` is divisible by ``n``, each section will be ``x.shape[axis] / n``. If ``x`` is not divisible by ``n``, the first | ||
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n). | ||
If ``indices_or_sections`` is a list or tuple of integter indices, ``x`` is split along ``axis`` at each of the indices. | ||
axis (int|Tensor, optional): The axis along which to split, it can be a integer or a ``0-D Tensor`` | ||
with shape [] and data type ``int32`` or ``int64``. | ||
If :math::`axis < 0`, the axis to split along is :math:`rank(x) + axis`. Default is 0. | ||
name (str, optional): The default value is None. Normally there is no need for user to set this property. | ||
For more information, please refer to :ref:`api_guide_Name` . | ||
Returns: | ||
list[Tensor], The list of segmented Tensors. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
|
||
>>> # x is a Tensor of shape [8] | ||
>>> # evenly split | ||
>>> x = paddle.rand([8]) | ||
>>> out0, out1 = paddle.tensor_split(x, indices_or_sections=2) | ||
>>> print(out0.shape) | ||
[4] | ||
>>> print(out1.shape) | ||
[4] | ||
|
||
>>> # not evenly split | ||
>>> out0, out1, out2 = paddle.tensor_split(x, indices_or_sections=3) | ||
>>> print(out0.shape) | ||
[3] | ||
>>> print(out1.shape) | ||
[3] | ||
>>> print(out2.shape) | ||
[2] | ||
|
||
>>> # split with indices | ||
>>> out0, out1, out2 = paddle.tensor_split(x, indices_or_sections=[2, 3]) | ||
>>> print(out0.shape) | ||
[2] | ||
>>> print(out1.shape) | ||
[1] | ||
>>> print(out2.shape) | ||
[5] | ||
|
||
>>> # split with indices out of range | ||
>>> out0, out1, out2, out3 = paddle.tensor_split(x, indices_or_sections=[2, 3, 10]) | ||
>>> print(out0.shape) | ||
[2] | ||
>>> print(out1.shape) | ||
[1] | ||
>>> print(out2.shape) | ||
[5] | ||
>>> print(out3.shape) | ||
[0] | ||
|
||
>>> # split along axis | ||
>>> # x is a Tensor of shape [7, 8] | ||
>>> x = paddle.rand([7, 8]) | ||
>>> out0, out1 = paddle.tensor_split(x, indices_or_sections=2, axis=1) | ||
>>> print(out0.shape) | ||
[7, 4] | ||
>>> print(out1.shape) | ||
[7, 4] | ||
|
||
>>> out0, out1, out2 = paddle.tensor_split(x, indices_or_sections=[2, 3], axis=1) | ||
>>> print(out0.shape) | ||
[7, 2] | ||
>>> print(out1.shape) | ||
[7, 1] | ||
>>> print(out2.shape) | ||
[7, 5] | ||
|
||
""" | ||
if x.ndim <= 0 or x.ndim <= axis: | ||
raise ValueError( | ||
f"The input tensor's dimension must be greater than 0 or axis which is {axis}, but got {x.ndim}" | ||
) | ||
|
||
total_n = x.shape[axis] | ||
|
||
def _tensor_split_indices(total_n, indices, axis): | ||
splits = [] | ||
|
||
starts = 0 | ||
ends = 0 | ||
for idx in indices: | ||
ends = idx | ||
sub_array = paddle.slice( | ||
x, axes=[axis], starts=[starts], ends=[ends] | ||
) | ||
splits.append(sub_array) | ||
starts = ends | ||
|
||
starts = ends | ||
ends = total_n | ||
sub_array = paddle.slice(x, axes=[axis], starts=[starts], ends=[ends]) | ||
splits.append(sub_array) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 此处的实现方式为多次调用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好想法!我试试看~ |
||
|
||
return splits | ||
|
||
def _tensor_split_sections(total_n, sections, axis): | ||
if sections <= 0: | ||
raise ValueError('indices_or_sections must be larger than 0.') | ||
|
||
base, mod = divmod(total_n, sections) | ||
section_array = [base + 1] * mod + [base] * (sections - mod) | ||
section_array = np.cumsum(section_array[:-1], dtype=int) | ||
|
||
return _tensor_split_indices(total_n, section_array, axis) | ||
|
||
if isinstance(indices_or_sections, int): | ||
return _tensor_split_sections(total_n, indices_or_sections, axis) | ||
|
||
elif isinstance(indices_or_sections, (list, tuple)): | ||
return _tensor_split_indices(total_n, indices_or_sections, axis) | ||
|
||
else: | ||
raise ValueError( | ||
f"The indices_or_sections should be int, list or tuple of ints, but got {type(indices_or_sections)}" | ||
) | ||
|
||
|
||
def hsplit(x, num_or_sections, name=None): | ||
""" | ||
Split the input tensor into multiple sub-Tensors along the horizontal axis, which is equivalent to ``paddle.split`` with ``axis=1`` | ||
when ``x`` 's dimension is larger than 1, or equivalent to ``paddle.split`` with ``axis=0`` when ``x`` 's dimension is 1. | ||
|
||
Args: | ||
x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64. | ||
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections`` | ||
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into. | ||
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of | ||
sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly. | ||
The length of the list must not be larger than the ``x`` 's size of axis 1 when ``x`` 's dimension is larger than 1, | ||
or axis 0 when ``x`` 's dimension is 1. | ||
name (str, optional): The default value is None. Normally there is no need for user to set this property. | ||
For more information, please refer to :ref:`api_guide_Name` . | ||
Returns: | ||
list[Tensor], The list of segmented Tensors. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
|
||
>>> # x is a Tensor of shape [8] | ||
>>> x = paddle.rand([8]) | ||
>>> out0, out1 = paddle.hsplit(x, num_or_sections=2) | ||
>>> print(out0.shape) | ||
[4] | ||
>>> print(out1.shape) | ||
[4] | ||
|
||
>>> # x is a Tensor of shape [7, 8] | ||
>>> x = paddle.rand([7, 8]) | ||
>>> out0, out1 = paddle.hsplit(x, num_or_sections=2) | ||
>>> print(out0.shape) | ||
[7, 4] | ||
>>> print(out1.shape) | ||
[7, 4] | ||
|
||
>>> out0, out1, out2 = paddle.hsplit(x, num_or_sections=[1, 3, 4]) | ||
>>> print(out0.shape) | ||
[7, 1] | ||
>>> print(out1.shape) | ||
[7, 3] | ||
>>> print(out2.shape) | ||
[7, 4] | ||
|
||
>>> out0, out1, out2 = paddle.hsplit(x, num_or_sections=[2, 3, -1]) | ||
>>> print(out0.shape) | ||
[7, 2] | ||
>>> print(out1.shape) | ||
[7, 3] | ||
>>> print(out2.shape) | ||
[7, 3] | ||
""" | ||
if x.ndim < 1: | ||
raise ValueError( | ||
f"The input tensor's dimension must be greater than 0, but got {x.ndim}" | ||
) | ||
if x.ndim > 1: | ||
return split(x, num_or_sections, axis=1, name=name) | ||
else: | ||
return split(x, num_or_sections, axis=0, name=name) | ||
|
||
|
||
def dsplit(x, num_or_sections, name=None): | ||
""" | ||
Split the input tensor into multiple sub-Tensors along the depth axis, which is equivalent to ``paddle.split`` with ``axis=2``. | ||
|
||
Args: | ||
x (Tensor): A Tensor whose dimension must be greater than 2. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64. | ||
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections`` | ||
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into. | ||
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of | ||
sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly. | ||
The length of the list must not be larger than the ``x`` 's size of axis 2. | ||
name (str, optional): The default value is None. Normally there is no need for user to set this property. | ||
For more information, please refer to :ref:`api_guide_Name` . | ||
Returns: | ||
list[Tensor], The list of segmented Tensors. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
|
||
>>> # x is a Tensor of shape [7, 6, 8] | ||
>>> x = paddle.rand([7, 6, 8]) | ||
>>> out0, out1 = paddle.dsplit(x, num_or_sections=2) | ||
>>> print(out0.shape) | ||
[7, 6, 4] | ||
>>> print(out1.shape) | ||
[7, 6, 4] | ||
|
||
>>> out0, out1, out2 = paddle.dsplit(x, num_or_sections=[1, 3, 4]) | ||
>>> print(out0.shape) | ||
[7, 6, 1] | ||
>>> print(out1.shape) | ||
[7, 6, 3] | ||
>>> print(out2.shape) | ||
[7, 6, 4] | ||
|
||
>>> out0, out1, out2 = paddle.dsplit(x, num_or_sections=[2, 3, -1]) | ||
>>> print(out0.shape) | ||
[7, 6, 2] | ||
>>> print(out1.shape) | ||
[7, 6, 3] | ||
>>> print(out2.shape) | ||
[7, 6, 3] | ||
""" | ||
if x.ndim < 3: | ||
raise ValueError( | ||
f"The input tensor's dimension must be greater than 2, but got {x.ndim}" | ||
) | ||
return split(x, num_or_sections, axis=2, name=name) | ||
|
||
|
||
def vsplit(x, num_or_sections, name=None): | ||
""" | ||
Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.split`` with ``axis=0``. | ||
|
||
Args: | ||
x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, float16, float32, float64, uint8, int8, int32 or int64. | ||
x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64. | ||
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections`` | ||
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into. | ||
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of | ||
sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly. | ||
The length of the list must not be larger than the ``x`` 's size of axis 0. | ||
The length of the list must not be larger than the ``x`` 's size of axis 0. | ||
name (str, optional): The default value is None. Normally there is no need for user to set this property. | ||
For more information, please refer to :ref:`api_guide_Name` . | ||
Returns: | ||
|
@@ -2222,13 +2467,15 @@ def vsplit(x, num_or_sections, name=None): | |
[4, 6, 7] | ||
>>> print(out1.shape) | ||
[4, 6, 7] | ||
|
||
>>> out0, out1, out2 = paddle.vsplit(x, num_or_sections=[1, 3, 4]) | ||
>>> print(out0.shape) | ||
[1, 6, 7] | ||
>>> print(out1.shape) | ||
[3, 6, 7] | ||
>>> print(out2.shape) | ||
[4, 6, 7] | ||
|
||
>>> out0, out1, out2 = paddle.vsplit(x, num_or_sections=[2, 3, -1]) | ||
>>> print(out0.shape) | ||
[2, 6, 7] | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
少了
![image](https://private-user-images.githubusercontent.com/70642955/289775460-f260825a-dbdf-4e11-95d9-1432ab867c14.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTg0MTY5MzYsIm5iZiI6MTcxODQxNjYzNiwicGF0aCI6Ii83MDY0Mjk1NS8yODk3NzU0NjAtZjI2MDgyNWEtZGJkZi00ZTExLTk1ZDktMTQzMmFiODY3YzE0LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA2MTUlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNjE1VDAxNTcxNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTBhMGQyNmJmY2QyZjU1MzA1MDU5OWM0MTllNDRkOGZiNzhhNjI2ZWEyM2JiZGNmNGRmM2RiYjI4YjVjMzg1YzQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.dwSudGcbUp7v2gURnffs-F9WPk4gaP38C9p4X7UE2XQ)
(其实应该在 CI 检查里加一下关于
的检查...顺师傅有意向做不🐶)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不如顺师傅下一个PR单独改一下这个吧