Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API -part #58917

Merged
merged 15 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,15 @@
concat,
crop,
diagonal_scatter,
dsplit,
expand,
expand_as,
flatten,
flip,
flip as reverse,
gather,
gather_nd,
hsplit,
index_add,
index_add_,
index_fill,
Expand Down Expand Up @@ -309,6 +311,7 @@
row_stack,
strided_slice,
take_along_axis,
tensor_split,
tensordot,
tile,
tolist,
Expand Down Expand Up @@ -631,6 +634,9 @@
'searchsorted',
'bucketize',
'split',
'tensor_split',
'hsplit',
'dsplit',
'vsplit',
'logical_and',
'logical_and_',
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
column_stack,
concat,
diagonal_scatter,
dsplit,
dstack,
expand,
expand_as,
Expand All @@ -158,6 +159,7 @@
flip as reverse,
gather,
gather_nd,
hsplit,
hstack,
index_add,
index_add_,
Expand Down Expand Up @@ -189,6 +191,7 @@
stack,
strided_slice,
take_along_axis,
tensor_split,
tensordot,
tile,
unbind,
Expand Down Expand Up @@ -608,6 +611,9 @@
'shard_index',
'slice',
'split',
'tensor_split',
'hsplit',
'dsplit',
'vsplit',
'chunk',
'tensordot',
Expand Down
241 changes: 223 additions & 18 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,17 +2571,227 @@ def _get_SectionsTensorList(one_list):
return outs


def vsplit(x, num_or_sections, name=None):
def tensor_split(x, num_or_indices, axis=0, name=None):
"""
Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.split`` with ``axis=0``.
Split the input tensor into multiple sub-Tensors along ``axis``, allowing not being of equal size.

Args:
x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, float16, float32, float64, uint8, int8, int32 or int64.
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly.
The length of the list must not be larger than the ``x`` 's size of axis 0.
x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections along ``axis``.
If ``x`` is divisible by ``n``, each section will be ``x.shape[axis] / n``. If ``x`` is not divisible by ``n``, the first
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n).
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n)``.

少了(其实应该在 CI 检查里加一下关于 的检查...顺师傅有意向做不🐶)
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不如顺师傅下一个PR单独改一下这个吧

If ``num_or_indices`` is a list or tuple of integter indices, ``x`` is split along ``axis`` at each of the indices. For instance,
``num_or_indices=[2, 4]`` with ``axis=0`` would split ``x`` into ``x[:2]``, ``x[2:4]`` and ``x[4:]`` along axis 0.
axis (int|Tensor, optional): The axis along which to split, it can be a integer or a ``0-D Tensor``
with shape [] and data type ``int32`` or ``int64``.
If :math::`axis < 0`, the axis to split along is :math:`rank(x) + axis`. Default is 0.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
list[Tensor], The list of segmented Tensors.

Examples:
.. code-block:: python

>>> import paddle

>>> # x is a Tensor of shape [8]
>>> # evenly split
>>> x = paddle.rand([8])
>>> out0, out1 = paddle.tensor_split(x, num_or_indices=2)
>>> print(out0.shape)
[4]
>>> print(out1.shape)
[4]

>>> # not evenly split
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=3)
>>> print(out0.shape)
[3]
>>> print(out1.shape)
[3]
>>> print(out2.shape)
[2]

>>> # split with indices
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3])
>>> print(out0.shape)
[2]
>>> print(out1.shape)
[1]
>>> print(out2.shape)
[5]

>>> # split along axis
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1 = paddle.tensor_split(x, num_or_indices=2, axis=1)
>>> print(out0.shape)
[7, 4]
>>> print(out1.shape)
[7, 4]

>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3], axis=1)
>>> print(out0.shape)
[7, 2]
>>> print(out1.shape)
[7, 1]
>>> print(out2.shape)
[7, 5]

"""
if x.ndim <= 0 or x.ndim <= axis:
raise ValueError(
f"The input tensor's dimension must be greater than 0 or axis which is {axis}, but got {x.ndim}"
)

total_n = x.shape[axis]

def _tensor_split_indices(x, total_n, indices, axis):
splits = []

starts = 0
ends = 0
for idx in list(indices) + [total_n]:
ends = idx
# convert index < 0 to positive
starts_index = starts if starts >= 0 else total_n + starts
ends_index = ends if ends >= 0 else total_n + ends
# ends index should equal or larger than starts
ends_index = max(starts_index, ends_index)

sub_array = paddle.slice(
x, axes=[axis], starts=[starts_index], ends=[ends_index]
)
splits.append(sub_array)
starts = ends

return splits

def _tensor_split_sections(x, total_n, sections, axis):
if sections <= 0:
raise ValueError('num_or_indices must be larger than 0.')

base, mod = divmod(total_n, sections)
num_or_sections = [base + 1] * mod + [base] * (sections - mod)
return split(x, num_or_sections, axis)

if isinstance(num_or_indices, int):
return _tensor_split_sections(x, total_n, num_or_indices, axis)

elif isinstance(num_or_indices, (list, tuple)):
return _tensor_split_indices(x, total_n, num_or_indices, axis)

else:
raise ValueError(
f"The num_or_indices should be int, list or tuple of ints, but got {type(num_or_indices)}"
)


def hsplit(x, num_or_indices, name=None):
"""
Split the input tensor into multiple sub-Tensors along the horizontal axis, which is equivalent to ``paddle.tensor_split`` with ``axis=1``
when ``x`` 's dimension is larger than 1, or equivalent to ``paddle.tensor_split`` with ``axis=0`` when ``x`` 's dimension is 1.

Args:
x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections.
If ``num_or_indices`` is a list or tuple of integter indices, ``x`` is split at each of the indices.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
list[Tensor], The list of segmented Tensors.

Examples:
.. code-block:: python

>>> import paddle

>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1 = paddle.hsplit(x, num_or_indices=2)
>>> print(out0.shape)
[4]
>>> print(out1.shape)
[4]

>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1 = paddle.hsplit(x, num_or_indices=2)
>>> print(out0.shape)
[7, 4]
>>> print(out1.shape)
[7, 4]

>>> out0, out1, out2 = paddle.hsplit(x, num_or_indices=[1, 4])
>>> print(out0.shape)
[7, 1]
>>> print(out1.shape)
[7, 3]
>>> print(out2.shape)
[7, 4]

"""
if x.ndim < 1:
raise ValueError(
f"The input tensor's dimension must be greater than 0, but got {x.ndim}"
)
if x.ndim > 1:
return tensor_split(x, num_or_indices, axis=1, name=name)
else:
return tensor_split(x, num_or_indices, axis=0, name=name)


def dsplit(x, num_or_indices, name=None):
"""
Split the input tensor into multiple sub-Tensors along the depth axis, which is equivalent to ``paddle.tensor_split`` with ``axis=2``.

Args:
x (Tensor): A Tensor whose dimension must be greater than 2. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections.
If ``num_or_indices`` is a list or tuple of integter indices, ``x`` is split at each of the indices.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
list[Tensor], The list of segmented Tensors.

Examples:
.. code-block:: python

>>> import paddle

>>> # x is a Tensor of shape [7, 6, 8]
>>> x = paddle.rand([7, 6, 8])
>>> out0, out1 = paddle.dsplit(x, num_or_indices=2)
>>> print(out0.shape)
[7, 6, 4]
>>> print(out1.shape)
[7, 6, 4]

>>> out0, out1, out2 = paddle.dsplit(x, num_or_indices=[1, 4])
>>> print(out0.shape)
[7, 6, 1]
>>> print(out1.shape)
[7, 6, 3]
>>> print(out2.shape)
[7, 6, 4]

"""
if x.ndim < 3:
raise ValueError(
f"The input tensor's dimension must be greater than 2, but got {x.ndim}"
)
return tensor_split(x, num_or_indices, axis=2, name=name)


def vsplit(x, num_or_indices, name=None):
"""
Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.tensor_split`` with ``axis=0``.

Args:
x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections.
If ``num_or_indices`` is a list or tuple of integter indices, ``x`` is split at each of the indices.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Expand All @@ -2594,31 +2804,26 @@ def vsplit(x, num_or_sections, name=None):

>>> # x is a Tensor of shape [8, 6, 7]
>>> x = paddle.rand([8, 6, 7])
>>> out0, out1 = paddle.vsplit(x, num_or_sections=2)
>>> out0, out1 = paddle.vsplit(x, num_or_indices=2)
>>> print(out0.shape)
[4, 6, 7]
>>> print(out1.shape)
[4, 6, 7]
>>> out0, out1, out2 = paddle.vsplit(x, num_or_sections=[1, 3, 4])

>>> out0, out1, out2 = paddle.vsplit(x, num_or_indices=[1, 4])
>>> print(out0.shape)
[1, 6, 7]
>>> print(out1.shape)
[3, 6, 7]
>>> print(out2.shape)
[4, 6, 7]
>>> out0, out1, out2 = paddle.vsplit(x, num_or_sections=[2, 3, -1])
>>> print(out0.shape)
[2, 6, 7]
>>> print(out1.shape)
[3, 6, 7]
>>> print(out2.shape)
[3, 6, 7]

"""
if x.ndim < 2:
raise ValueError(
f"The input tensor's dimension must be greater than 1, but got {x.ndim}"
)
return split(x, num_or_sections, axis=0, name=name)
return tensor_split(x, num_or_indices, axis=0, name=name)


def squeeze(x, axis=None, name=None):
Expand Down