基于如上的 TRPO 的数学原理实现了 Minimal TRPO。

三个数学工具核心
- 共轭梯度 (Conjugate Gradient)
- Hessian–向量积 (HVP) 基于 KL 的二阶近似
- 线搜索 + KL 约束

如下假定已经有一批 rollouts，并能提供
```
- states:      Tensor [N, obs_dim]
- actions:     Tensor [N, act_dim]
- advantages:  Tensor [N]
- old_logp:    Tensor [N]
```

In [1]:
import math
from dataclasses import dataclass
from typing import Callable, Tuple

import torch
import torch.nn as nn
from torch.distributions import Normal

如下是一个简单的高斯策略网络 (连续动作)。

In [4]:
class GaussianPolicy(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes=(64, 64)):
        super().__init__()
        layers = []
        last = obs_dim
        for h in hidden_sizes:
            layers += [nn.Linear(last, h), nn.Tanh()]
            last = h

        self.net = nn.Sequential(*layers)
        self.mu = nn.Linear(last, act_dim)
        self.log_std = nn.Parameter(torch.zeros(act_dim))

    def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self.net(obs)
        mu = self.mu(h)
        log_std = self.log_std.expand_as(mu)
        return mu, log_std

    def dist(self, obs: torch.Tensor) -> Normal:
        mu, log_std = self(obs)
        std = torch.exp(log_std)
        return Normal(mu, std)

In [3]:
def flat_params(model: nn.Module) -> torch.Tensor:
    return torch.cat([p.data.view(-1) for p in model.parameters()])


def set_flat_params(model: nn.Module, flat: torch.Tensor) -> None:
    i = 0
    for p in model.parameters():
        n = p.numel()
        p.data.copy_(flat[i:i+n].view_as(p))
        i += n


def flat_grad(y: torch.Tensor, model: nn.Module, retain_graph=False, create_graph=False) -> torch.Tensor:
    grads = torch.autograd.grad(y, [p for p in model.parameters() if p.requires_grad],
                                retain_graph=retain_graph, create_graph=create_graph, allow_unused=True)
    out = []
    for g, p in zip(grads, model.parameters()):
        if p.requires_grad:
            out.append(torch.zeros_like(p).view(-1) if g is None else g.contiguous().view(-1))
    return torch.cat(out)

共轭梯度法 (只需提供 Avp 函数与右端项 b)。

In [5]:
def conjugate_gradients(Avp: Callable[[torch.Tensor], torch.Tensor], b: torch.Tensor, nsteps=10, residual_tol=1e-10) -> torch.Tensor:
    x = torch.zeros_like(b)
    r = b.clone()
    p = r.clone()
    rr = torch.dot(r, r)
    for _ in range(nsteps):
        Ap = Avp(p)
        alpha = rr / (torch.dot(p, Ap) + 1e-10)
        x = x + alpha * p
        r = r - alpha * Ap
        new_rr = torch.dot(r, r)
        if new_rr < residual_tol:
            break
        beta = new_rr / (rr + 1e-10)
        p = r + beta * p
        rr = new_rr
    return x

TRPO 更新器，包括可信域半径（阈值）、共轭梯度迭代步数、HVP 的阻尼项 (H + d * I)、线搜索收缩系数、线搜索最大步数、期望改进比例（可选用于更严格的接受准则）。

In [6]:
@dataclass
class TRPOConfig:
    max_kl: float = 1e-2             
    cg_iters: int = 10                
    cg_damping: float = 1e-2          
    backtrack_coeff: float = 0.8      
    backtrack_iters: int = 15         
    accept_ratio: float = 0.1         


class TRPOUpdater:
    def __init__(self, policy: GaussianPolicy, config: TRPOConfig):
        self.policy = policy
        self.cfg = config

    @torch.no_grad()
    def _old_dist(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        mu, log_std = self.policy(states)
        return mu.detach(), log_std.detach()

    def _surrogate_loss(self, states: torch.Tensor, actions: torch.Tensor, advantages: torch.Tensor, old_logp: torch.Tensor) -> torch.Tensor:
        dist = self.policy.dist(states)
        logp = dist.log_prob(actions).sum(-1)
        ratio = torch.exp(logp - old_logp)
        return -(ratio * advantages).mean()

    def _mean_kl(self, states: torch.Tensor, old_mu: torch.Tensor, old_log_std: torch.Tensor) -> torch.Tensor:
        new_mu, new_log_std = self.policy(states)
        old_std, new_std = old_log_std.exp(), new_log_std.exp()
        kl = (new_log_std - old_log_std) + (old_std.pow(2) + (old_mu - new_mu).pow(2)) / (2.0 * new_std.pow(2)) - 0.5
        return kl.sum(dim=-1).mean()

    def _build_hvp(self, states: torch.Tensor, old_mu: torch.Tensor, old_log_std: torch.Tensor) -> Callable[[torch.Tensor], torch.Tensor]:
        def hvp(v: torch.Tensor) -> torch.Tensor:
            kl = self._mean_kl(states, old_mu, old_log_std)
            grads = flat_grad(kl, self.policy, create_graph=True)
            gvp = (grads * v).sum()
            hv = flat_grad(gvp, self.policy, retain_graph=True)
            return hv + self.cfg.cg_damping * v
        return hvp

    def step(self, states: torch.Tensor, actions: torch.Tensor, advantages: torch.Tensor, old_logp: torch.Tensor) -> dict:
        # keep track of old policy distribution
        with torch.no_grad():
            old_mu, old_log_std = self._old_dist(states)

        # 1) compute policy gradient g
        loss_pi = self._surrogate_loss(states, actions, advantages, old_logp)
        g = flat_grad(loss_pi, self.policy)
        g = -g  # maximize original objective => minimize negative objective, flip direction here

        # if gradient is too small, return
        if torch.norm(g) < 1e-8:
            return {"step_size": 0.0, "improve": 0.0, "kl": 0.0}

        # 2) solve Hx = g using conjugate gradient, get x ≈ H^{-1} g
        hvp = self._build_hvp(states, old_mu, old_log_std)
        x = conjugate_gradients(hvp, g, nsteps=self.cfg.cg_iters)

        # 3) compute step size, ensure KL is less than max_kl theoretically
        xHx = (x * hvp(x)).sum()
        step_coef = math.sqrt(2.0 * self.cfg.max_kl / (xHx + 1e-10))
        full_step = step_coef * x

        # 4) line search, find step size that satisfies KL constraint and improves the objective
        prev_params = flat_params(self.policy)
        prev_loss = loss_pi.item()

        def set_and_eval(step: torch.Tensor) -> Tuple[float, float]:
            new_params = prev_params + step
            set_flat_params(self.policy, new_params)
            with torch.no_grad():
                kl = self._mean_kl(states, old_mu, old_log_std).item()
            loss_new = self._surrogate_loss(states, actions, advantages, old_logp).item()
            improve = prev_loss - loss_new  # old loss - new loss, the larger the better
            return improve, kl

        step = full_step
        for _ in range(self.cfg.backtrack_iters):
            improve, kl = set_and_eval(step)
            if kl <= self.cfg.max_kl and improve > 0:
                break
            step = self.cfg.backtrack_coeff * step
        else:
            # backtrack failed: restore parameters
            set_flat_params(self.policy, prev_params)
            return {"step_size": 0.0, "improve": 0.0, "kl": 0.0}

        return {"step_size": float(step.norm().item()), "improve": float(improve), "kl": float(kl)}

In [None]:
torch.manual_seed(0)
obs_dim, act_dim = 3, 2
N = 256  # batch size

policy = GaussianPolicy(obs_dim, act_dim)
updater = TRPOUpdater(policy, TRPOConfig(max_kl=1e-2))

states = torch.randn(N, obs_dim)
actions = torch.randn(N, act_dim)
advantages = torch.randn(N)

with torch.no_grad():
    dist_old = policy.dist(states)
    old_logp = dist_old.log_prob(actions).sum(-1)

info = updater.step(states, actions, advantages, old_logp)
print("TRPO step info:", info)