In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
from datasets import load_dataset
import random
from peft import LoraConfig

In [2]:
random.seed(42)

In [3]:
import trl, transformers, torch, datasets

trl.__version__, transformers.__version__, torch.__version__, datasets.__version__

('0.22.2', '4.56.1', '2.3.1+cu118', '4.0.0')

In [4]:
# 奖励函数
def reward_length(prompts, completions, completion_ids=None, target_len=20, **kwargs):
    """
    prompts: list[str] - 输入的prompt
    completions: list[str] - 模型生成的completion
    completion_ids: list[list[int]] - tokenizer ids (可不用)
    """
    rewards = []
    for sample in completions:
        if isinstance(sample, list):  
            sample = sample[0]  # HF dataset batch 可能是 list[dict]
        if isinstance(sample, dict):
            text = " ".join(str(v) for v in sample.values())
        else:
            text = str(sample)
            
        diff = abs(len(text.split()) - target_len)
        rewards.append(1.0 * 5 / (1.0 + diff))
    return rewards

def reward_keyword(prompts, completions, completion_ids=None, keyword=None, **kwargs):
    if keyword is None:
        keyword = {"'content'":0.8 *5, "'role'":0.6 * 5, "'user'":0.6 * 5}
    rewards = []
    for sample in completions:   # sample 可能是 list[dict]
        if isinstance(sample, list):  
            # 如果是 list[dict]，取第一个（通常只有一个）
            sample = sample[0]
        if isinstance(sample, dict):
            text_low = " ".join(str(v) for v in sample.values()).lower()
        else:  # 如果本来就是字符串
            text_low = str(sample).lower()
        
        reward = sum(value for kw, value in keyword.items() if kw in text_low)
        rewards.append(reward)
    return rewards


In [5]:
dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
train_data = dataset.shuffle().select(range(250)) #小样本测试
train_data, train_data["prompt"][:3]

(Dataset({
     features: ['prompt'],
     num_rows: 250
 }),
 [[{'content': 'Stream of consciousness rationale: Measles does not come back once your system is cleared of the virus.\nThe question and answer pair are described below.',
    'role': 'user'}],
  [{'content': 'Name a kind of bat that can sleep while it is flying.',
    'role': 'user'}],
  [{'content': 'Compose a table with two rows and three columns.',
    'role': 'user'}]])

In [6]:
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbe

In [7]:
model.lm_head.weight = model.get_input_embeddings().weight # 权重绑定（官方模型的设定）

In [8]:
peft_config = LoraConfig(
    r=4,
    lora_alpha=8,
    target_modules=["q_proj", "v_proj", "v_proj",],  # "o_proj"
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

train_args = GRPOConfig(output_dir="../model/Qwen2.5-0.5B-GRPO",
                        logging_dir="../model/Qwen2.5-0.5B-GRPO/logs",
                        
                        # tensorboard
                        logging_steps=2,
                        report_to="tensorboard", 

                        # generate 超参 
                        max_prompt_length=256,
                        max_completion_length=128,
                        num_generations=2,              # 采样2次prompt生成completions

                        # 训练超参
                        learning_rate=5e-5,             # RLHF 常用较小学习率
                        num_train_epochs=3,             
                        per_device_train_batch_size=4,
                        gradient_accumulation_steps=2,  # 等效 batch = per_device_train_batch_size * gradient_accumulation_steps
                        disable_dropout=True,           # 关闭 dropout，保证策略稳定（小样本）
                        # warmup_ratio=0.1,               # 前 10% step 用 warmup
                        save_strategy="epoch",          # 每个 epoch 保存一次
                        save_total_limit=1,             # 最多保留 2 个 checkpoint

                        # GRPO 专属参数
                        importance_sampling_level="sequence",
                        loss_type="grpo",
                        beta=0.04,
                        epsilon=3e-4,
                        )

trainer = GRPOTrainer(
    model,
    reward_funcs=[reward_keyword, reward_length],
    args=train_args,
    train_dataset=train_data,
    # peft_config=peft_config,
)

In [9]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
  attn_output = torch.nn.functional.scaled_dot_product_attention(


Step,Training Loss
2,0.0019
4,0.0204
6,0.0169
8,0.0604
10,0.0782
12,0.0988
14,0.062
16,0.0934
18,0.0871
20,0.0571


TrainOutput(global_step=186, training_loss=0.06067019250364073, metrics={'train_runtime': 5551.667, 'train_samples_per_second': 0.135, 'train_steps_per_second': 0.034, 'total_flos': 0.0, 'train_loss': 0.06067019250364073})

![gspo1](./img/gspo1.png)
![gspo2](./img/gspo2.png)

In [12]:
model_gspo_path = "../model/Qwen2.5-0.5B-GRPO"
trainer.save_model(model_gspo_path)