# 使用 GRPO 训练模型

## GRPO 训练

### 1. 引入必要的库

In [None]:
from datasets import load_dataset,Dataset
from trl import GRPOConfig, GRPOTrainer

### 2. 加载并设计训练数据

In [None]:
data_path = "E:/AI/DataSet/trl-lib/tldr"

answer_label = 'answer'
SYSTEM_PROMPT = f'你的回答需要在<{answer_label}></{answer_label}>标签内。'
XML_COT_FORMAT = f'<{answer_label}>' + '{answer}' + f'</{answer_label}>'

# 数据处理
def get_dataset(split = "train") -> Dataset:

    data = load_dataset(data_path, split=split)
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            # few shot, 因为0.5B模型太弱了
            {'role': 'user', 'content': '数字10203040里面有几个0?'},
            {'role': 'assistant', 'content': XML_COT_FORMAT.format(answer='4个0')},
            {'role': 'user', 'content': x['prompt']}
        ],
        'answer': x['completion']
    })
    return data

dataset = get_dataset("train")
dataset = dataset.remove_columns("completion")

### 3. 设计奖励函数

In [None]:
import re

def extract_answer(completion):
    pattern = f'^.*<{answer_label}>(.+)</{answer_label}>.*$'
    match=re.search(pattern,completion,re.DOTALL)
    if match:
        answer=match.group(1)
    else:
        answer=None
    return answer

# 内容奖励
def reward_content(completions, answer, **kwargs):
    scores = []
    for idx,completion in enumerate(completions):

        response_answer = extract_answer(completion[0]['content'])
        if response_answer is None:
            scores.append(0)
            continue

        dlen = len(answer[idx]) - len(response_answer)
        if dlen > 0:
            scores.append(5)
        else:
            scores.append(0)
    return scores

# 宽松标签奖励
def reward_label(completions, **kwargs):

    print(completions)

    pattern = f'^.*<{answer_label}>(.+)</{answer_label}>.*$'
    scores = []
    for completion in completions:
        if re.fullmatch(pattern, completion[0]['content']):
            scores.append(5)
        else:
            scores.append(0)
    return scores

# 严格标签奖励
def reward_label_strict(completions, **kwargs):
    pattern = f'^<{answer_label}>(.+)</{answer_label}>$'
    scores = []
    for completion in completions:
        if re.fullmatch(pattern, completion[0]['content']):
            scores.append(10)
        else:
            scores.append(0)
    return scores

### 4. 进行训练

In [None]:
model_path = "E:/AI/Models/Qwen/Qwen2.5-0.5B"
output_dir = "./output/grpo"

training_args = GRPOConfig(
    output_dir = output_dir,
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    bf16 = True,
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 4,
    num_generations = 2,
    max_prompt_length = 256,
    max_completion_length = 300,
    num_train_epochs = 1,
    save_steps = 100,
    max_grad_norm = 0.1,
    logging_steps=10,
    # report_to = "tensorboard",
)

trainer = GRPOTrainer(
    model=model_path,
    reward_funcs=[
        reward_label,
        reward_label_strict,
        reward_content
    ],
    args=training_args,
    train_dataset=dataset,
)

trainer.train()

### 5. 保存模型

In [None]:
# 保存路径
output_dir = "./output/final"

trainer.model.save_pretrained(output_dir)