Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API -part #58917

Merged
merged 15 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@
slice,
crop,
split,
tensor_split,
hsplit,
dsplit,
vsplit,
squeeze,
squeeze_,
Expand Down Expand Up @@ -634,6 +637,9 @@
'searchsorted',
'bucketize',
'split',
'tensor_split',
'hsplit',
'dsplit',
'vsplit',
'logical_and',
'logical_and_',
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@
from .manipulation import shard_index # noqa: F401
from .manipulation import slice # noqa: F401
from .manipulation import split # noqa: F401
from .manipulation import tensor_split # noqa: F401
from .manipulation import hsplit # noqa: F401
from .manipulation import dsplit # noqa: F401
from .manipulation import vsplit # noqa: F401
from .manipulation import squeeze # noqa: F401
from .manipulation import squeeze_ # noqa: F401
Expand Down Expand Up @@ -571,6 +574,9 @@
'shard_index',
'slice',
'split',
'tensor_split',
'hsplit',
'dsplit',
'vsplit',
'chunk',
'tensordot',
Expand Down
251 changes: 249 additions & 2 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2194,17 +2194,262 @@ def _get_SectionsTensorList(one_list):
return outs


def tensor_split(x, indices_or_sections, axis=0, name=None):
"""
Split the input tensor into multiple sub-Tensors along ``axis``, allowing not equally size.

Args:
x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, complex64, complex128 or int64.
indices_or_sections (int|list|tuple): If ``indices_or_sections`` is an int ``n``, ``x`` is split into ``n`` sections along ``axis``.
If ``x`` is divisible by ``n``, each section will be ``x.shape[axis] / n``. If ``x`` is not divisible by ``n``, the first
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n).
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n)``.

少了(其实应该在 CI 检查里加一下关于 的检查...顺师傅有意向做不🐶)
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不如顺师傅下一个PR单独改一下这个吧

If ``indices_or_sections`` is a list or tuple of integter indices, ``x`` is split along ``axis`` at each of the indices.
axis (int|Tensor, optional): The axis along which to split, it can be a integer or a ``0-D Tensor``
with shape [] and data type ``int32`` or ``int64``.
If :math::`axis < 0`, the axis to split along is :math:`rank(x) + axis`. Default is 0.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
list[Tensor], The list of segmented Tensors.

Examples:
.. code-block:: python

>>> import paddle

>>> # x is a Tensor of shape [8]
>>> # evenly split
>>> x = paddle.rand([8])
>>> out0, out1 = paddle.tensor_split(x, indices_or_sections=2)
>>> print(out0.shape)
[4]
>>> print(out1.shape)
[4]

>>> # not evenly split
>>> out0, out1, out2 = paddle.tensor_split(x, indices_or_sections=3)
>>> print(out0.shape)
[3]
>>> print(out1.shape)
[3]
>>> print(out2.shape)
[2]

>>> # split with indices
>>> out0, out1, out2 = paddle.tensor_split(x, indices_or_sections=[2, 3])
>>> print(out0.shape)
[2]
>>> print(out1.shape)
[1]
>>> print(out2.shape)
[5]

>>> # split with indices out of range
>>> out0, out1, out2, out3 = paddle.tensor_split(x, indices_or_sections=[2, 3, 10])
>>> print(out0.shape)
[2]
>>> print(out1.shape)
[1]
>>> print(out2.shape)
[5]
>>> print(out3.shape)
[0]

>>> # split along axis
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1 = paddle.tensor_split(x, indices_or_sections=2, axis=1)
>>> print(out0.shape)
[7, 4]
>>> print(out1.shape)
[7, 4]

>>> out0, out1, out2 = paddle.tensor_split(x, indices_or_sections=[2, 3], axis=1)
>>> print(out0.shape)
[7, 2]
>>> print(out1.shape)
[7, 1]
>>> print(out2.shape)
[7, 5]

"""
if x.ndim <= 0 or x.ndim <= axis:
raise ValueError(
f"The input tensor's dimension must be greater than 0 or axis which is {axis}, but got {x.ndim}"
)

total_n = x.shape[axis]

def _tensor_split_indices(total_n, indices, axis):
splits = []

starts = 0
ends = 0
for idx in indices:
ends = idx
sub_array = paddle.slice(
x, axes=[axis], starts=[starts], ends=[ends]
)
splits.append(sub_array)
starts = ends

starts = ends
ends = total_n
sub_array = paddle.slice(x, axes=[axis], starts=[starts], ends=[ends])
splits.append(sub_array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处的实现方式为多次调用slice操作并append到一个list中,预期耗时会较高。
能否考虑将indices_or_sections转化成num_or_sections 再一次性调用用split操作完成?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好想法!我试试看~


return splits

def _tensor_split_sections(total_n, sections, axis):
if sections <= 0:
raise ValueError('indices_or_sections must be larger than 0.')

base, mod = divmod(total_n, sections)
section_array = [base + 1] * mod + [base] * (sections - mod)
section_array = np.cumsum(section_array[:-1], dtype=int)

return _tensor_split_indices(total_n, section_array, axis)

if isinstance(indices_or_sections, int):
return _tensor_split_sections(total_n, indices_or_sections, axis)

elif isinstance(indices_or_sections, (list, tuple)):
return _tensor_split_indices(total_n, indices_or_sections, axis)

else:
raise ValueError(
f"The indices_or_sections should be int, list or tuple of ints, but got {type(indices_or_sections)}"
)


def hsplit(x, num_or_sections, name=None):
"""
Split the input tensor into multiple sub-Tensors along the horizontal axis, which is equivalent to ``paddle.split`` with ``axis=1``
when ``x`` 's dimension is larger than 1, or equivalent to ``paddle.split`` with ``axis=0`` when ``x`` 's dimension is 1.

Args:
x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64.
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly.
The length of the list must not be larger than the ``x`` 's size of axis 1 when ``x`` 's dimension is larger than 1,
or axis 0 when ``x`` 's dimension is 1.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
list[Tensor], The list of segmented Tensors.

Examples:
.. code-block:: python

>>> import paddle

>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1 = paddle.hsplit(x, num_or_sections=2)
>>> print(out0.shape)
[4]
>>> print(out1.shape)
[4]

>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1 = paddle.hsplit(x, num_or_sections=2)
>>> print(out0.shape)
[7, 4]
>>> print(out1.shape)
[7, 4]

>>> out0, out1, out2 = paddle.hsplit(x, num_or_sections=[1, 3, 4])
>>> print(out0.shape)
[7, 1]
>>> print(out1.shape)
[7, 3]
>>> print(out2.shape)
[7, 4]

>>> out0, out1, out2 = paddle.hsplit(x, num_or_sections=[2, 3, -1])
>>> print(out0.shape)
[7, 2]
>>> print(out1.shape)
[7, 3]
>>> print(out2.shape)
[7, 3]
"""
if x.ndim < 1:
raise ValueError(
f"The input tensor's dimension must be greater than 0, but got {x.ndim}"
)
if x.ndim > 1:
return split(x, num_or_sections, axis=1, name=name)
else:
return split(x, num_or_sections, axis=0, name=name)


def dsplit(x, num_or_sections, name=None):
"""
Split the input tensor into multiple sub-Tensors along the depth axis, which is equivalent to ``paddle.split`` with ``axis=2``.

Args:
x (Tensor): A Tensor whose dimension must be greater than 2. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64.
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly.
The length of the list must not be larger than the ``x`` 's size of axis 2.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
list[Tensor], The list of segmented Tensors.

Examples:
.. code-block:: python

>>> import paddle

>>> # x is a Tensor of shape [7, 6, 8]
>>> x = paddle.rand([7, 6, 8])
>>> out0, out1 = paddle.dsplit(x, num_or_sections=2)
>>> print(out0.shape)
[7, 6, 4]
>>> print(out1.shape)
[7, 6, 4]

>>> out0, out1, out2 = paddle.dsplit(x, num_or_sections=[1, 3, 4])
>>> print(out0.shape)
[7, 6, 1]
>>> print(out1.shape)
[7, 6, 3]
>>> print(out2.shape)
[7, 6, 4]

>>> out0, out1, out2 = paddle.dsplit(x, num_or_sections=[2, 3, -1])
>>> print(out0.shape)
[7, 6, 2]
>>> print(out1.shape)
[7, 6, 3]
>>> print(out2.shape)
[7, 6, 3]
"""
if x.ndim < 3:
raise ValueError(
f"The input tensor's dimension must be greater than 2, but got {x.ndim}"
)
return split(x, num_or_sections, axis=2, name=name)


def vsplit(x, num_or_sections, name=None):
"""
Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.split`` with ``axis=0``.

Args:
x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, float16, float32, float64, uint8, int8, int32 or int64.
x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64.
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly.
The length of the list must not be larger than the ``x`` 's size of axis 0.
The length of the list must not be larger than the ``x`` 's size of axis 0.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Expand All @@ -2222,13 +2467,15 @@ def vsplit(x, num_or_sections, name=None):
[4, 6, 7]
>>> print(out1.shape)
[4, 6, 7]

>>> out0, out1, out2 = paddle.vsplit(x, num_or_sections=[1, 3, 4])
>>> print(out0.shape)
[1, 6, 7]
>>> print(out1.shape)
[3, 6, 7]
>>> print(out2.shape)
[4, 6, 7]

>>> out0, out1, out2 = paddle.vsplit(x, num_or_sections=[2, 3, -1])
>>> print(out0.shape)
[2, 6, 7]
Expand Down