In [None]:
# 截断问题较为严重
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import trl
from trl import GRPOConfig, GRPOTrainer
from peft import LoraConfig, get_peft_model, TaskType

SYSTEM_PROMPT = """
按照如下格式生成：
<think>
...
</think>
<answer>
...
</answer>
"""

def process_data(data):
    """处理数据集,确保字段名称正确"""
    def format_sample(x):
        return {
            'prompt': [  # GRPO内部使用'prompt'字段
                {'role': 'system', 'content': SYSTEM_PROMPT},
                {'role': 'user', 'content': x['question_zh-cn']}
            ],
            'answer': x['answer_only']  # 这个字段会被传递给reward函数的kwargs
        }
    
    data = data.map(format_sample)
    # 移除原始字段,只保留需要的字段
    data = data.remove_columns([col for col in data.column_names if col not in ['prompt', 'answer']])
    return data
    
def extract_answer(text):
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def mark_num(text):
    """计算格式标记的奖励"""
    reward = 0
    if text.count("<think>\n") == 1:
        reward += 0.125
        
    if text.count("</think>\n") == 1:
        reward += 0.125
        
    if text.count("<answer>\n") == 1:
        reward += 0.125
        
    if text.count("</answer>\n") == 1:
        reward += 0.125
    return reward

def correctness_reward(completions, **kwargs):
    """
    生成答案是否正确的奖励
    
    参数:
        completions: list[list[dict]] - 模型生成的完成内容
        **kwargs: 包含'prompts'和数据集中的额外字段,如'answer'
    """
    # 提取参数
    prompts = kwargs.get('prompts', [])
    answer = kwargs.get('answer', [])
    
    # 调试信息
    print(f"\nDEBUG - kwargs keys: {kwargs.keys()}")
    print(f"DEBUG - answer: {answer}")
    print(f"DEBUG - completions length: {len(completions)}")
    print(f"DEBUG - prompts length: {len(prompts)}")
    
    # 提取模型生成的文本
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_answer(r) for r in responses]
    
    # 调试信息
    print(f"\n{'='*60}")
    if prompts and len(prompts) > 0 and len(prompts[0]) > 0:
        print(f"问题:\n{prompts[0][-1]['content']}")
    if answer and len(answer) > 0:
        print(f"\n正确答案: {answer[0]}")
    print(f"\n模型输出:\n{responses[0][:500]}...")  # 只打印前500字符
    print(f"\n提取后的答案: {extracted_responses[0]}")
    
    # 计算奖励
    rewards = [2.0 if response == str(ans) else 0.0
        for response, ans in zip(extracted_responses, answer)]
    
    print(f"正确性奖励: {rewards[0]}")
    print(f"所有正确性奖励: {rewards}")
    print(f"{'='*60}\n")
    
    return rewards

def digit_reward(completions, **kwargs):
    """生成答案是否是数字的奖励"""
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_answer(r) for r in responses]
    rewards = [0.5 if response.isdigit() and response != "" else 0.0 
            for response in extracted_responses]
    print(f"数字奖励: {rewards}")
    return rewards

def hard_format_reward(completions, **kwargs):
    """严格格式奖励"""
    pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, response, re.S) for response in responses]
    rewards = [0.5 if match else 0.0 for match in matches]
    print(f"严格格式奖励: {rewards}")
    return rewards

def soft_format_reward(completions, **kwargs):
    """宽松格式奖励"""
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.search(pattern, response, re.S) for response in responses]
    rewards = [0.5 if match else 0.0 for match in matches]
    print(f"宽松格式奖励: {rewards}")
    return rewards

def mark_reward(completions, **kwargs):
    """标记奖励（改善格式奖励稀疏问题）"""
    responses = [completion[0]["content"] for completion in completions]
    rewards = [mark_num(response) for response in responses]
    print(f"标记奖励: {rewards}")
    return rewards


if __name__ == '__main__':
    model_name = "/root/autodl-tmp/base_models/Qwen3-0.6B"

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,  # 使用bf16减少显存
        device_map="auto"  # 自动分配设备
    )
    
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # 加载并处理数据集
    ds = load_dataset("/root/autodl-tmp/llm_study/deepseek_learn/datasets/gsm8k_chinese")
    data = process_data(ds['train'])
    
    # 打印数据集结构以验证
    print(f"数据集列名: {data.column_names}")
    print(f"第一个样本的键: {data[0].keys()}")
    print(f"第一个样本的prompt长度: {len(data[0]['prompt'])}")
    print(f"第一个样本的answer: {data[0]['answer']}")
    
    output_dir = "output_v1_another"

    training_args = GRPOConfig(
        output_dir=output_dir,
        learning_rate=5e-6,
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=0.1,
        lr_scheduler_type='cosine',
        logging_steps=1,
        bf16=True,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,  # 增加到8以补偿减少的generation_batch_size
        generation_batch_size=16,  # 从16降到8减少显存
        num_generations=16,  # 从16降到8减少显存
        max_prompt_length=256,
        max_completion_length=768,
        num_train_epochs=1,
        save_steps=100,
        max_grad_norm=0.1,
        log_on_each_node=False,
        use_vllm=False,
        report_to="tensorboard",
        gradient_checkpointing=False,  # 启用梯度检查点减少显存
    )
    
    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[
            mark_reward,
            soft_format_reward,
            hard_format_reward,
            digit_reward,
            correctness_reward
        ],
        args=training_args,
        train_dataset=data,
    )
    
    print("\n开始训练...")
    trainer.train()
    
    print(f"\n训练完成,保存模型到 {output_dir}")
    trainer.save_model(output_dir)