# PPO的代码实现
代码来自于，https://github.com/nikhilbarhate99/PPO-PyTorch

总体流程为：

    环境交互:
        ↓
    PPO.select_action(state)
        ↓
    环境返回 reward, done
        ↓
    数据存入 RolloutBuffer
        ↓
    每 N 步后:
        ↓
    PPO.update()
        ├─ 计算 G_t 和优势 A_t
        ├─ 计算 ratio = π/π_old
        ├─ 优化 min(surr1, surr2)
        ├─ 更新 Critic
        └─ 同步 policy_old ← policy


In [None]:
import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical

################################## set device ##################################
print("============================================================================================")
# set device to cpu or cuda
device = torch.device('cpu')
if(torch.cuda.is_available()): 
    device = torch.device('cuda:0') 
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")
print("============================================================================================")


################################## PPO Policy ##################################
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []
    
    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.state_values[:]
        del self.is_terminals[:]


class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, has_continuous_action_space, action_std_init):
        super(ActorCritic, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space
        
        if has_continuous_action_space:
            self.action_dim = action_dim
            self.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device)
        # actor
        if has_continuous_action_space :
            self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            nn.Tanh(),
                            nn.Linear(64, action_dim),
                            nn.Tanh()
                        )
        else:
            self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            nn.Tanh(),
                            nn.Linear(64, action_dim),
                            nn.Softmax(dim=-1)
                        )
        # critic
        self.critic = nn.Sequential(
                        nn.Linear(state_dim, 64),
                        nn.Tanh(),
                        nn.Linear(64, 64),
                        nn.Tanh(),
                        nn.Linear(64, 1)
                    )
        
    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device)
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling ActorCritic::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")

    def forward(self):
        raise NotImplementedError
    
    def act(self, state):

        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)
            dist = MultivariateNormal(action_mean, cov_mat)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)
        state_val = self.critic(state)

        return action.detach(), action_logprob.detach(), state_val.detach()
    
    def evaluate(self, state, action):

        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            
            action_var = self.action_var.expand_as(action_mean)
            cov_mat = torch.diag_embed(action_var).to(device)
            dist = MultivariateNormal(action_mean, cov_mat)
            
            # For Single Action Environments.
            if self.action_dim == 1:
                action = action.reshape(-1, self.action_dim)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)
        
        return action_logprobs, state_values, dist_entropy


class PPO:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, action_std_init=0.6):

        self.has_continuous_action_space = has_continuous_action_space

        if has_continuous_action_space:
            self.action_std = action_std_init

        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        
        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)
        self.optimizer = torch.optim.Adam([
                        {'params': self.policy.actor.parameters(), 'lr': lr_actor},
                        {'params': self.policy.critic.parameters(), 'lr': lr_critic}
                    ])

        self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        
        self.MseLoss = nn.MSELoss()

    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_std = new_action_std
            self.policy.set_action_std(new_action_std)
            self.policy_old.set_action_std(new_action_std)
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling PPO::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")

    def decay_action_std(self, action_std_decay_rate, min_action_std):
        print("--------------------------------------------------------------------------------------------")
        if self.has_continuous_action_space:
            self.action_std = self.action_std - action_std_decay_rate
            self.action_std = round(self.action_std, 4)
            if (self.action_std <= min_action_std):
                self.action_std = min_action_std
                print("setting actor output action_std to min_action_std : ", self.action_std)
            else:
                print("setting actor output action_std to : ", self.action_std)
            self.set_action_std(self.action_std)

        else:
            print("WARNING : Calling PPO::decay_action_std() on discrete action space policy")
        print("--------------------------------------------------------------------------------------------")

    def select_action(self, state):

        if self.has_continuous_action_space:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob, state_val = self.policy_old.act(state)

            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)
            self.buffer.state_values.append(state_val)

            return action.detach().cpu().numpy().flatten()
        else:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob, state_val = self.policy_old.act(state)
            
            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)
            self.buffer.state_values.append(state_val)

            return action.item()

    def update(self):
        # Monte Carlo estimate of returns
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)
            
        # Normalizing the rewards
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)
        old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device)

        # calculate advantages
        advantages = rewards.detach() - old_state_values.detach()

        # Optimize policy for K epochs
        for _ in range(self.K_epochs):

            # Evaluating old actions and values
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # match state_values tensor dimensions with rewards tensor
            state_values = torch.squeeze(state_values)
            
            # Finding the ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss  
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages

            # final loss of clipped objective PPO
            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy
            
            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            
        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        self.buffer.clear()
    
    def save(self, checkpoint_path):
        torch.save(self.policy_old.state_dict(), checkpoint_path)
   
    def load(self, checkpoint_path):
        self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        
        
       



## PPO 任务示例

我们让 GPT-2 学会生成“正向情绪”的句子。
系统提供一些简单的 **prompt（提示语）**，例如：

```text
"Write a short product review:"
"Describe your day:"
"Give quick feedback on this service:"
```

模型生成一段续写文本。
然后我们根据生成文本中的**情感词汇**给出奖励：

| 奖励规则                                                        | 示例   |
| ----------------------------------------------------------- | ---- |
| 若文本中包含 `"good"`, `"great"`, `"excellent"`, `"wonderful"` 等词 | +1.0 |
| 否则                                                          | −0.5 |

因此模型会逐渐学会偏向使用这些积极词汇。



### ⚙️ 模型结构（Architecture）

整体使用一个 **Actor–Critic** 框架：

```
        +-----------------------------+
        |        GPT-2 backbone       |
        +-----------------------------+
              ↙                   ↘
        Policy (actor)         Value (critic)
         → 输出 logits          → 线性层预测 V(s)
```

* **Actor**：GPT-2 的语言建模头，输出每个 token 的概率分布。
* **Critic**：在最后一层 hidden 上加线性层，预测状态价值 (V(s_t))。

---

### 🧩 PPO 算法要点

1. **Ratio (重要性采样比率)**
   $$r_t = \frac{\pi_\theta(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)}$$
   

2. **剪切目标 (Clipped Surrogate Objective)**
   
   $$L^{clip} = \min\big(r_t A_t,, \text{clip}(r_t,1-\epsilon,1+\epsilon) A_t\big)$$
   

3. **Critic Loss (值函数误差)**
   
   $$L^{V} = (V_\phi(s_t)-R_t)^2$$

4. **Entropy Bonus (熵正则)**
   鼓励策略保持多样性。

5. **最终目标（要最大化）**
   
   $$L = \mathbb{E}[L^{clip} - c_1 L^{V} + c_2 \mathcal{H}]$$
   实现时取负号作为 loss 进行最小化。

---

### 🧮 优势与回报计算

使用 GAE(λ)（广义优势估计）：


$$\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)$$

$$A_t = \sum_l (\gamma \lambda)^l \delta_{t+l}$$


$$R_t = A_t + V(s_t)$$

在这个 demo 中：

* 每个生成序列只有**一个终局奖励**；
* 奖励分配到最后一个 token，并通过 GAE 反向传播；
* 非生成部分（prompt）mask 掉，不参与更新。

---

### 🧰 训练流程

1. **Rollout 收集数据**

   * 用当前策略 GPT-2 生成文本。
   * 记录：input_ids, attention_mask, logits, value, logp_old。
   * 计算奖励（正向词检测）。

2. **计算优势与回报**

   * 用 critic 的值估计结合 reward 算 GAE。
   * 生成 `advantages` 和 `returns`。

3. **PPO 更新**

   * 多次小批量（minibatch）训练。
   * 计算新策略 logp，与旧策略比 ratio。
   * 进行 **剪切策略目标** + **value loss** + **entropy** 优化。
   * 使用梯度裁剪防止发散。

4. **监控指标**

   * loss 组件（policy / value / entropy）
   * KL 散度（防止策略漂移）
   * 平均奖励变化

---

### 数据整体流动逻辑总结
prompt
  ↓
生成 token_1, token_2, ... token_T
  ↓
critic 给出每步 V(s_t)
  ↓
reward 仅在最后 token 给出 (+1 或 -0.5)
  ↓
GAE 根据 V 和 r_t 把“未来好处”往前传
  ↓
得到每步优势 A_t
  ↓
policy_loss 用 A_t 更新 actor（GPT2 logits）
value_loss 用 (R_t - V(s_t))^2 更新 critic


### 📊 代码结构概览

```
├── CFG                # 参数与任务配置
├── ActorCritic        # GPT2 + Value Head 模型
├── reward_fn()        # 简易奖励函数
├── compute_logprobs() # token级log概率
├── compute_returns_advantages() # GAE
├── iterate_minibatches()        # 小批量迭代器
│
├── rollout & PPO loop           # 主训练循环
│
└── sample_text()                # 前后生成对比
```

---
t2_minimal_demo.py
```



### 🧭 结果示例（示意）

```
=== Sampling BEFORE training ===
Prompt: Write a short product review:
Write a short product review: It was a terrible day, the food was cold...

=== Sampling AFTER training ===
Prompt: Write a short product review:
Write a short product review: The service was great and the food was excellent!
```





### 参数与任务配置

In [None]:
# ppo_gpt2_minimal_demo.py
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Tuple

import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import GPT2LMHeadModel, GPT2TokenizerFast as GPT2Tokenizer

# ===================== 配置 =====================
@dataclass
class CFG:
    model_name: str = "gpt2"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    seed: int = 42

    # rollout
    prompts: List[str] = (
        "Write a short product review:",
        "Describe your day:",
        "Tell me something about your favorite movie:",
        "How was your meal?",
        "Give quick feedback on this service:",
    )
    max_new_tokens: int = 24
    eos_token_id: int = None   # 自动从tokenizer里拿
    temperature: float = 1.0
    top_k: int = 50

    # PPO
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_eps: float = 0.2
    value_clip_eps: float = 0.2
    c1_value: float = 0.5
    c2_entropy: float = 0.01
    lr: float = 1e-5
    max_grad_norm: float = 0.5

    # 训练
    rollout_batch_size: int = 6     # 每轮收集多少条序列
    ppo_epochs: int = 3             # 对同一批数据做几次小批量更新
    minibatch_size: int = 3
    train_steps: int = 5            # 总共做多少个 PPO outer steps（演示：很小）
    print_every: int = 1

    # 简单奖励词表（命中加分）
    pos_words: Tuple[str,...] = ("good", "great", "excellent", "wonderful")

cfg = CFG()

# ===================== 随机种子 =====================
random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
torch.cuda.manual_seed_all(cfg.seed)

# ===================== 模型 & tokenizer =====================
tokenizer = GPT2Tokenizer.from_pretrained(cfg.model_name)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
cfg.eos_token_id = tokenizer.eos_token_id

### 奖励函数实现

In [None]:
# 简易奖励函数：生成文本包含正向词则高奖励，否则低奖励
def reward_fn(texts: List[str], pos_words: Tuple[str,...]) -> torch.Tensor:
    rewards = []
    for t in texts:
        t_low = t.lower()
        matched = any(w in t_low for w in pos_words)
        rewards.append(1.0 if matched else -0.5)
    return torch.tensor(rewards, dtype=torch.float32)

# 把序列末端奖励分配到动作token上，并做 GAE
def compute_returns_advantages(
    rewards: torch.Tensor,              # [B] 每条序列一个终局奖励
    values: torch.Tensor,               # [B, T] 每个位置的V(s_t)
    dones: torch.Tensor,                # [B] 终止标志（这里全1）
    action_mask: torch.Tensor,          # [B, T] 只对生成部分为1，prompt部分为0
    gamma: float,
    lam: float
) -> Tuple[torch.Tensor, torch.Tensor]:
    # 我们让每条序列的最后一个 action token 接收 reward，其它 action token 也能通过 GAE 回传
    B, T = values.shape
    device = values.device

    # 构建 per-step reward：默认为0，只在每条序列最后一个有效 action 位置放入终局奖励
    step_rewards = torch.zeros_like(values)
    last_idx = (action_mask.cumsum(dim=1) == action_mask.sum(dim=1, keepdim=True)).long().argmax(dim=1)  # 每条序列最后一个action索引
    for b in range(B):
        step_rewards[b, last_idx[b]] = rewards[b]

    # 计算 deltas 与 GAE，仅在 action_mask==1 的位置有效
    deltas = step_rewards + gamma * F.pad(values[:, 1:], (0,1)) * (1.0 - 0.0) - values
    deltas = deltas * action_mask

    advantages = torch.zeros_like(values)
    last_gae = torch.zeros(B, device=device)
    for t in reversed(range(T)):
        mask_t = action_mask[:, t]
        delta_t = deltas[:, t]
        last_gae = delta_t + gamma * lam * last_gae
        advantages[:, t] = last_gae * mask_t  # 非action位置为0

        # 在非action位置维持 last_gae 不被无意义地传播
        last_gae = last_gae * (mask_t > 0).float() + last_gae * (mask_t == 0).float()

    returns = advantages + values
    return returns.detach(), advantages.detach()

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, base_model: GPT2LMHeadModel):
        super().__init__()
        self.actor = base_model
        hidden = base_model.config.n_embd
        self.value_head = nn.Linear(hidden, 1)

    def forward(self, input_ids, attention_mask=None):
        # 为了拿到hidden states，打开输出
        out = self.actor(input_ids=input_ids,
                         attention_mask=attention_mask,
                         output_hidden_states=True,
                         return_dict=True)
        logits = out.logits                                  # [B, T, V]
        last_hidden = out.hidden_states[-1]                  # [B, T, H]
        values = self.value_head(last_hidden).squeeze(-1)    # [B, T]
        return logits, values

    def generate(self, input_ids, attention_mask, max_new_tokens, temperature=1.0, top_k=50, eos_token_id=None):
        # 手写 sampling（也可用 model.generate，这里为了更清晰拿logprob）
        self.eval()
        B = input_ids.size(0)
        cur_input_ids = input_ids
        cur_attn = attention_mask
        with torch.no_grad():
            for _ in range(max_new_tokens):
                logits, _ = self.forward(cur_input_ids, cur_attn)
                next_logits = logits[:, -1, :] / max(1e-6, temperature)
                if top_k is not None and top_k > 0:
                    topv, topi = torch.topk(next_logits, top_k)
                    mask = next_logits < topv[:, -1].unsqueeze(-1)
                    next_logits = next_logits.masked_fill(mask, -float("inf"))
                probs = F.softmax(next_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)  # [B,1]
                cur_input_ids = torch.cat([cur_input_ids, next_token], dim=1)
                cur_attn = torch.ones_like(cur_input_ids, device=cur_input_ids.device)
                if eos_token_id is not None:
                    # 如果所有样本都产生了eos，提前终止
                    if (next_token.squeeze(-1) == eos_token_id).all():
                        break
        self.train()
        return cur_input_ids, cur_attn

# logprob 工具
def compute_logprobs(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    # logits: [B,T,V], labels: [B,T]，返回每个位置 token 的 log p(token)
    logp = F.log_softmax(logits, dim=-1)
    lp = logp.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)  # [B,T]
    return lp


# mini-batch 迭代器
def iterate_minibatches(batch: Dict[str, torch.Tensor], mb_size: int):
    N = batch["input_ids"].size(0)
    idx = torch.randperm(N)
    for start in range(0, N, mb_size):
        mb_idx = idx[start : start + mb_size]
        yield {k: v[mb_idx] for k, v in batch.items()}

# ===================== 构建模型与优化器 =====================
base = GPT2LMHeadModel.from_pretrained(cfg.model_name)
model = ActorCritic(base).to(cfg.device)
optim = torch.optim.AdamW(model.parameters(), lr=cfg.lr)



# ===================== PPO 训练循环 =====================
for step in range(1, cfg.train_steps + 1):
    # -------- Rollout 收集 on-policy 轨迹 --------
    batch_prompts = [random.choice(cfg.prompts) for _ in range(cfg.rollout_batch_size)]
    enc = tokenizer(batch_prompts, return_tensors="pt", padding=True).to(cfg.device)

    with torch.no_grad():
        # 旧策略下采样序列
        full_ids, full_attn = model.generate(
            enc.input_ids, enc.attention_mask,
            max_new_tokens=cfg.max_new_tokens,
            temperature=cfg.temperature,
            top_k=cfg.top_k,
            eos_token_id=cfg.eos_token_id,
        )

        # 计算旧策略的 logits/values 和 logp_old
        logits_old, values_old = model(full_ids, attention_mask=full_attn)
        # labels = 下一个token（自回归），用 -100 屏蔽最后一个位置
        labels = full_ids[:, 1:].contiguous()
        logits_old_trim = logits_old[:, :-1, :].contiguous()
        values_old_trim = values_old[:, :-1].contiguous()
        attn_trim = full_attn[:, :-1].contiguous()
        logp_old = compute_logprobs(logits_old_trim, labels)

    # 构造 action_mask：只对生成出的新token部分为1（去掉prompt token部分）
    prompt_len = (enc.attention_mask.sum(dim=1)).to(full_ids.device)  # [B]
    T = labels.size(1)
    action_mask = torch.zeros_like(values_old_trim, dtype=torch.float32)
    for b in range(cfg.rollout_batch_size):
        # 可训练的 token 范围：从 prompt_len[b]-1 起到末尾（因为 labels 对应的是预测第 t+1 个 token）
        start = int(prompt_len[b].item()) - 1
        if start < 0: start = 0
        action_mask[b, start:] = 1.0

    # 文本解码用于奖励
    decoded = tokenizer.batch_decode(full_ids, skip_special_tokens=True)
    rewards = reward_fn(decoded, cfg.pos_words).to(cfg.device)          # [B]
    dones = torch.ones_like(rewards)

    # 计算 returns & advantages（只在 action_mask==1 的位置）
    returns, advantages = compute_returns_advantages(
        rewards=rewards,
        values=values_old_trim,
        dones=dones,
        action_mask=action_mask,
        gamma=cfg.gamma,
        lam=cfg.gae_lambda,
    )

    # --------- 把 rollout 打包为训练 batch ----------
    train_batch = {
        "input_ids": full_ids[:, :-1].detach(),          # 和 logits_old_trim 对齐
        "attention_mask": full_attn[:, :-1].detach(),
        "labels": labels.detach(),
        "logp_old": logp_old.detach(),
        "values_old": values_old_trim.detach(),
        "returns": returns.detach(),
        "advantages": advantages.detach(),
        "action_mask": action_mask.detach(),
    }

    # --------- PPO 多轮小批量更新 ----------
    for epoch in range(cfg.ppo_epochs):
        for mb in iterate_minibatches(train_batch, cfg.minibatch_size):
            logits, values = model(mb["input_ids"].to(cfg.device),
                                   attention_mask=mb["attention_mask"].to(cfg.device))
            # 对齐
            logits = logits
            values = values

            # 重新计算新策略的 logp
            logp_new = compute_logprobs(logits, mb["labels"].to(cfg.device))

            # ratio & 策略损失（clip）
            ratio = torch.exp(logp_new - mb["logp_old"].to(cfg.device))                   # [B,T]
            adv = mb["advantages"].to(cfg.device)
            # 标准化优势（仅对 action_mask 生效）
            mask = mb["action_mask"].to(cfg.device)
            adv_masked = adv[mask > 0]
            adv_norm = (adv - adv_masked.mean()) / (adv_masked.std() + 1e-8)

            pg1 = ratio * adv_norm
            pg2 = torch.clamp(ratio, 1.0 - cfg.clip_eps, 1.0 + cfg.clip_eps) * adv_norm
            policy_loss = -(torch.min(pg1, pg2) * mask).sum() / (mask.sum() + 1e-8)

            # value 损失（带 value clipping）
            values_old = mb["values_old"].to(cfg.device)
            returns = mb["returns"].to(cfg.device)
            v_clipped = values_old + (values - values_old).clamp(-cfg.value_clip_eps, cfg.value_clip_eps)
            v_loss1 = (values - returns) ** 2
            v_loss2 = (v_clipped - returns) ** 2
            value_loss = 0.5 * torch.max(v_loss1, v_loss2)
            value_loss = (value_loss * mask).sum() / (mask.sum() + 1e-8)

            # 熵奖励（越大越好，这里用负号并入loss）
            with torch.no_grad():
                probs = F.softmax(logits, dim=-1)
            log_probs = F.log_softmax(logits, dim=-1)
            entropy = -(probs * log_probs).sum(dim=-1)          # [B,T]
            entropy_loss = -(entropy * mask).sum() / (mask.sum() + 1e-8)

            loss = policy_loss + cfg.c1_value * value_loss + cfg.c2_entropy * entropy_loss

            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
            optim.step()

    if step % cfg.print_every == 1:
        with torch.no_grad():
            # 监控一下 KL（经验上监控旧→新）
            kl = (mb["logp_old"].to(cfg.device) - logp_new).mean().item()
            avg_ret = rewards.mean().item()
            print(f"[Step {step}] loss={loss.item():.4f}  policy={policy_loss.item():.4f}  "
                  f"value={value_loss.item():.4f}  entropy={-entropy_loss.item():.4f}  "
                  f"KL={kl:.4f}  avg_reward={avg_ret:.3f}")


  from .autonotebook import tqdm as notebook_tqdm


=== Sampling BEFORE training ===

[Before] Prompt: Write a short product review:
Write a short product review:

Product Reviews Write about the products of the following company: CXV, General Dynamics, Intel

This


[Before] Prompt: Describe your day:
Describe your day: What was it like for you? A? B ? I. B a? Or C ? I ? E ? I


=== Sampling AFTER training ===

[Before] Prompt: Write a short product review:
Write a short product review: www.goodreads.com/reviews/lulu-coco-coco-no...


[Before] Prompt: Describe your day:
Describe your day: what does the man know? What is that one thing?

Do you feel like a man? What is life


[Before] Prompt: Tell me something about your favorite movie:
Tell me something about your favorite movie: "I'm like a woman, you know." There's an intense intensity to the character, and if you're young


[Before] Prompt: How was your meal?
How was your meal? How many bottles are there? I've found to be more expensive than you're probably expecting.

My first nigh

### 训练结果观察

In [2]:
# ===================== 训练前：观察一次生成 =====================
def sample_text(prompts: List[str], num=2):
    model.eval()
    with torch.no_grad():
        for p in prompts[:num]:
            enc = tokenizer(p, return_tensors="pt").to(cfg.device)
            out_ids, _ = model.generate(
                enc.input_ids, enc.attention_mask,
                max_new_tokens=cfg.max_new_tokens,
                temperature=cfg.temperature,
                top_k=cfg.top_k,
                eos_token_id=cfg.eos_token_id,
            )
            text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
            print(f"\n[Before] Prompt: {p}\n{text}\n")
    model.train()

print("=== Sampling BEFORE training ===")
sample_text(cfg.prompts, num=2)


# ===================== 训练后：再次采样对比 =====================
print("\n=== Sampling AFTER training ===")
sample_text(cfg.prompts, num=4)

# 单条演示：看看奖励词是否更常出现
demo_prompt = "Write a short product review:"
enc = tokenizer(demo_prompt, return_tensors="pt").to(cfg.device)
out_ids, _ = model.generate(
    enc.input_ids, enc.attention_mask,
    max_new_tokens=cfg.max_new_tokens,
    temperature=cfg.temperature,
    top_k=cfg.top_k,
    eos_token_id=cfg.eos_token_id,
)
print("[Final Demo]")
print(tokenizer.decode(out_ids[0], skip_special_tokens=True))


=== Sampling BEFORE training ===

[Before] Prompt: Write a short product review:
Write a short product review: https://www.groupon.org/medicine/apparel

Donate to Planned Parenthood - http://


[Before] Prompt: Describe your day:
Describe your day:

Monday, September 13

The first thing you should know is that, without an actual work day in the


=== Sampling AFTER training ===

[Before] Prompt: Write a short product review:
Write a short product review:

Review a short product review:

Write a short product review:

Write a short product review:


[Before] Prompt: Describe your day:
Describe your day: In my opinion, this is a recipe for good wine.  It tastes like nothing but sherry.
So far


[Before] Prompt: Tell me something about your favorite movie:
Tell me something about your favorite movie: One where you have children, and you play in an army?

Advertisement

A: This was a great


[Before] Prompt: How was your meal?
How was your meal?

We ate from lunch with friends in Houston. 