# GRPO 教学（使用外部库 TRL，无 LoRA）

本 Notebook 目标：用 HuggingFace **TRL** 的 `GRPOTrainer` 跑通一个最小版 GRPO 训练流程（奖励用规则函数），并把 TRL 的关键参数和 GRPO 公式一一对应起来。

- 手写版（更贴公式、更透明）见：`GRPO_tutorial_no_lora.ipynb`
- 本版更像“工程实践写法”：把采样/ratio/clip/KL 等细节交给 TRL


## 1. GRPO 公式（并映射到 TRL 参数）

对一个 prompt $x$，从旧策略采样一组回答（group size 为 $G$）：

$$
y_i \sim \pi_{\theta_{old}}(\cdot\mid x),\quad i=1,\dots,G
$$

得到序列级奖励 $r_i=r(x,y_i)$，在 group 内做标准化得到相对优势（无需 critic）：

$$
\mu=\frac{1}{G}\sum_{i=1}^{G}r_i,\quad \sigma=\sqrt{\frac{1}{G}\sum_{i=1}^{G}(r_i-\mu)^2}+\varepsilon,\quad A_i=\frac{r_i-\mu}{\sigma}
$$

对 token 级概率比值（PPO 核心）：

$$
\rho_{i,t}(\theta)=\frac{\pi_{\theta}(a_{i,t}\mid s_{i,t})}{\pi_{\theta_{old}}(a_{i,t}\mid s_{i,t})}=\exp(\log\pi_{\theta}-\log\pi_{\theta_{old}})
$$

clipped surrogate：

$$
L^{clip}_{i,t}(\theta)=\min\Big(\rho_{i,t}(\theta)A_i,\;\mathrm{clip}(\rho_{i,t}(\theta),1-\epsilon,1+\epsilon)A_i\Big)
$$

再加 KL 约束（参考策略 $\pi_{ref}$）：

$$
\widehat{KL}_{i,t}=\log\pi_{\theta}(a_{i,t}\mid s_{i,t})-\log\pi_{ref}(a_{i,t}\mid s_{i,t})
$$

最终目标（最大化）：

$$
J(\theta)=\mathbb{E}[L^{clip}(\theta)]-\beta\,\mathbb{E}[\widehat{KL}]
$$

训练时最小化 $\mathcal{L}(\theta)=-J(\theta)$。

### TRL 参数对照

- `num_generations` $\leftrightarrow G$（每个 prompt 采样多少个回答）
- `clip_range`（或同义字段）$\leftrightarrow \epsilon$
- `beta`（或同义字段）$\leftrightarrow \beta$（KL 系数）
- `max_completion_length`（或 `max_new_tokens`）控制每条回答长度

下面我们会：构造 prompts 数据集、写一个 `reward_func` 作为 $r(x,y)$，其余交给 TRL。

## 2. 环境准备

你使用的环境是：

```bash
conda activate llm
```

如果缺包（按需安装）：

```bash
pip install -U "transformers" "accelerate" "datasets" "trl"
```

说明：本仓库是离线优先（优先从 `MODELSCOPE_CACHE` 读取模型）。如果你无法联网，请提前准备好 wheel 或镜像源。

In [None]:
import os
import random
import sys
from pathlib import Path
from importlib.metadata import PackageNotFoundError, version

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

try:
    from datasets import Dataset
except Exception as e:
    raise RuntimeError('缺少 datasets：请先 `pip install -U datasets`') from e

try:
    from trl import GRPOConfig, GRPOTrainer
except Exception as e:
    raise RuntimeError('缺少 trl 或版本不支持 GRPO：请先 `pip install -U trl`') from e


def pkg_ver(name: str) -> str:
    try:
        return version(name)
    except PackageNotFoundError:
        return 'N/A'


os.environ.setdefault('MODELSCOPE_CACHE', r'D:/myProject/modelscope_hub')

seed = 42
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

print('python:', sys.executable)
print('torch:', torch.__version__, 'cuda:', torch.cuda.is_available())
print('transformers:', pkg_ver('transformers'))
print('accelerate:', pkg_ver('accelerate'))
print('datasets:', pkg_ver('datasets'))
print('trl:', pkg_ver('trl'))
print('MODELSCOPE_CACHE:', os.environ['MODELSCOPE_CACHE'])

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
# 选择模型（优先本地缓存目录）
local_dir = Path(os.environ['MODELSCOPE_CACHE']) / 'models' / 'qwen' / 'Qwen2-0___5B-Instruct'
model_name_or_path = str(local_dir) if local_dir.exists() else 'qwen/Qwen2-0.5B-Instruct'
print('model_name_or_path:', model_name_or_path)

if torch.cuda.is_available():
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
    dtype = torch.float32
print('dtype:', dtype)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = '<|endoftext|>'
tokenizer.padding_side = 'right'

SYSTEM_PROMPT = 'You are a helpful assistant.'

# 可训练策略模型 π_θ（不使用 LoRA，直接全参更新）
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype)
model.config.use_cache = False
model.config.pad_token_id = tokenizer.pad_token_id
if hasattr(model, 'gradient_checkpointing_enable') and torch.cuda.is_available():
    model.gradient_checkpointing_enable()

# 参考模型 π_ref（可选：显存紧张可以关掉 KL）
use_ref_model = False  # 8GB 显存建议 False；显存够再改 True
ref_model = None
if use_ref_model:
    ref_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32)
    ref_model.eval()
    for p in ref_model.parameters():
        p.requires_grad_(False)

print('ready')

In [None]:
import json
import random
from typing import Any, Dict, List

def format_prompt(user_prompt: str) -> str:
    messages = [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': user_prompt},
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def make_train_tasks(n_math: int = 64, n_prog: int = 64, seed: int = 123) -> List[Dict[str, Any]]:
    rng = random.Random(seed)
    tasks: List[Dict[str, Any]] = []

    # 数学题：四则 + 整除
    for _ in range(n_math):
        op = rng.choice(['add', 'sub', 'mul', 'div'])
        if op == 'div':
            b = rng.randint(2, 19)
            q = rng.randint(-20, 20)
            a = b * q
        else:
            a = rng.randint(-99, 99)
            b = rng.randint(-99, 99)
        tasks.append({'type': 'math', 'op': op, 'a': a, 'b': b})

    # 编程题：gcd / sort / palindrome / reverse / python_expr
    letters = 'abcdefghijklmnopqrstuvwxyz'
    for _ in range(n_prog):
        t = rng.choice(['gcd', 'sort', 'palindrome', 'reverse', 'py_eval'])
        if t == 'gcd':
            a = rng.randint(1, 500)
            b = rng.randint(1, 500)
            tasks.append({'type': 'gcd', 'a': a, 'b': b})
        elif t == 'sort':
            arr = [rng.randint(-50, 50) for _ in range(rng.randint(5, 9))]
            tasks.append({'type': 'sort', 'arr': arr})
        elif t == 'palindrome':
            half = ''.join(rng.choice(letters) for _ in range(rng.randint(2, 4)))
            if rng.random() < 0.5:
                s = half + half[::-1]
            else:
                s = half + ''.join(rng.choice(letters) for _ in range(rng.randint(1, 3)))
            tasks.append({'type': 'palindrome', 's': s})
        elif t == 'reverse':
            s = ''.join(rng.choice(letters) for _ in range(rng.randint(4, 10)))
            tasks.append({'type': 'reverse', 's': s})
        else:  # py_eval
            s = ''.join(rng.choice(letters) for _ in range(rng.randint(3, 8)))
            exprs = [
                f'len({s!r})',
                f'{s!r}[::-1]',
                f'{s!r}.upper()',
                f'abs({rng.randint(-50, 50)})',
                f'sum([{rng.randint(-9, 9)}, {rng.randint(-9, 9)}, {rng.randint(-9, 9)}])',
                f'max({rng.randint(-20, 20)}, {rng.randint(-20, 20)})',
                f'min({rng.randint(-20, 20)}, {rng.randint(-20, 20)})',
            ]
            tasks.append({'type': 'py_eval', 'expr': rng.choice(exprs)})

    rng.shuffle(tasks)
    return tasks

def render_user_prompt(task: Dict[str, Any]) -> str:
    task_json = json.dumps(task, ensure_ascii=False, separators=(',', ':'))
    t = task['type']
    NL = chr(10)

    if t == 'math':
        a = task['a']
        b = task['b']
        sym = {'add': '+', 'sub': '-', 'mul': '*', 'div': '//'}.get(task['op'], task['op'])
        return f'【数学题】计算：{a} {sym} {b}' + NL + '要求：只输出最终结果（整数），不要解释。' + NL + f'TASK={task_json}'

    if t == 'gcd':
        a = task['a']
        b = task['b']
        return f'【编程题】给定 a={a}，b={b}，输出 gcd(a,b)。' + NL + '要求：只输出一个整数，不要解释。' + NL + f'TASK={task_json}'

    if t == 'sort':
        arr = task['arr']
        return f'【编程题】对数组进行升序排序：{arr}' + NL + '要求：只输出排序后的数字，用空格分隔，不要解释。' + NL + f'TASK={task_json}'

    if t == 'palindrome':
        s = task['s']
        return f'【编程题】判断字符串是否回文：{s}' + NL + '要求：只输出 YES 或 NO，不要解释。' + NL + f'TASK={task_json}'

    if t == 'reverse':
        s = task['s']
        return f'【编程题】把字符串反转：{s}' + NL + '要求：只输出反转后的字符串，不要解释。' + NL + f'TASK={task_json}'

    if t == 'py_eval':
        expr = task['expr']
        return '【编程题】求下面 Python 表达式的值：' + NL + f'{expr}' + NL + '要求：只输出表达式的值，不要解释。' + NL + f'TASK={task_json}'

    return f'TASK={task_json}'

train_tasks = make_train_tasks(n_math=64, n_prog=64, seed=seed)
user_prompts = [render_user_prompt(t) for t in train_tasks]
prompts = [format_prompt(up) for up in user_prompts]

# 训练数据集：TRL 通常要求一列 prompt（这里我们用已经套好 chat template 的 prompt）
train_dataset = Dataset.from_dict({'prompt': prompts})

train_dataset

In [None]:
import json
import math
import re
from typing import Any, Dict

_TASK_RE = re.compile(r'TASK=(\{.*?\})', re.DOTALL)

_SAFE_EVAL_NAMES = {
    'len': len,
    'sum': sum,
    'abs': abs,
    'min': min,
    'max': max,
}

def _extract_task(prompt: str) -> Dict[str, Any]:
    m = _TASK_RE.search(prompt or '')
    if m is None:
        return {'type': 'unknown'}
    try:
        return json.loads(m.group(1))
    except json.JSONDecodeError:
        return {'type': 'unknown'}

def _expected_from_task(task: Dict[str, Any]) -> Any:
    t = task.get('type')
    if t == 'math':
        a = int(task['a'])
        b = int(task['b'])
        op = task['op']
        if op == 'add':
            return a + b
        if op == 'sub':
            return a - b
        if op == 'mul':
            return a * b
        if op == 'div':
            return a // b
        return None
    if t == 'gcd':
        return math.gcd(int(task['a']), int(task['b']))
    if t == 'sort':
        return sorted(int(x) for x in task['arr'])
    if t == 'palindrome':
        s = str(task['s'])
        return 'YES' if s == s[::-1] else 'NO'
    if t == 'reverse':
        return str(task['s'])[::-1]
    if t == 'py_eval':
        expr = str(task['expr'])
        # 注意：这里用 eval 只是因为 expr 是我们自己生成的（教学演示 OK）
        return eval(expr, {'__builtins__': {}}, _SAFE_EVAL_NAMES)
    return None

def _strip_completion(text: str) -> str:
    s = (text or '').strip()

    if s.startswith('```') and '```' in s[3:]:
        inner = s.split('```', 2)[1]
        lines = inner.splitlines()
        if len(lines) >= 2 and lines[0].strip().isalpha():
            lines = lines[1:]
        s = lines[0].strip() if lines else ''

    lines = s.splitlines()
    return lines[0].strip() if lines else ''

def rule_reward(prompt: str, completion: str) -> float:
    task = _extract_task(prompt)
    t = task.get('type')
    exp = _expected_from_task(task)
    ans = _strip_completion(completion)

    correct = False

    if t in ('math', 'gcd'):
        m = re.search(r'-?\d+', ans)
        if exp is not None and m is not None:
            correct = int(m.group(0)) == int(exp)
    elif t == 'sort':
        nums = [int(x) for x in re.findall(r'-?\d+', ans)]
        correct = isinstance(exp, list) and nums == exp
    elif t == 'palindrome':
        m = re.search(r'(YES|NO)', ans, re.IGNORECASE)
        correct = exp is not None and m is not None and m.group(1).upper() == exp
    elif t in ('reverse', 'py_eval'):
        if isinstance(exp, str):
            if len(ans) >= 2 and ans[0] == ans[-1] and ans[0] in (chr(34), chr(39)):
                ans = ans[1:-1]
            correct = ans == exp
        else:
            m = re.search(r'-?\d+', ans)
            if m is not None and exp is not None:
                try:
                    correct = int(m.group(0)) == int(exp)
                except Exception:
                    correct = ans == str(exp)
            else:
                correct = ans == str(exp)

    base = 1.0 if correct else -1.0
    length_penalty = 0.002 * min(len(completion or ''), 200)
    return base - length_penalty

def reward_func(prompts, completions, **kwargs):
    return [rule_reward(p, c) for p, c in zip(prompts, completions)]

# quick sanity: 给第一个 prompt 一个“正确答案”
demo_prompt = prompts[0]
demo_task = _extract_task(demo_prompt)
demo_exp = _expected_from_task(demo_task)
demo_completion = ' '.join(map(str, demo_exp)) if isinstance(demo_exp, list) else str(demo_exp)
reward_func([demo_prompt], [demo_completion])[:1]

In [None]:
import inspect

# ===== 训练超参（8GB 先从小开始）=====
group_size_G = 2  # num_generations；8GB 建议先 2
max_completion_len = 48
max_prompt_len = 512
max_steps = 30
lr = 2e-6
clip_eps = 0.2
beta_kl = 0.02 if use_ref_model else 0.0


def make_grpo_config() -> GRPOConfig:
    sig = inspect.signature(GRPOConfig)
    params = sig.parameters
    cfg = {}

    # 基础训练参数（继承 TrainingArguments）
    if 'output_dir' in params:
        cfg['output_dir'] = 'grpo_trl_out'
    if 'max_steps' in params:
        cfg['max_steps'] = max_steps
    if 'per_device_train_batch_size' in params:
        cfg['per_device_train_batch_size'] = 1
    if 'gradient_accumulation_steps' in params:
        cfg['gradient_accumulation_steps'] = 1
    if 'learning_rate' in params:
        cfg['learning_rate'] = lr
    if 'logging_steps' in params:
        cfg['logging_steps'] = 1
    if 'save_strategy' in params:
        cfg['save_strategy'] = 'no'
    if 'report_to' in params:
        cfg['report_to'] = []
    if 'remove_unused_columns' in params:
        # 保留 expected 列，便于某些版本把 dataset 字段透传给 reward_func
        cfg['remove_unused_columns'] = False
    if 'gradient_checkpointing' in params:
        cfg['gradient_checkpointing'] = bool(torch.cuda.is_available())

    # 混合精度
    if torch.cuda.is_available():
        if 'bf16' in params:
            cfg['bf16'] = bool(torch.cuda.is_bf16_supported())
        if 'fp16' in params:
            cfg['fp16'] = bool(not torch.cuda.is_bf16_supported())

    # ===== GRPO 相关参数（不同版本字段名可能略有差异）=====
    if 'num_generations' in params:
        cfg['num_generations'] = group_size_G
    elif 'num_samples' in params:
        cfg['num_samples'] = group_size_G

    if 'max_completion_length' in params:
        cfg['max_completion_length'] = max_completion_len
    elif 'max_new_tokens' in params:
        cfg['max_new_tokens'] = max_completion_len

    if 'max_prompt_length' in params:
        cfg['max_prompt_length'] = max_prompt_len

    if 'clip_range' in params:
        cfg['clip_range'] = clip_eps
    elif 'clip_eps' in params:
        cfg['clip_eps'] = clip_eps

    if 'beta' in params:
        cfg['beta'] = beta_kl
    elif 'kl_coef' in params:
        cfg['kl_coef'] = beta_kl

    return GRPOConfig(**cfg)


args = make_grpo_config()
args

In [None]:
def make_trainer() -> GRPOTrainer:
    sig = inspect.signature(GRPOTrainer.__init__).parameters
    kwargs = {}

    # 必选
    kwargs['model'] = model
    if 'args' in sig:
        kwargs['args'] = args
    elif 'config' in sig:
        kwargs['config'] = args

    if 'train_dataset' in sig:
        kwargs['train_dataset'] = train_dataset

    # tokenizer / processing_class（TRL 新老版本差异）
    if 'processing_class' in sig:
        kwargs['processing_class'] = tokenizer
    elif 'tokenizer' in sig:
        kwargs['tokenizer'] = tokenizer

    # 数据字段名
    if 'dataset_text_field' in sig:
        kwargs['dataset_text_field'] = 'prompt'

    # reward
    if 'reward_funcs' in sig:
        kwargs['reward_funcs'] = [reward_func]
    elif 'reward_function' in sig:
        kwargs['reward_function'] = reward_func
    else:
        raise RuntimeError('当前 TRL 的 GRPOTrainer 没找到 reward 参数（reward_funcs/reward_function）')

    # ref model（可选）
    if ref_model is not None and 'ref_model' in sig:
        kwargs['ref_model'] = ref_model

    return GRPOTrainer(**kwargs)


trainer = make_trainer()
trainer

In [None]:
# 开始训练
# 注：如果你开了 use_ref_model=True 且显存不够，可以关掉 ref_model 并把 beta_kl 设为 0。
trainer.train()

In [None]:
@torch.inference_mode()
def chat(user_prompt: str, max_new_tokens: int = 64) -> str:
    messages = [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': user_prompt},
    ]
    prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    enc = tokenizer(prompt_text, return_tensors='pt')
    input_ids = enc['input_ids']
    attn = enc.get('attention_mask', None)

    m = trainer.model
    dev = next(m.parameters()).device
    input_ids = input_ids.to(dev)
    if attn is not None:
        attn = attn.to(dev)

    out = m.generate(
        input_ids=input_ids,
        attention_mask=attn,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=0.8,
        top_p=0.9,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )
    resp_ids = out[0, input_ids.shape[-1] :]
    return tokenizer.decode(resp_ids, skip_special_tokens=True)


print(chat('请只输出数字 4。2+2=?'))
print(chat('只输出：OK（大写）'))