Skip to content

Commit

Permalink
【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API (#…
Browse files Browse the repository at this point in the history
…6389)

* [Add] split extension api docs

* [Change] h v d split using tensor_split

* Update docs/api/paddle/tensor_split_cn.rst

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update docs/api/paddle/dsplit_cn.rst

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update docs/api/paddle/hsplit_cn.rst

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.dsplit.md

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.tensor_split.md

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.vsplit.md

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update docs/guides/model_convert/convert_from_pytorch/api_difference/ops/torch.dsplit.md

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update docs/guides/model_convert/convert_from_pytorch/api_difference/ops/torch.hsplit.md

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update docs/guides/model_convert/convert_from_pytorch/api_difference/ops/torch.tensor_split.md

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update docs/guides/model_convert/convert_from_pytorch/api_difference/ops/torch.vsplit.md

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* [Fix] blankline

* [Fix] ref tensor_split

---------

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
  • Loading branch information
megemini and sunzhongkai588 committed Dec 14, 2023
1 parent ca33083 commit 816c816
Show file tree
Hide file tree
Showing 15 changed files with 291 additions and 22 deletions.
3 changes: 3 additions & 0 deletions docs/api/paddle/Overview_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,10 @@ tensor 元素操作相关(如:转置,reshape 等)
" :ref:`paddle.shard_index <cn_api_paddle_shard_index>` ", "根据分片(shard)的偏移量重新计算分片的索引"
" :ref:`paddle.slice <cn_api_paddle_slice>` ", "沿多个轴生成 input 的切片"
" :ref:`paddle.split <cn_api_paddle_split>` ", "将输入 Tensor 分割成多个子 Tensor"
" :ref:`paddle.tensor_split <cn_api_paddle_tensor_split>` ", "将输入 Tensor 分割成多个子 Tensor,允许不等分"
" :ref:`paddle.hsplit <cn_api_paddle_hsplit>` ", "将输入 Tensor 沿第零个维度分割成多个子 Tensor"
" :ref:`paddle.vsplit <cn_api_paddle_vsplit>` ", "将输入 Tensor 沿第一个维度分割成多个子 Tensor"
" :ref:`paddle.dsplit <cn_api_paddle_dsplit>` ", "将输入 Tensor 沿第二个维度分割成多个子 Tensor"
" :ref:`paddle.squeeze <cn_api_paddle_squeeze>` ", "删除输入 Tensor 的 Shape 中尺寸为 1 的维度"
" :ref:`paddle.stack <cn_api_paddle_stack>` ", "沿 axis 轴对输入 x 进行堆叠操作"
" :ref:`paddle.strided_slice <cn_api_paddle_strided_slice>` ", "沿多个轴生成 x 的切片"
Expand Down
29 changes: 28 additions & 1 deletion docs/api/paddle/Tensor_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2220,7 +2220,34 @@ split(num_or_sections, axis=0, name=None)

请参考 :ref:`cn_api_paddle_split`

vsplit(num_or_sections, name=None)
tensor_split(num_or_indices, axis=0, name=None)
:::::::::

返回:计算后的 Tensor

返回类型:Tensor

请参考 :ref:`cn_api_paddle_tensor_split`

dsplit(num_or_indices, name=None)
:::::::::

返回:计算后的 Tensor

返回类型:Tensor

请参考 :ref:`cn_api_paddle_dsplit`

hsplit(num_or_indices, name=None)
:::::::::

返回:计算后的 Tensor

返回类型:Tensor

请参考 :ref:`cn_api_paddle_hsplit`

vsplit(num_or_indices, name=None)
:::::::::

返回:计算后的 Tensor
Expand Down
27 changes: 27 additions & 0 deletions docs/api/paddle/dsplit_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
.. _cn_api_paddle_dsplit:

dsplit
-------------------------------

.. py:function:: paddle.dsplit(x, num_or_indices, name=None)
将输入 Tensor 沿着深度轴分割成多个子 Tensor,等价于将 :ref:`cn_api_paddle_tensor_split` API 的参数 axis 固定为 2。

参数
:::::::::
- **x** (Tensor) - 输入变量,数据类型为 bool、bfloat16、float16、float32、float64、uint8、int8、int32、int64 的多维 Tensor,其维度必须大于 2。
- **num_or_indices** (int|list|tuple) - 如果 ``num_or_indices`` 是一个整数 ``n`` ,则 ``x`` 拆分为 ``n`` 部分。如果 ``num_or_indices`` 是整数索引的列表或元组,则在每个索引处分割 ``x`` 。
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。

返回
:::::::::

list[Tensor],分割后的 Tensor 列表。


代码示例
:::::::::

COPY-FROM: paddle.dsplit
27 changes: 27 additions & 0 deletions docs/api/paddle/hsplit_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
.. _cn_api_paddle_hsplit:

hsplit
-------------------------------

.. py:function:: paddle.hsplit(x, num_or_indices, name=None)
将输入 Tensor 沿着水平轴分割成多个子 Tensor。当 x 的维度大于 1 时等价于将 :ref:`cn_api_paddle_tensor_split` API 的参数 axis 固定为 1,当 x 的维度等于 1 时等价于将 paddle.tensor_split API 的参数 axis 固定为 0。

参数
:::::::::
- **x** (Tensor) - 输入变量,数据类型为 bool、bfloat16、float16、float32、float64、uint8、int8、int32、int64 的多维 Tensor,其维度必须大于 0。
- **num_or_indices** (int|list|tuple) - 如果 ``num_or_indices`` 是一个整数 ``n`` ,则 ``x`` 拆分为 ``n`` 部分。如果 ``num_or_indices`` 是整数索引的列表或元组,则在每个索引处分割 ``x`` 。
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。

返回
:::::::::

list[Tensor],分割后的 Tensor 列表。


代码示例
:::::::::

COPY-FROM: paddle.hsplit
28 changes: 28 additions & 0 deletions docs/api/paddle/tensor_split_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
.. _cn_api_paddle_tensor_split:

tensor_split
-------------------------------

.. py:function:: paddle.tensor_split(x, num_or_indices, axis=0, name=None)
将输入 Tensor 沿着轴 ``axis`` 分割成多个子 Tensor,允许进行不等长地分割。

参数
:::::::::
- **x** (Tensor) - 输入变量,数据类型为 bool、bfloat16、float16、float32、float64、uint8、int8、int32、int64 的多维 Tensor,其维度必须大于 0。
- **num_or_indices** (int|list|tuple) - 如果 ``num_or_indices`` 是一个整数 ``n`` ,则 ``x`` 沿 ``axis`` 拆分为 ``n`` 部分。如果 ``x`` 可被 ``n`` 整除,则每个部分都是 ``x.shape[axis]/n`` 。如果 ``x`` 不能被 ``n`` 整除,则第一个 ``int(x.shape[axis]%n)`` 分割大小将为 ``int(x.shape[axis]/n)+1`` ,其余部分的大小将是 ``int(x.shape[axis]/n)`` 。如果 ``num_or_indices`` 是整数索引的列表或元组,则在每个索引处沿 ``axis`` 分割 ``x`` 。例如, ``num_or_indices=[2, 4]`` 在 ``axis=0`` 时将沿轴 0 将 ``x`` 拆分为 ``x[:2]`` 、 ``x[2:4]`` 和 ``x[4:]`` 。
- **axis** (int|Tensor,可选) - 整数或者形状为[]的 0-D Tensor,数据类型为 int32 或 int64。表示需要分割的维度。如果 ``axis < 0``,则划分的维度为 ``rank(x) + axis`` 。默认值为 0。
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。

返回
:::::::::

list[Tensor],分割后的 Tensor 列表。


代码示例
:::::::::

COPY-FROM: paddle.tensor_split
8 changes: 4 additions & 4 deletions docs/api/paddle/vsplit_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
vsplit
-------------------------------

.. py:function:: paddle.vsplit(x, num_or_sections, name=None)
.. py:function:: paddle.vsplit(x, num_or_indices, name=None)
将输入 Tensor 沿着垂直轴分割成多个子 Tensor,等价于将 paddle.split API 的参数 axis 固定为 0。
将输入 Tensor 沿着垂直轴分割成多个子 Tensor,等价于将 :ref:`cn_api_paddle_tensor_split` API 的参数 axis 固定为 0。

参数
:::::::::
- **x** (Tensor) - 输入变量,数据类型为 bool、float16、float32、float64、uint8、int8、int32、int64 的多维 Tensor,其维度必须大于 1。
- **num_or_sections** (int|list|tuple) - 如果 ``num_or_sections`` 是一个整数,则表示 Tensor 平均划分为相同大小子 Tensor 的数量。如果 ``num_or_sections`` 是一个 list 或 tuple,那么它的长度代表子 Tensor 的数量,它的元素可以是整数或者形状为[]的 0-D Tensor,依次代表子 Tensor 需要分割成的维度的大小。list 或 tuple 的长度不能超过输入 Tensor 第一个维度的大小。在 list 或 tuple 中,至多有一个元素值为-1,表示该值是由 ``x`` 的维度和其他 ``num_or_sections`` 中元素推断出来的
- **x** (Tensor) - 输入变量,数据类型为 bool、bfloat16、float16、float32、float64、uint8、int8、int32、int64 的多维 Tensor,其维度必须大于 1。
- **num_or_indices** (int|list|tuple) - 如果 ``num_or_indices`` 是一个整数 ``n`` ,则 ``x`` 拆分为 ``n`` 部分。如果 ``num_or_indices`` 是整数索引的列表或元组,则在每个索引处分割 ``x`` 。
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。

返回
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
## [ 仅参数名不一致 ]torch.Tensor.split_size_or_sections

### [torch.Tensor.的 split](https://pytorch.org/docs/stable/generated/torch.Tensor.dsplit.html)

```python
torch.Tensor.dsplit(split_size_or_sections)
```

### [paddle.Tensor.dsplit](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/Tensor_cn.html#dsplit-num_or_indices-name-none)

```python
paddle.Tensor.dsplit(num_or_indices, name=None)
```

其中 Paddle 相比 Pytorch 仅参数名不一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ------------- | ------------ | ------------------------------------------------------ |
| split_size_or_sections | num_or_indices | 表示分割的数量或索引,仅参数名不一致。 |
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
## [ 仅参数名不一致 ]torch.Tensor.hsplit

### [torch.Tensor.hsplit](https://pytorch.org/docs/stable/generated/torch.Tensor.hsplit.html)

```python
torch.Tensor.hsplit(split_size_or_sections)
```

### [paddle.Tensor.hsplit](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/Tensor_cn.html#hsplit-num_or_indices-name-none)

```python
paddle.Tensor.hsplit(num_or_indices, name=None)
```

其中 Paddle 相比 Pytorch 仅参数名不一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ------------- | ------------ | ------------------------------------------------------ |
| split_size_or_sections | num_or_indices | 表示分割的数量或索引,仅参数名不一致。 |
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
## [ 仅参数名不一致 ]torch.tensor_split
### [torch.tensor_split](https://pytorch.org/docs/stable/generated/torch.Tensor.tensor_split.html)

```python
torch.tensor_split(indices_or_sections, dim=0)
```

### [paddle.tensor_split](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/Tensor_cn.html#tensor_split-num_or_indices-axis-0-name-none)

```python
paddle.tensor_split(num_or_indices, axis=0, name=None)
```

其中 Paddle 相比 Pytorch 仅参数名不一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ------------- | ------------ | ------------------------------------------------------ |
| indices_or_sections | num_or_indices | 表示分割的数量或索引,仅参数名不一致。 |
| dim | axis | 表示需要分割的维度,仅参数名不一致。 |
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
## [ 仅参数名不一致 ]torch.Tensor.split_size_or_sections

### [torch.Tensor.vsplit](https://pytorch.org/docs/stable/generated/torch.Tensor.vsplit.html)

```python
torch.Tensor.vsplit(split_size_or_sections)
```

### [paddle.Tensor.vsplit](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/Tensor_cn.html#vsplit-num_or_indices-name-none)

```python
paddle.Tensor.vsplit(num_or_indices, name=None)
```

其中 Paddle 相比 Pytorch 仅参数名不一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ------------- | ------------ | ------------------------------------------------------ |
| split_size_or_sections | num_or_indices | 表示分割的数量或索引,仅参数名不一致。 |
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
## [ 仅 paddle 参数更多 ]torch.dsplit
## [ 仅参数名不一致 ]torch.dsplit
### [torch.dsplit](https://pytorch.org/docs/stable/generated/torch.dsplit.html#torch.dsplit)

```python
torch.dsplit(input,
indices_or_sections)
```

### [paddle.split](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/split_cn.html)
### [paddle.dsplit](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/dsplit_cn.html)

```python
paddle.split(x,
num_or_sections,
axis=0,
paddle.dsplit(x,
num_or_indices,
name=None)
```

Paddle 相比 PyTorch 支持更多其他参数,具体如下:
其中 Paddle 相比 Pytorch 仅参数名不一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ------------- | ------------ | ------------------------------------------------------ |
| input | x | 输入多维 Tensor ,仅参数名不一致。 |
| indices_or_sections | num_or_sections | 用于分割的 int 或 list 或 tuple ,仅参数名不一致。 |
| - | axis | 表示需要分割的维度,PyTorch 无此参数,Paddle 需要设置为 2。 |
| indices_or_sections | num_or_indices | 表示分割的数量或索引,仅参数名不一致。 |
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
## [ 仅参数名不一致 ]torch.hsplit
### [torch.hsplit](https://pytorch.org/docs/stable/generated/torch.hsplit.html#torch.hsplit)

```python
torch.hsplit(input,
indices_or_sections)
```

### [paddle.hsplit](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/hsplit_cn.html)

```python
paddle.hsplit(x,
num_or_indices,
name=None)
```

其中 Paddle 相比 Pytorch 仅参数名不一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ------------- | ------------ | ------------------------------------------------------ |
| input | x | 输入多维 Tensor ,仅参数名不一致。 |
| indices_or_sections | num_or_indices | 表示分割的数量或索引,仅参数名不一致。 |
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## [ 仅参数名不一致 ]torch.tensor_split
### [torch.tensor_split](https://pytorch.org/docs/stable/generated/torch.tensor_split.html)

```python
torch.tensor_split(input, indices_or_sections, dim=0)
```

### [paddle.tensor_split](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/tensor_split_cn.html)

```python
paddle.tensor_split(x, num_or_indices, axis=0, name=None)
```

其中 Paddle 相比 Pytorch 仅参数名不一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ------------- | ------------ | ------------------------------------------------------ |
| input | x | 表示输入的 Tensor ,仅参数名不一致。 |
| indices_or_sections | num_or_indices | 表示分割的数量或索引,仅参数名不一致。 |
| dim | axis | 表示需要分割的维度,仅参数名不一致。 |
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
## [ 仅参数名不一致 ]torch.vsplit
### [torch.vsplit](https://pytorch.org/docs/stable/generated/torch.vsplit.html#torch.vsplit)

```python
torch.vsplit(input,
indices_or_sections)
```

### [paddle.vsplit](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/vsplit_cn.html)

```python
paddle.vsplit(x,
num_or_indices,
name=None)
```

其中 Paddle 相比 Pytorch 仅参数名不一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ------------- | ------------ | ------------------------------------------------------ |
| input | x | 输入多维 Tensor ,仅参数名不一致。 |
| indices_or_sections | num_or_indices | 表示分割的数量或索引,仅参数名不一致。 |

0 comments on commit 816c816

Please sign in to comment.