Skip to content

Commit

Permalink
【Hackathon 5th No.4】为 Paddle 新增 masked_scatter API (#6405)
Browse files Browse the repository at this point in the history
* add masked_scatter docs

* fix

* fix
  • Loading branch information
yangguohao committed Dec 21, 2023
1 parent 2854d27 commit 5f5e8c1
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/api/paddle/Overview_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ tensor 数学操作原位(inplace)版本
" :ref:`paddle.hypot_ <cn_api_paddle_hypot_>` ", "Inplace 版本的 hypot API,对输入 x 采用 Inplace 策略"
" :ref:`paddle.multigammaln_ <cn_api_paddle_multigammaln_>` ", "Inplace 版本的 multigammaln API,对输入 x 采用 Inplace 策略"
" :ref:`paddle.masked_fill_ <cn_api_paddle_masked_fill_>` ", "Inplace 版本的 masked_fill API,对输入 x 采用 Inplace 策略"
" :ref:`paddle.masked_scatter_ <cn_api_paddle_masked_scatter_>` ", "Inplace 版本的 masked_scatter API,对输入 x 采用 Inplace 策略"
" :ref:`paddle.index_fill_ <cn_api_paddle_index_fill_>` ", "Inplace 版本的 index_fill API,对输入 x 采用 Inplace 策略"
" :ref:`paddle.sin_ <cn_api_paddle_sin_>` ", "Inplace 版本的 sin API,对输入 x 采用 Inplace 策略"

Expand Down Expand Up @@ -402,6 +403,7 @@ tensor 元素操作相关(如:转置,reshape 等)
" :ref:`paddle.view_as <cn_api_paddle_view_as>` ", "使用 other 的 shape,返回 x 的一个 view Tensor"
" :ref:`paddle.unfold <cn_api_paddle_unfold>` ", "返回 x 的一个 view Tensor。以滑动窗口式提取 x 的值"
" :ref:`paddle.masked_fill <cn_api_paddle_masked_fill>` ", "根据 mask 信息,将 value 中的值填充到 x 中 mask 对应为 True 的位置。"
" :ref:`paddle.masked_scatter <cn_api_paddle_masked_scatter>` ", "根据 mask 信息,将 value 中的值逐个填充到 x 中 mask 对应为 True 的位置。"
" :ref:`paddle.diagonal_scatter <cn_api_paddle_diagonal_scatter>` ", "根据给定的轴 axis 和偏移量 offset,将张量 y 的值填充到张量 x 中"
" :ref:`paddle.index_fill <cn_api_paddle_index_fill>` ", "沿着指定轴 axis 将 index 中指定位置的 x 的值填充为 value"
" :ref:`paddle.column_stack <cn_api_paddle_column_stack>` ", "沿水平轴堆叠输入 x 中的所有张量。"
Expand Down
13 changes: 13 additions & 0 deletions docs/api/paddle/Tensor_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3121,6 +3121,19 @@ masked_fill_(x, mask, value, name=None)

Inplace 版本的 :ref:`cn_api_paddle_masked_fill` API,对输入 `x` 采用 Inplace 策略。

masked_scatter(x, mask, value, name=None)
:::::::::
根据 mask 信息,将 value 中的值逐个填充到 x 中 mask 对应为 True 的位置。

返回一个根据 mask 将对应位置填充为 value 中元素的 Tensor。

请参考 :ref:`cn_api_paddle_masked_scatter`

masked_scatter_(x, mask, value, name=None)
:::::::::

Inplace 版本的 :ref:`cn_api_paddle_masked_scatter` API,对输入 `x` 采用 Inplace 策略。

atleast_1d(name=None)
:::::::::
将输入转换为张量并返回至少为 ``1`` 维的视图。 ``1`` 维或更高维的输入会被保留。
Expand Down
11 changes: 11 additions & 0 deletions docs/api/paddle/masked_scatter__cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.. _cn_api_paddle_masked_scatter_:

masked_scatter\_
-------------------------------

.. py:function:: paddle.masked_scatter_(x, mask, value, name=None)
Inplace 版本的 :ref:`cn_api_paddle_masked_scatter` API,对输入 x 采用 Inplace 策略。

更多关于 inplace 操作的介绍请参考 `3.1.3 原位(Inplace)操作和非原位操作的区别`_ 了解详情。

.. _3.1.3 原位(Inplace)操作和非原位操作的区别: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/beginner/tensor_cn.html#id3
28 changes: 28 additions & 0 deletions docs/api/paddle/masked_scatter_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
.. _cn_api_paddle_masked_scatter:

masked_scatter
-------------------------------

.. py:function:: paddle.masked_scatter(x, mask, value, name=None)
返回一个 N-D 的 Tensor,Tensor 的值是根据 ``mask`` 信息,将 ``value`` 中的值逐个填充到 ``x`` 中 ``mask`` 对应为 ``True`` 的位置,``mask`` 的数据类型是 bool。

参数
::::::::::::

- **x** (Tensor) - 输入 Tensor,数据类型为 float,double,int,int64_t,float16 或者 bfloat16。
- **mask** (Tensor) - 布尔张量,表示要填充的位置。mask 的数据类型必须为 bool。
- **value** (Tensor) - 用于填充目标张量的值,数据类型为 float,double,int,int64_t,float16 或者 bfloat16。
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。

返回
::::::::::::
返回一个根据 ``mask`` 将对应位置逐个填充 ``value`` 中的 Tensor,数据类型与 ``x`` 相同。


代码示例
::::::::::::

COPY-FROM: paddle.masked_scatter
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## [ 参数完全一致 ] torch.Tensor.masked_scatter

### [torch.Tensor.masked_scatter](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter.html?highlight=masked_scatter#torch.Tensor.masked_scatter)

```python
torch.Tensor.masked_scatter(mask, value)
```

### [paddle.Tensor.masked_scatter](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/Tensor_cn.html#masked-scatter-mask-value-name-non)

```python
paddle.Tensor.masked_scatter(mask, value, name=None)
```

两者功能一致,参数完全一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
|---------|--------------| -------------------------------------------------- |
| mask | mask | 布尔张量,表示要填充的位置 |
| value | value | 用于填充目标张量的值 |
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## [ 参数完全一致 ] torch.Tensor.masked_scatter_

### [torch.Tensor.masked_scatter_](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html?highlight=masked_scatter#torch.Tensor.masked_scatter_)

```python
torch.Tensor.masked_scatter_(mask, value)
```

### [paddle.Tensor.masked_scatter_]()

```python
paddle.Tensor.masked_scatter_(mask, value, name=None)
```

两者功能一致,参数完全一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
|---------|--------------| -------------------------------------------------- |
| mask | mask | 布尔张量,表示要填充的位置 |
| value | value | 用于填充目标张量的值 |
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,8 @@
| 304 | [torch.Tensor.resize_](https://pytorch.org/docs/stable/generated/torch.Tensor.resize_.html?highlight=resize#torch.Tensor.resize_) | | 功能缺失 |
| 305 | [torch.Tensor.masked_fill_](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html?highlight=resize#torch.Tensor.masked_fill_) | [paddle.Tensor.masked_fill_](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/Tensor_cn.html#id25) | 功能完全一致 |
| 306 | [torch.Tensor.tensor_split](https://pytorch.org/docs/stable/generated/torch.Tensor.tensor_split.html) | [paddle.Tensor.tensor_split](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/Tensor_cn.html#tensor_split-indices_or_sections-axis-0-name-none) | 功能完全一致,仅参数名不一致 [差异对比](https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.tensor_split.md) |
| 307 | [torch.Tensor.masked_scatter](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter.html?highlight=resize#torch.Tensor.masked_scatter) | [paddle.Tensor.masked_scatter](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/Tensor_cn.html#id25) | 功能完全一致 |
| 308 | [torch.Tensor.masked_scatter_](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html?highlight=resize#torch.Tensor.masked_scatter_) | [paddle.Tensor.masked_scatter_](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/Tensor_cn.html#id25) | 功能完全一致 |


| 序号 | PyTorch API | PaddlePaddle API | 备注 |
Expand Down

0 comments on commit 5f5e8c1

Please sign in to comment.