PPO 中的广义优势估计（Generalized Advantage Estimation GAE）通过在时间步上对 TD Error 进行加权累加，提供了一个在偏差和方差之间可调的优势估计器。其定义为
$$
\begin{aligned}A_{t}^{\mathrm{GAE}(\gamma, \lambda)} & =\delta_{t}+(\gamma \lambda) \delta_{t+1}+(\gamma \lambda)^{2} \delta_{t+2}+\cdots \\& =\sum_{l=0}^{\infty}(\gamma \lambda)^{l} \delta_{t+l}\end{aligned}
$$
其中
$\delta_{t}=r_{t}+\gamma v_{t}\left(s_{t+1}\right)-v_{t}\left(s_{t}\right)$ 是在第 t 步上的价值估计残差。 GAE 的衰减系数是 $\lambda$，控制偏差和方差之间的平衡，即步数越大方差越大。
- $\lambda=0$ 只考虑一步的 TD Error，偏差较大，但方差较小；
- $\lambda=1$ 优势估计等价于蒙特卡洛方法，偏差较小，但方差较大；

通常奖励是稀疏延迟的、最朴素的 reward model 对于整个序列分配一个最终分数。

In [7]:
import torch

B, T = 4, 20
rewards = torch.cat((torch.zeros(B, T-1), torch.randn(B, 1)), dim=1)
values = torch.randn(B, T)
rewards, values.shape

(tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000, -1.9131],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000, -0.7905],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000, -0.7630],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000, -1.0393]]),
 torch.Size([4, 20]))

反向、递推计算 GAE。

In [35]:
gamma = 0.95
lambd = 0.99

gae = 0.
advantages = []

for t in reversed(range(T)):
    v_next = (values[:, t+1] if t < T-1 else 0.0)
    delta = rewards[:, t] + gamma * v_next - values[:, t]
    gae = delta + gamma * lambd * gae
    advantages.append(gae)

advantages = torch.stack(advantages[::-1], dim=1)
advantages

tensor([[ 0.3177,  0.8067, -0.2291, -1.2326,  1.1936, -1.6586, -0.5858, -1.3243,
          0.6070, -1.9644, -1.7141, -1.1202, -1.0013, -1.3996, -2.3914, -1.0959,
         -0.2817, -1.5057, -1.5323,  0.1711],
        [ 1.8244, -0.5336,  1.5302, -0.8799, -0.4558, -0.4538, -0.0670, -0.7079,
         -0.6576,  1.5259, -1.2568, -0.0608, -0.1926, -1.9756, -1.5727, -0.9659,
         -0.4553, -0.1435, -0.7183, -0.8768],
        [-0.1662, -1.2120,  0.1726,  0.1674, -0.8651, -0.2279, -0.0876,  0.6933,
         -0.3429, -0.8923,  0.9746,  0.1910, -0.6030, -0.5632, -1.2377, -0.0467,
         -1.1971, -1.8571, -0.5677, -2.5791],
        [ 1.7116, -1.4577, -2.6082,  0.6714,  0.0295, -1.1375, -1.6134, -1.6412,
         -1.1371, -2.4929, -0.7455, -0.4193, -0.2312, -1.3275,  0.2358,  1.3057,
         -1.1331,  0.6836, -0.9357, -0.1647]])