# GRPO 教学（手写公式 + 最小实现，无 LoRA）

本 Notebook 目标：像之前的 DPO/PPO 教学一样，用 PyTorch + Transformers 手写 GRPO（Group Relative Policy Optimization）的关键计算：

- 写清楚公式
- 代码变量名/计算过程能对应公式
- 不使用 TRL 这类现成 RLHF Trainer

注意：无 LoRA 表示全参更新，对显存要求更高；如果你只有 8GB，务必把 `group_size/max_new_tokens/grpo_epochs` 调小。


## 1. GRPO 的核心思想与公式

GRPO 可以理解为：**PPO 的一个“去 critic”变体**。

- PPO 通常需要 critic（$V(s)$）来估计优势函数 $A_t$（如 GAE）。
- GRPO 用“同一 prompt 下的一组采样结果”做相对基线：在一个 group 内把奖励做中心化/标准化，得到序列级 advantage，从而不需要训练价值网络。

### 1.1 采样与奖励

对一个 prompt $x$，用旧策略（rollout 时的策略）$\pi_{\theta_{old}}$ 采样 $G$ 个回答：

$$
y_i \sim \pi_{\theta_{old}}(\cdot\mid x),\quad i=1,\dots,G
$$

每个回答 $y_i$ 有 token 序列 $a_{i,1:T_i}$，并得到一个**序列级**奖励（来自 RM 或规则）：

$$
r_i = r(x, y_i)
$$

### 1.2 Group-relative Advantage（去掉 critic 的关键）

在同一个 group 内做标准化（也可以只做中心化）：

$$
\mu = \frac{1}{G}\sum_{i=1}^{G} r_i,\quad \sigma = \sqrt{\frac{1}{G}\sum_{i=1}^{G}(r_i-\mu)^2}+\varepsilon
$$

$$
A_i = \frac{r_i-\mu}{\sigma}
$$

这里 $A_i$ 是**序列级** advantage（对同一个回答的所有 token 都一样）。


在纯粹的 GRPO 算法设计中，没有任何针对“中间推导过程是否正确”的奖励机制。

GRPO 是一种极度纯粹的 **“结果导向（Outcome-supervised）”** 算法。

### 1.3 PPO 的 ratio 与 clipped surrogate（token 级）

每个 token 的对数概率：

$$
\log\pi_{\theta}(a_{i,t}\mid s_{i,t}),\quad s_{i,t}=(x, a_{i,<t})
$$

概率比值（PPO 核心）：

$$
\rho_{i,t}(\theta)=\frac{\pi_{\theta}(a_{i,t}\mid s_{i,t})}{\pi_{\theta_{old}}(a_{i,t}\mid s_{i,t})}
=\exp\big(\log\pi_{\theta}-\log\pi_{\theta_{old}}\big)
$$

clipped surrogate（对每个 token）：

$$
L^{clip}_{i,t}(\theta)=\min\Big(\rho_{i,t}(\theta)A_i,\;\mathrm{clip}(\rho_{i,t}(\theta),1-\epsilon,1+\epsilon)A_i\Big)
$$

对 token 平均、对 group 平均：

$$
L^{clip}(\theta)=\frac{1}{G}\sum_{i=1}^{G}\frac{1}{T_i}\sum_{t=1}^{T_i} L^{clip}_{i,t}(\theta)
$$

### 1.4 KL 约束（对参考策略 $\pi_{ref}$）

常见 RLHF 会加 KL 约束让策略不要偏离参考模型（冻结）：

$$
\widehat{KL}_{i,t}=\log\pi_{\theta}(a_{i,t}\mid s_{i,t})-\log\pi_{ref}(a_{i,t}\mid s_{i,t})
$$

把 KL 作为惩罚项（系数 $\beta$）：

$$
J(\theta)=L^{clip}(\theta) - \beta\cdot \frac{1}{G}\sum_{i=1}^{G}\frac{1}{T_i}\sum_{t=1}^{T_i}\widehat{KL}_{i,t}
$$

训练时最小化 loss：

$$
\mathcal{L}(\theta) = -J(\theta)
$$

下面代码会用同名变量：`rewards/advantages/logpi_old/logpi/logpi_ref/ratio/clip_eps/kl_coef` 对应上面公式逐项实现。


## 2. 环境与模型加载（建议离线）

- 建议 `conda activate llm`
- 默认优先用 `MODELSCOPE_CACHE` 下的本地模型目录（避免联网）
- 参考模型默认放 CPU（更省显存，但更慢；显存够可放 GPU）


In [1]:
import os
import random
import sys
from pathlib import Path
from typing import Any, Dict, List, Tuple

import torch
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ.setdefault("MODELSCOPE_CACHE", r"D:/myProject/modelscope_hub")
print("python:", sys.executable)
print("torch:", torch.__version__, "cuda:", torch.cuda.is_available())
print("MODELSCOPE_CACHE:", os.environ["MODELSCOPE_CACHE"])

seed = 42
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = "cuda" if torch.cuda.is_available() else "cpu"
device


python: e:\Softwares\anaconda3\envs\llm\python.exe
torch: 2.10.0+cu126 cuda: True
MODELSCOPE_CACHE: D:/myProject/modelscope_hub


'cuda'

In [2]:
# 选择模型（优先本地缓存目录）
local_dir = Path(os.environ["MODELSCOPE_CACHE"]) / "models" / "qwen" / "Qwen2-0___5B-Instruct"
model_name_or_path = str(local_dir) if local_dir.exists() else "qwen/Qwen2-0.5B-Instruct"
print("model_name_or_path:", model_name_or_path)

if device == "cuda":
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
    dtype = torch.float32
print("dtype:", dtype)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = "<|endoftext|>"
tokenizer.padding_side = "right"

SYSTEM_PROMPT = "You are a helpful assistant."

# 可训练策略模型 π_θ
policy = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype).to(device)
policy.config.use_cache = False
policy.config.pad_token_id = tokenizer.pad_token_id

if hasattr(policy, "gradient_checkpointing_enable") and device == "cuda":
    policy.gradient_checkpointing_enable()  # 省显存

# 参考策略 π_ref（冻结）
ref_device = "cpu"  # 显存够可以改成 "cuda"
ref_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32).to(ref_device)
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad_(False)

print("ready")


model_name_or_path: D:\myProject\modelscope_hub\models\qwen\Qwen2-0___5B-Instruct
dtype: torch.bfloat16


`torch_dtype` is deprecated! Use `dtype` instead!


ready


## 3. Toy 任务与规则奖励 $r(x,y)$

真实 GRPO/RLHF 会用 reward model（RM）。教学版这里用规则奖励：检查输出是否包含期望字符串。

奖励（可自行改）：

- 命中期望：+1
- 不命中：-1
- 长度轻微惩罚（鼓励更短）


In [3]:
train_tasks = [
    {"prompt": "czq是谁", "expected": "czq是神！"},
    {"prompt": "请只输出数字 4，不要额外文字。2+2等于几？", "expected": "4"},
    {"prompt": "把“我喜欢机器学习”翻译成英文，只输出翻译。", "expected": "I like machine learning"},
    {"prompt": "请只回答：通义千问", "expected": "通义千问"},
]

def normalize_text(s: str) -> str:
    s = s.strip().lower()
    for ch in [" ", "\n", "\t", "。", "，", ",", ".", "!", "?", "：", ":", "\"", "'"]:
        s = s.replace(ch, "")
    return s

def rule_reward(prompt: str, response: str, expected: str) -> float:
    resp = normalize_text(response)
    exp = normalize_text(expected)
    hit = exp in resp
    base = 1.0 if hit else -1.0
    length_penalty = 0.002 * len(resp)
    return base - length_penalty

train_tasks[0]


{'prompt': 'czq是谁', 'expected': 'czq是神！'}

## 4. 关键实现：采样 group、计算 logprob/ratio、GRPO loss

实现要点：

- `logπ(a|s)` 用 `log_softmax(logits)` + `gather` 手写（贴公式）
- `logpi_old` 在 rollout 时固定
- `logpi_ref` 用冻结参考模型计算
- advantage 用 group 标准化得到 `A_i`，再广播到每个 token


In [4]:
@torch.inference_mode()
def build_prompt_input_ids(tok: Any, user_prompt: str) -> torch.Tensor:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt},
    ]
    text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    enc = tok(text, return_tensors="pt")
    return enc["input_ids"][0]

@torch.inference_mode()
def sample_group(
    m: torch.nn.Module,
    tok: Any,
    prompt_ids: torch.Tensor,
    group_size: int,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
) -> List[Dict[str, Any]]:
    m.eval()
    prompt_len = int(prompt_ids.numel())
    input_ids = prompt_ids.unsqueeze(0).to(device)

    old_use_cache = getattr(m.config, "use_cache", True)
    m.config.use_cache = True
    out = m.generate(
        input_ids=input_ids,
        num_return_sequences=group_size,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        max_new_tokens=max_new_tokens,
        eos_token_id=tok.eos_token_id,
        pad_token_id=tok.pad_token_id,
    )
    m.config.use_cache = old_use_cache

    seqs = out.detach().cpu()
    samples: List[Dict[str, Any]] = []
    for seq in seqs:
        resp = seq[prompt_len:]
        # trim at eos/pad
        end = int(resp.numel())
        for j, tid in enumerate(resp.tolist()):
            if tid == tok.eos_token_id or tid == tok.pad_token_id:
                end = j
                break
        resp_ids = resp[:end]
        if resp_ids.numel() == 0:
            continue
        full_ids = torch.cat([prompt_ids, resp_ids], dim=0)
        text = tok.decode(resp_ids, skip_special_tokens=True)
        samples.append(
            {
                "full_ids": full_ids,
                "response_ids": resp_ids,
                "response_text": text,
                "response_len": int(resp_ids.numel()),
            }
        )
    return samples

def pad_group(prompt_ids: torch.Tensor, samples: List[Dict[str, Any]], pad_id: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
    prompt_len = int(prompt_ids.numel())
    max_resp_len = max(s["response_len"] for s in samples)
    B = len(samples)

    input_ids = torch.full((B, prompt_len + max_resp_len), pad_id, dtype=torch.long)
    attention_mask = torch.zeros_like(input_ids, dtype=torch.long)
    response_mask = torch.zeros((B, max_resp_len), dtype=torch.bool)

    for i, s in enumerate(samples):
        resp = s["response_ids"]
        rlen = int(resp.numel())
        input_ids[i, :prompt_len] = prompt_ids
        input_ids[i, prompt_len : prompt_len + rlen] = resp
        attention_mask[i, : prompt_len + rlen] = 1
        response_mask[i, :rlen] = True

    return input_ids, attention_mask, response_mask, prompt_len, max_resp_len

def action_logp_entropy_from_logits(
    logits: torch.Tensor,
    input_ids: torch.Tensor,
    prompt_len: int,
    max_resp_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """返回 response token 的 (logπ, entropy)。

    logits: (B, L, V)
    input_ids: (B, L)
    """
    log_probs = F.log_softmax(logits, dim=-1)
    # 位置 p-1 预测 input_ids[p]
    log_probs_slice = log_probs[:, prompt_len - 1 : prompt_len - 1 + max_resp_len, :]
    action_ids = input_ids[:, prompt_len : prompt_len + max_resp_len]

    token_logp = log_probs_slice.gather(dim=-1, index=action_ids.unsqueeze(-1)).squeeze(-1)
    entropy = -(log_probs_slice.exp() * log_probs_slice).sum(dim=-1)
    return token_logp, entropy

def grpo_loss(
    logpi: torch.Tensor,
    logpi_old: torch.Tensor,
    logpi_ref: torch.Tensor,
    advantages: torch.Tensor,
    response_mask: torch.Tensor,
    clip_eps: float,
    kl_coef: float,
    ent_coef: float,
    entropy: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
    """对应：
    ratio = exp(logπ - logπ_old)
    L_clip = mean(min(ratio*A, clip(ratio)*A))
    loss = -L_clip + kl_coef*mean(logπ - logπ_ref) - ent_coef*mean(entropy)
    """
    mask = response_mask.to(device=logpi.device, dtype=logpi.dtype)
    denom = mask.sum().clamp_min(1.0)

    A = advantages.to(dtype=logpi.dtype).unsqueeze(1)
    ratio = torch.exp(logpi - logpi_old)
    surr1 = ratio * A
    surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * A
    clipped_obj = torch.min(surr1, surr2)
    policy_loss = -(clipped_obj * mask).sum() / denom

    kl = logpi - logpi_ref
    kl_mean = (kl * mask).sum() / denom
    kl_loss = kl_coef * kl_mean

    entropy_mean = (entropy * mask).sum() / denom
    loss = policy_loss + kl_loss - ent_coef * entropy_mean

    clip_frac = (((ratio - 1.0).abs() > clip_eps).to(logpi.dtype) * mask).sum() / denom

    metrics = {
        "policy_loss": policy_loss.detach(),
        "kl_mean": kl_mean.detach(),
        "entropy": entropy_mean.detach(),
        "clip_frac": clip_frac.detach(),
        "loss": loss.detach(),
    }
    return loss, metrics


## 5. GRPO 训练循环（教学版）

每一步训练：

1) 选一个 prompt，采样 `group_size` 个回答
2) 用规则得到每个回答的 reward：`rewards[i] = r(x,y_i)`
3) 在 group 内标准化得到 `advantages[i]`
4) 计算 `logpi_old`（固定）与 `logpi_ref`（固定）
5) 做 `grpo_epochs` 次更新：用当前模型算 `logpi`，计算 `ratio`、clipped objective、KL penalty，反向传播


In [5]:
# 超参（教学默认）
train_steps = 30
group_size = 4
max_new_tokens = 48
temperature = 1.0
top_p = 0.9

clip_eps = 0.2
kl_coef = 0.05
ent_coef = 0.0
grpo_epochs = 2

lr = 2e-6  # 全参更新，建议很小
max_grad_norm = 1.0

optimizer = torch.optim.AdamW(policy.parameters(), lr=lr)

use_amp = device == "cuda"
use_bf16 = use_amp and torch.cuda.is_bf16_supported()
autocast_dtype = torch.bfloat16 if use_bf16 else torch.float16
autocast_device_type = "cuda" if use_amp else "cpu"
scaler = torch.cuda.amp.GradScaler(enabled=use_amp and not use_bf16)

@torch.inference_mode()
def compute_logpi_ref(
    ref: torch.nn.Module,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    prompt_len: int,
    max_resp_len: int,
) -> torch.Tensor:
    ids = input_ids.to(ref_device)
    mask = attention_mask.to(ref_device)
    logits = ref(input_ids=ids, attention_mask=mask, use_cache=False, return_dict=True).logits
    logp, _ = action_logp_entropy_from_logits(logits, ids, prompt_len, max_resp_len)
    return logp.to(device)

print("start training")
for step in range(train_steps):
    task = random.choice(train_tasks)
    prompt = task["prompt"]
    expected = task["expected"]

    prompt_ids = build_prompt_input_ids(tokenizer, prompt)
    samples = sample_group(policy, tokenizer, prompt_ids, group_size, max_new_tokens, temperature, top_p)
    if len(samples) < 2:
        print(f"step={step} skip(group too small)")
        continue

    rewards = [rule_reward(prompt, s["response_text"], expected) for s in samples]
    rewards_t = torch.tensor(rewards, device=device, dtype=torch.float32)
    advantages = (rewards_t - rewards_t.mean()) / (rewards_t.std(unbiased=False).clamp_min(1e-6))

    input_ids, attention_mask, response_mask, prompt_len, max_resp_len = pad_group(
        prompt_ids, samples, pad_id=tokenizer.pad_token_id
    )
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    # logπ_old（rollout 时固定）
    with torch.inference_mode():
        policy.eval()
        logits_old = policy(input_ids=input_ids, attention_mask=attention_mask, use_cache=False, return_dict=True).logits
        logpi_old, _ = action_logp_entropy_from_logits(logits_old, input_ids, prompt_len, max_resp_len)
        logpi_old = logpi_old.detach()

    # logπ_ref（冻结参考模型）
    logpi_ref = compute_logpi_ref(ref_model, input_ids, attention_mask, prompt_len, max_resp_len).detach()

    # 多轮 GRPO 更新
    policy.train()
    for ep in range(grpo_epochs):
        with torch.autocast(device_type=autocast_device_type, dtype=autocast_dtype, enabled=use_amp):
            logits = policy(input_ids=input_ids, attention_mask=attention_mask, use_cache=False, return_dict=True).logits
            logpi, entropy = action_logp_entropy_from_logits(logits, input_ids, prompt_len, max_resp_len)
            loss, m = grpo_loss(
                logpi=logpi,
                logpi_old=logpi_old,
                logpi_ref=logpi_ref,
                advantages=advantages,
                response_mask=response_mask,
                clip_eps=clip_eps,
                kl_coef=kl_coef,
                ent_coef=ent_coef,
                entropy=entropy,
            )

        optimizer.zero_grad(set_to_none=True)
        if scaler.is_enabled():
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            clip_grad_norm_(policy.parameters(), max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            clip_grad_norm_(policy.parameters(), max_grad_norm)
            optimizer.step()

    # 日志
    print(
        f"step={step} reward_mean={rewards_t.mean().item():.3f} reward_std={rewards_t.std(unbiased=False).item():.3f} "
        f"loss={m['loss'].item():.4f} policy={m['policy_loss'].item():.4f} kl={m['kl_mean'].item():.4f} "
        f"clip_frac={m['clip_frac'].item():.3f}"
    )
    # 打印一个样本
    best_i = int(torch.argmax(rewards_t).item())
    print("prompt:", prompt)
    print("best_reward:", float(rewards_t[best_i].item()), "response:", samples[best_i]["response_text"])
    print("-" * 60)

print("done")


start training


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp and not use_bf16)


step=0 reward_mean=-1.104 reward_std=0.029 loss=0.2606 policy=0.2618 kl=-0.0246 clip_frac=0.033
prompt: czq是谁
best_reward: -1.0679999589920044 response: 对不起，我需要更多的信息才能回答这个问题。你可以告诉我你需要什么帮助吗？
------------------------------------------------------------
step=1 reward_mean=-1.113 reward_std=0.034 loss=0.2399 policy=0.2400 kl=-0.0014 clip_frac=0.007
prompt: czq是谁
best_reward: -1.0720000267028809 response: 抱歉，我不太明白你的意思。能否提供更多的背景信息或细节？我将尽我所能来帮助你。
------------------------------------------------------------
step=2 reward_mean=0.456 reward_std=0.872 loss=-0.0072 policy=-0.0071 kl=-0.0020 clip_frac=0.000
prompt: 把“我喜欢机器学习”翻译成英文，只输出翻译。
best_reward: 0.9599999785423279 response: "I like machine learning"
------------------------------------------------------------
step=3 reward_mean=0.992 reward_std=0.004 loss=0.4025 policy=0.4030 kl=-0.0118 clip_frac=0.000
prompt: 请只输出数字 4，不要额外文字。2+2等于几？
best_reward: 0.9980000257492065 response: 4
------------------------------------------------------------
step=4

## 6. 简单验证

训练后随机看几个 prompt 的输出。


In [7]:
@torch.inference_mode()
def chat(prompt: str, max_new_tokens: int = 64) -> str:
    prompt_ids = build_prompt_input_ids(tokenizer, prompt)
    input_ids = prompt_ids.unsqueeze(0).to(device)
    old_use_cache = getattr(policy.config, "use_cache", True)
    policy.config.use_cache = True
    out = policy.generate(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )
    policy.config.use_cache = old_use_cache

    full = out[0].detach().cpu()
    resp = full[prompt_ids.numel():]
    end = int(resp.numel())
    for j, tid in enumerate(resp.tolist()):
        if tid == tokenizer.eos_token_id or tid == tokenizer.pad_token_id:
            end = j
            break
    resp_ids = resp[:end]
    return tokenizer.decode(resp_ids, skip_special_tokens=True)

for t in train_tasks:
    print("Q:", t["prompt"])
    print("A:", chat(t["prompt"]))
    print()


Q: czq是谁
A: 抱歉，从现在很多人都曾几天朗格法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法

Q: 请只输出数字 4，不要额外文字。2+2等于几？
A: 2000000000000000000000000000000000000000000000000000000000000000

Q: 把“我喜欢机器学习”翻译成英文，只输出翻译。
A: "I have the (::synchronized String
法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法

Q: 请只回答：通义千问
A: 通法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法法

