Skip to content

Commit

Permalink
【Hackthon 6th No. 27】为 paddle.io.RandomSampler/random_split /Layer.cl…
Browse files Browse the repository at this point in the history
…ear_gradients 进行功能增强 (#6594)

* update docs

* update docs
  • Loading branch information
NKNaN committed Apr 10, 2024
1 parent dbf8cf1 commit f7683fc
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 29 deletions.
2 changes: 1 addition & 1 deletion docs/api/paddle/io/RandomSampler_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RandomSampler
:::::::::
- **data_source** (Dataset) - 此参数必须是 :ref:`cn_api_paddle_io_Dataset` 或 :ref:`cn_api_paddle_io_IterableDataset` 的一个子类实例或实现了 ``__len__`` 的 Python 对象,用于生成样本下标。默认值为 None。
- **replacement** (bool,可选) - 如果为 ``False`` 则会采样整个数据集,如果为 ``True`` 则会按 ``num_samples`` 指定的样本数采集。默认值为 ``False`` 。
- **num_samples** (int,可选) - 如果 ``replacement`` 设置为 ``True`` 则按此参数采集对应的样本数。默认值为 None,不启用
- **num_samples** (int,可选) - 按此参数采集对应的样本数。默认值为 None,此时设为 ``data_source`` 的长度
- **generator** (Generator,可选) - 指定采样 ``data_source`` 的采样器。默认值为 None,不启用。

返回
Expand Down
2 changes: 1 addition & 1 deletion docs/api/paddle/io/random_split_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ random_split
::::::::::::

- **dataset** (Dataset) - 此参数必须是 ``paddle.io.Dataset`` 或 ``paddle.io.IterableDataset`` 的一个子类实例或实现了 ``__len__`` 的 Python 对象,用于生成样本下标。默认值为 None。
- **lengths** (list) - 总和为原数组长度的,子集合长度数组
- **lengths** (list) - 总和为原数组长度,表示子集合长度数组;或总和为 1.0,表示子集合长度占比的数组
- **generator** (Generator,可选) - 指定采样 ``data_source`` 的采样器。默认值为 None。

返回
Expand Down
6 changes: 5 additions & 1 deletion docs/api/paddle/nn/Layer_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,15 @@ sublayers(include_self=False)

COPY-FROM: paddle.nn.Layer.sublayers

clear_gradients()
clear_gradients(set_to_zero=True)
'''''''''

清除该层所有参数的梯度。

**参数**

- **set_to_zero** (bool,可选) - 是否将可训练参数的梯度设置为 0 ,若为 False 则设为 None。默认值:True。

**返回**

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## [torch 参数更多]torch.nn.Module.zero_grad
## [参数不一致]torch.nn.Module.zero_grad

### [torch.nn.Module.zero_grad](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.zero_grad)

Expand All @@ -9,13 +9,23 @@ torch.nn.Module.zero_grad(set_to_none=True)
### [paddle.nn.Layer.clear_gradients](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/nn/Layer_cn.html#clear-gradients)

```python
paddle.nn.Layer.clear_gradients()
paddle.nn.Layer.clear_gradients(set_to_zero=True)
```

PyTorch 相比 Paddle 支持更多其他参数,具体如下:
PyTorch `Module.zero_grad` 参数与 Paddle `Layer.clear_gradients` 参数用法刚好相反,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ----------- | ------------ | ------------------------------------------------ |
| set_to_none | - | 是否设置为 None,Paddle 无此参数,暂无转写方式。 |
| set_to_none | set_to_zero | 设置如何清空梯度,PyTorch 默认 set_to_none 为 True,Paddle 默认 set_to_zero 为 True,两者功能刚好相反,Paddle 需设置为 False。 |

### 转写示例

```python
# PyTorch 写法
torch.nn.Module.zero_grad(set_to_none=True)

# Paddle 写法
paddle.nn.Layer.clear_gradients(set_to_zero=False)
```
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ paddle.io.RandomSampler(data_source, replacement=False, num_samples=None, genera
| ----------- | ------------ | -------------------------------------------------------------------- |
| data_source | data_source | Dataset 或 IterableDataset 的一个子类实例或实现了 `__len__` 的 Python 对象。 |
| replacement | replacement | 如果为 False 则会采样整个数据集。 |
| num_samples | num_samples | 如果 replacement 设置为 True 则按此参数采集对应的样本数|
| num_samples | num_samples | 按此参数采集对应的样本数|
| generator | generator | 指定采样 data_source 的采样器。 |
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## [ 参数不一致 ]torch.utils.data.random_split
## [ 参数完全一致 ]torch.utils.data.random_split
### [torch.utils.data.random_split](https://pytorch.org/docs/stable/data.html?highlight=torch+utils+data+random_split#torch.utils.data.random_split)

```python
Expand All @@ -15,27 +15,12 @@ paddle.io.random_split(dataset,
generator=None)
```

两者参数除 lengths 外用法一致,具体如下:
### 参数差异
两者功能一致,参数完全一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ------------- | ------------ |---------------------------------------------------------------------|
| dataset | dataset | 表示可迭代数据集。 |
| lengths | lengths | PyTorch:可为子集合长度列表,列表总和为原数组长度。也可为子集合所占比例列表,列表总和为 1.0。PaddlePaddle: 子集合长度列表,列表总和为原数组长度 |
| lengths | lengths | 可为子集合长度列表,列表总和为原数组长度。也可为子集合所占比例列表,列表总和为 1.0。 |
| generator | generator | 指定采样 data_source 的采样器。默认值为 None。 |

### 转写示例
#### lenghts: 子集合长度列表
```python
# PyTorch 写法
lengths = [0.3, 0.3, 0.4]
datasets = torch.utils.data.random_split(dataset,
lengths,
generator=torch.manual_seed(0))

# Paddle 写法
lengths = [0.3, 0.3, 0.4]
lengths = [length * len(dataset) for length in lengths]
datasets = paddle.io.random_split(dataset,
lengths,
generator=paddle.seed(0))
```

0 comments on commit f7683fc

Please sign in to comment.