Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.41】为 Paddle 新增 Rprop API 中文文档 #6388

Merged
merged 6 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/paddle/optimizer/Overview_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ paddle.optimizer 目录下包含飞桨框架支持的优化器算法相关的 AP
" :ref:`Optimizer <cn_api_paddle_optimizer_Optimizer>` ", "飞桨框架优化器基类"
" :ref:`RMSProp <cn_api_paddle_optimizer_RMSProp>` ", "RMSProp 优化器"
" :ref:`SGD <cn_api_paddle_optimizer_SGD>` ", "SGD 优化器"
" :ref:`Rprop <cn_api_paddle_optimizer_Rprop>` ", "Rprop 优化器"

.. _about_lr:

Expand Down
127 changes: 127 additions & 0 deletions docs/api/paddle/optimizer/Rprop_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
.. _cn_api_paddle_optimizer_Rprop:

Rprop
-------------------------------

.. py:class:: paddle.optimizer.Rprop(learning_rate=0.001, learning_rate_range=(1e-5, 50), parameters=None, etas=(0.5, 1.2), grad_clip=None, name=None)


.. note::
此优化器仅适用于 full-batch 训练。

Rprop算法的优化器。有关详细信息,请参阅:

`A direct adaptive method for faster backpropagation learning : The RPROP algorithm <https://ieeexplore.ieee.org/document/298623>`_ 。


.. math::

\begin{aligned}
&\hspace{0mm} For\ all\ weights\ and\ biases\{ \\
&\hspace{5mm} \textbf{if} \: (\frac{\partial E}{\partial w_{ij}}(t-1)*\frac{\partial E}{\partial w_{ij}}(t)> 0)\ \textbf{then} \: \{ \\
&\hspace{10mm} learning\_rate_{ij}(t)=\mathrm{minimum}(learning\_rate_{ij}(t-1)*\eta^{+},learning\_rate_{max}) \\
&\hspace{10mm} \Delta w_{ij}(t)=-sign(\frac{\partial E}{\partial w_{ij}}(t))*learning\_rate_{ij}(t) \\
&\hspace{10mm} w_{ij}(t+1)=w_{ij}(t)+\Delta w_{ij}(t) \\
&\hspace{5mm} \} \\
&\hspace{5mm} \textbf{else if} \: (\frac{\partial E}{\partial w_{ij}}(t-1)*\frac{\partial E}{\partial w_{ij}}(t)< 0)\ \textbf{then} \: \{ \\
&\hspace{10mm} learning\_rate_{ij}(t)=\mathrm{maximum}(learning\_rate_{ij}(t-1)*\eta^{-},learning\_rate_{min}) \\
&\hspace{10mm} w_{ij}(t+1)=w_{ij}(t) \\
&\hspace{10mm} \frac{\partial E}{\partial w_{ij}}(t)=0 \\
&\hspace{5mm} \} \\
&\hspace{5mm} \textbf{else if} \: (\frac{\partial E}{\partial w_{ij}}(t-1)*\frac{\partial E}{\partial w_{ij}}(t)= 0)\ \textbf{then} \: \{ \\
&\hspace{10mm} \Delta w_{ij}(t)=-sign(\frac{\partial E}{\partial w_{ij}}(t))*learning\_rate_{ij}(t) \\
&\hspace{10mm} w_{ij}(t+1)=w_{ij}(t)+\Delta w_{ij}(t) \\
&\hspace{5mm} \} \\
&\hspace{0mm} \} \\
\end{aligned}


参数
::::::::::::

- **learning_rate** (float|_LRScheduleri,可选) - 初始学习率,用于参数更新的计算。可以是一个浮点型值或者一个_LRScheduler 类。默认值为 0.001。
- **learning_rate_range** (tuple,可选) - 学习率的范围。学习率不能小于元组的第一个元素;学习率不能大于元组的第二个元素。默认值为 (1e-5, 50)。
- **parameters** (list,可选) - 指定优化器需要优化的参数。在动态图模式下必须提供该参数;在静态图模式下默认值为 None,这时所有的参数都将被优化。
- **etas** (tuple,可选) - 用于更新学习率的元组。元组的第一个元素是乘法递减因子;元组的第二个元素是乘法增加因子。默认值为 (0.5, 1.2)。
- **grad_clip** (GradientClipBase,可选) – 梯度裁剪的策略,支持三种裁剪策略::ref:`paddle.nn.ClipGradByGlobalNorm <cn_api_paddle_nn_ClipGradByGlobalNorm>` 、 :ref:`paddle.nn.ClipGradByNorm <cn_api_paddle_nn_ClipGradByNorm>` 、 :ref:`paddle.nn.ClipGradByValue <cn_api_paddle_nn_ClipGradByValue>` 。
默认值为 None,此时将不进行梯度裁剪。
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。


代码示例
::::::::::::

COPY-FROM: paddle.optimizer.Rprop


方法
::::::::::::
step()
'''''''''

.. note::

该 API 只在 `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ 模式下生效。

执行一次优化器并进行参数更新。

**返回**

无。

**代码示例**

COPY-FROM: paddle.optimizer.Rprop.step

minimize(loss, startup_program=None, parameters=None, no_grad_set=None)
'''''''''

为网络添加反向计算过程,并根据反向计算所得的梯度,更新 parameters 中的 Parameters,最小化网络损失值 loss。

**参数**

- **loss** (Tensor) - 需要最小化的损失值变量
- **startup_program** (Program,可选) - 用于初始化 parameters 中参数的 :ref:`cn_api_paddle_static_Program`,默认值为 None,此时将使用 :ref:`cn_api_paddle_static_default_startup_program` 。
- **parameters** (list,可选) - 待更新的 Parameter 或者 Parameter.name 组成的列表,默认值为 None,此时将更新所有的 Parameter。
- **no_grad_set** (set,可选) - 不需要更新的 Parameter 或者 Parameter.name 组成的集合,默认值为 None。

**返回**

tuple(optimize_ops, params_grads),其中 optimize_ops 为参数优化 OP 列表;param_grads 为由(param, param_grad)组成的列表,其中 param 和 param_grad 分别为参数和参数的梯度。在静态图模式下,该返回值可以加入到 ``Executor.run()`` 接口的 ``fetch_list`` 参数中,若加入,则会重写 ``use_prune`` 参数为 True,并根据 ``feed`` 和 ``fetch_list`` 进行剪枝,详见 ``Executor`` 的文档。


**代码示例**

COPY-FROM: paddle.optimizer.Rprop.minimize

clear_grad()
'''''''''

.. note::

该 API 只在 `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ 模式下生效。


清除需要优化的参数的梯度。

**代码示例**

COPY-FROM: paddle.optimizer.Rprop.clear_grad

get_lr()
'''''''''

.. note::

该 API 只在 `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ 模式下生效。

获取当前步骤的学习率。当不使用_LRScheduler 时,每次调用的返回值都相同,否则返回当前步骤的学习率。

**返回**

float,当前步骤的学习率。


**代码示例**

COPY-FROM: paddle.optimizer.Rprop.get_lr
10 changes: 10 additions & 0 deletions docs/api_guides/low_level/optimizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,13 @@ API Reference 请参考 :ref:`cn_api_fluid_optimizer_FtrlOptimizer`
:code:`ModelAverage` 优化器,在训练中通过窗口来累计历史 parameter,在预测时使用取平均值后的 paramet,整体提高预测的精度。

API Reference 请参考 :ref:`cn_api_fluid_optimizer_ModelAverage`




10.Rprop/RpropOptimizer
-----------------

:code:`Rprop` 优化器,该方法考虑到不同权值参数的梯度的数量级可能相差很大,因此很难找到一个全局的学习步长。因此创新性地提出靠参数梯度的符号,动态的调节学习步长以加速优化过程的方法。

API Reference 请参考 :ref:`cn_api_fluid_optimizer_Rprop`
10 changes: 10 additions & 0 deletions docs/api_guides/low_level/optimizer_en.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,13 @@ API Reference: :ref:`api_fluid_optimizer_FtrlOptimizer`
:code:`ModelAverage` Optimizer accumulates history parameters through sliding window during the model training. We use averaged parameters at inference time to upgrade general accuracy of inference.

API Reference: :ref:`api_fluid_optimizer_ModelAverage`




10.Rprop/RpropOptimizer
-----------------

:code:`Rprop` Optimizer, this method considers that the magnitude of gradients for different weight parameters may vary greatly, making it difficult to find a global learning step size. Therefore, an innovative method is proposed to accelerate the optimization process by dynamically adjusting the learning step size through the use of parameter gradient symbols.

API Reference: :ref:`api_fluid_optimizer_Rprop`
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
## [ torch 参数更多 ]torch.optim.Rprop

### [torch.optim.Rprop](https://pytorch.org/docs/stable/generated/torch.optim.Rprop.html)

```python
torch.optim.Rprop(params,
lr=0.01,
etas=(0.5, 1.2),
step_sizes=(1e-06, 50),
foreach=None,
maximize=False,
differentiable=False)
```

### [paddle.optimizer.Rprop](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/optimizer/Rprop_cn.html#cn-api-paddle-optimizer-rprop)

```python
paddle.optimizer.Rprop(learning_rate=0.001,
learning_rate_range=(1e-5, 50),
parameters=None,
etas=(0.5, 1.2),
grad_clip=None,
name=None)
```

Pytorch 相比 Paddle 支持更多其他参数,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ------------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------- |
| params | parameters | 表示指定优化器需要优化的参数,仅参数名不一致。 |
| lr | learning_rate | 初始学习率,用于参数更新的计算。参数默认值不一致, Pytorch 默认为`0.01`, Paddle 默认为`0.001`,Paddle 需保持与 Pytorch 一致。 |
| etas | etas | 用于更新学习率。参数一致。 |
| step_sizes | learning_rate_range | 学习率的范围,参数默认值不一致, Pytorch 默认为`(1e-06, 50)`, Paddle 默认为`(1e-5, 50)`,Paddle 需保持与 Pytorch 一致。 |
| foreach | - | 是否使用优化器的 foreach 实现。Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 |
| maximize | - | 根据目标最大化参数,而不是最小化。Paddle 无此参数,暂无转写方式。 |
| differentiable| - | 是否应通过训练中的优化器步骤进行自动微分。Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 |
| - | grad_clip | 梯度裁剪的策略。 PyTorch 无此参数,Paddle 保持默认即可。 |