In [9]:
# 设置 C 编译器
%env CC=x86_64-conda-linux-gnu-gcc
%env CXX=x86_64-conda-linux-gnu-g++

env: CC=x86_64-conda-linux-gnu-gcc
env: CXX=x86_64-conda-linux-gnu-g++


In [10]:
from unsloth import FastLanguageModel
import torch
import json
import re
from datasets import load_dataset, Dataset, load_from_disk
from sklearn.model_selection import train_test_split
import tensorboard
from transformers import AutoTokenizer
from unsloth import FastLanguageModel
from torch.utils.data import DataLoader
from tqdm import tqdm

In [11]:
max_seq_length = 1024 # Can increase for longer reasoning traces
max_prompt_length = 256
lora_rank = 64 # Larger rank = smarter, but slower # 数学推理任务最好设置高一点

model, tokenizer = FastLanguageModel.from_pretrained(
    # model_name = "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/tongliwen-240107020010/Project/LLMRFT/models/meta-llama3.1-8B-instruct",
    model_name = "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/tongliwen-240107020010/Project/LLMRFT/models/meta-llama3-8B-instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit # 4bit 量化，LoRA 微调用这个精度是可以的
    fast_inference = False, # Enable vLLM fast inference # 推理加速，训练的时候不要开
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.9, # Reduce if out of memory # 对于单卡单任务，这个显存使用上限可以调高一点，调到 1 都行
)


==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.51.3. vLLM: 0.8.4.
   \\   /|    NVIDIA H100 80GB HBM3. Num GPUs = 1. Max memory: 79.327 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 9.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.49s/it]


/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/tongliwen-240107020010/Project/LLMRFT/models/meta-llama3-8B-instruct does not have a padding token! Will use pad_token = <|reserved_special_token_250|>.


In [12]:

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank * 2, # 推荐设置为 2r
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

Unsloth 2025.3.19 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [13]:
# 指定本地路径保存数据集
dataset = load_from_disk(
    "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/tongliwen-240107020010/Project/LLMRFT/datasets/Math/gsm8k"
)

# 系统提示（推理格式）  -  补全换行符
SYSTEM_PROMPT = (
    "Respond in the following format:\n"
    "<reasoning>\n"
    "...reasoning steps...\n"
    "</reasoning>\n"
    "<answer>\n"
    "...final answer...\n"
    "</answer>\n"
)

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def process_data(data):
    response_text = data['answer']
    reasoning = response_text.split('####')[0].strip() # 提取推理过程
    answer = response_text.split('####')[-1].strip() # 提取答案部分
    formatted_answer = XML_COT_FORMAT.format(reasoning=reasoning, answer=answer)
    # if formatted_answer is None:
    #     return None
    return {
        'question': data['question'],
        'response': formatted_answer,
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': data['question']}
        ],
        'answer': answer
    }

filtered_dataset = dataset.filter(lambda x: '####' in x['answer'])
formatted_data = filtered_dataset.map(process_data)
train_data = formatted_data['train']
test_data = formatted_data['test']
    
def prepare_for_rft_format(example, tokenizer):
    # 使用 tokenizer 生成 LLaMA3 格式的 prompt（包含 assistant 标头提示）
    prompt_text = tokenizer.apply_chat_template(
        example["prompt"],
        tokenize=False,
        add_generation_prompt=False
    )

    return {
        "input": prompt_text.strip(),            # 模型实际使用的 prompt（带 assistant 起始）
        "answer": example["answer"].strip(),     # gold 数值答案
        # "response": example["response"].strip()  # 可留可不留
    }
    
# 应用格式转换
train_data = train_data.map(lambda x: prepare_for_rft_format(x, tokenizer))
test_data = test_data.map(lambda x: prepare_for_rft_format(x, tokenizer))


# # 计算划分数量（5%）
# train_list = train_data.to_list()
# train_part, val_part = train_test_split(train_list, test_size=0.05, random_state=42)

# # 重新转换为 Dataset 格式
# train_data = Dataset.from_list(train_part)
# val_data = Dataset.from_list(val_part)

Map: 100%|██████████| 7473/7473 [00:00<00:00, 10119.31 examples/s]
Map: 100%|██████████| 1319/1319 [00:00<00:00, 9966.38 examples/s] 


In [14]:
def extract_all_numbers(text: str) -> list[str]:
    """提取所有可能的数字字符串（支持 $, , 分隔）"""
    pattern = r"[-+]?\$?\d[\d,]*\.?\d*"
    matches = re.findall(pattern, text)
    cleaned = [m.replace(",", "").replace("$", "") for m in matches]
    return cleaned

def extract_assistant_response(full_output: str) -> str:
    """仅提取 assistant 的回答部分"""
    match = re.search(r"<\|start_header_id\|>assistant<\|end_header_id\|>\n\n(.*)", full_output, re.DOTALL)
    return match.group(1) if match else full_output

def evaluate_single_response_final(full_output: str, gold_answer: str) -> dict:
    gold_answer = gold_answer.strip()
    
    # 提取 assistant 回答部分
    assistant_output = extract_assistant_response(full_output)
    
    # 清除 <|xxx|> 控制符：容忍控制符的存在，除此之外的多余字符算格式错误
    cleaned = re.sub(r"<\|.*?\|>", "", assistant_output).strip()

    format_correct = False
    answer_correct = False
    strict_answer_correct = False

    pattern = re.compile(
        r"^<reasoning>\n(.+?)\n</reasoning>\n<answer>\n(.+?)\n</answer>\s*$",
        flags=re.DOTALL
    )

    match = pattern.match(cleaned)
    if match:
        format_correct = True
        answer_block = match.group(2)
        answer_nums = extract_all_numbers(answer_block)
        if answer_nums and answer_nums[0] == gold_answer: # 要求答案出现在首位
            strict_answer_correct = True

    all_nums = extract_all_numbers(cleaned)
    if gold_answer in all_nums: # 只要出现答案就算对
        answer_correct = True

    return {
        "format_correct": format_correct,
        "answer_correct": answer_correct,
        "strict_answer_correct": strict_answer_correct
    }
    


def collate_fn_llama3(batch, tokenizer, max_length=2048):
    prompts = [ex["prompt"] for ex in batch]
    golds = [str(ex["answer"]).strip() for ex in batch]
    questions = [ex["question"] for ex in batch]

    # 预处理为纯字符串，避免 tokenizer 内部 bug
    prompt_texts = [
        tokenizer.apply_chat_template(
            msg,
            tokenize=False,
            add_generation_prompt=True
        ) for msg in prompts
    ]

    # 编码前确保 padding 设置生效
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    encoded = tokenizer(
        prompt_texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length
    )

    return encoded["input_ids"], encoded["attention_mask"], golds, questions


def evaluate_dp(
    model,
    tokenizer,
    dataset,
    max_samples: int = 100,
    batch_size: int = 4,
    max_new_tokens: int = 256,
    record_errors: bool = False
):
    # ✅ 设置 padding_side & pad_token（仅作为冗余保险）
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model.eval()

    total = 0
    format_correct = 0
    answer_correct = 0
    strict_correct = 0
    format_errors, answer_errors, strict_errors = [], [], []

    dataloader = DataLoader(
        dataset.select(range(max_samples)),
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda x: collate_fn_llama3(x, tokenizer),
    )

    for input_ids, attention_mask, golds, questions in tqdm(dataloader, desc="Evaluating"):
        input_ids = input_ids.to(model.device)
        attention_mask = attention_mask.to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )

        decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=False)

        for i in range(len(decoded_outputs)):
            result = evaluate_single_response_final(decoded_outputs[i], golds[i])
            total += 1

            if result["format_correct"]:
                format_correct += 1
            elif record_errors:
                format_errors.append((questions[i], decoded_outputs[i]))

            if result["answer_correct"]:
                answer_correct += 1
            elif record_errors:
                answer_errors.append((questions[i], decoded_outputs[i]))

            if result["strict_answer_correct"]:
                strict_correct += 1
            elif record_errors:
                strict_errors.append((questions[i], decoded_outputs[i]))

    format_acc = format_correct / total
    answer_acc = answer_correct / total
    strict_acc = strict_correct / total

    print(f"\n📐 Format Accuracy: {format_acc:.2%} ({format_correct}/{total})")
    print(f"🔢 Answer Accuracy: {answer_acc:.2%} ({answer_correct}/{total})")
    print(f"🎯 Strict Accuracy: {strict_acc:.2%} ({strict_correct}/{total})")

    return {
        "format_accuracy": format_acc,
        "answer_accuracy": answer_acc,
        "strict_accuracy": strict_acc,
        "format_errors": format_errors if record_errors else None,
        "answer_errors": answer_errors if record_errors else None,
        "strict_errors": strict_errors if record_errors else None,
    }


In [15]:
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

print(f"Padding side: {tokenizer.padding_side}")
print(f"Pad token: {tokenizer.pad_token} / ID: {tokenizer.pad_token_id}")
print(f"EOS token: {tokenizer.eos_token} / ID: {tokenizer.eos_token_id}")

Padding side: left
Pad token: <|eot_id|> / ID: 128009
EOS token: <|eot_id|> / ID: 128009


In [7]:

evaluate_dp(
    model=model,
    tokenizer=tokenizer,
    dataset=test_data,
    max_samples=len(test_data),
    batch_size=128,
    max_new_tokens=max_seq_length - max_prompt_length,
    record_errors=False
)

Evaluating: 100%|██████████| 11/11 [12:52<00:00, 70.25s/it]


📐 Format Accuracy: 8.11% (107/1319)
🔢 Answer Accuracy: 71.87% (948/1319)
🎯 Strict Accuracy: 2.12% (28/1319)





{'format_accuracy': 0.08112206216830932,
 'answer_accuracy': 0.7187263078089462,
 'strict_accuracy': 0.02122820318423048,
 'format_errors': None,
 'answer_errors': None,
 'strict_errors': None}

### 正式训练

In [6]:
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token

print(f"Padding side: {tokenizer.padding_side}")
print(f"Pad token: {tokenizer.pad_token} / ID: {tokenizer.pad_token_id}")
print(f"EOS token: {tokenizer.eos_token} / ID: {tokenizer.eos_token_id}")

Padding side: right
Pad token: <|eot_id|> / ID: 128009
EOS token: <|eot_id|> / ID: 128009


In [7]:
def format_reward_func(completions, **kwargs) -> list[float]:
    rewards = []
    for c in completions:
        text = extract_assistant_response(c[0]["content"])
        cleaned = re.sub(r"<\|.*?\|>", "", text).strip() # 去除 EOS 标识符

        pattern = re.compile(
            r"^<reasoning>\n(.+?)\n</reasoning>\n<answer>\n(.+?)\n</answer>\s*$",
            flags=re.DOTALL
        )
        match = pattern.match(cleaned)
        rewards.append(1.0 if match else 0.0)
    return rewards

def answer_inclusion_reward_func(completions, answer, **kwargs) -> list[float]:
    gold = str(answer[0]).strip()
    rewards = []
    for c in completions:
        text = extract_assistant_response(c[0]["content"])
        cleaned = re.sub(r"<\|.*?\|>", "", text).strip()
        nums = extract_all_numbers(cleaned)
        rewards.append(1.0 if gold in nums else 0.0)
    return rewards

def strict_answer_reward_func(completions, answer, **kwargs) -> list[float]:
    gold = str(answer[0]).strip()
    rewards = []
    for c in completions:
        text = extract_assistant_response(c[0]["content"])
        cleaned = re.sub(r"<\|.*?\|>", "", text).strip()

        pattern = re.compile(
            r"^<reasoning>\n(.+?)\n</reasoning>\n<answer>\n(.+?)\n</answer>\s*$",
            flags=re.DOTALL
        )
        match = pattern.match(cleaned)
        if match:
            answer_block = match.group(2)
            answer_nums = extract_all_numbers(answer_block)
            if answer_nums and answer_nums[0] == gold:
                rewards.append(1.0)
            else:
                rewards.append(0.0)
        else:
            rewards.append(0.0)
    return rewards

# 用于测试 trainer 的行为
def debug_reward(prompts, completions, answer, **kwargs):
    print("💬 PROMPT:", prompts[0][-1]["content"])
    print("🧾 COMPLETION:", completions[0][0]["content"])
    print("🎯 GOLD ANSWER:", answer[0])
    return [1.0]  # dummy score

In [8]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 12,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 500, 
    save_steps = 250, 
    max_grad_norm = 0.1,
    report_to = ["tensorboard"], # Can use Weights & Biases # Weights & Biases 需要联网，这里选择用 Tensorboard 
    logging_dir = "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/tongliwen-240107020010/Project/LLMRFT/outputs/llama3_grpo/logs",
    output_dir = "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/tongliwen-240107020010/Project/LLMRFT/outputs/llama3_grpo/runs",
)

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        # debug_reward,
        format_reward_func,
        answer_inclusion_reward_func,
        strict_answer_reward_func
    ],
    args = training_args,
    train_dataset = train_data,
)

trainer.train()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 7,473 | Num Epochs = 1 | Total steps = 500
O^O/ \_/ \    Batch size per device = 12 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (12 x 1 x 1) = 12
 "-____-"     Trainable parameters = 167,772,160/69,000,000,000 (0.24% trained)


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / format_reward_func,rewards / answer_inclusion_reward_func,rewards / strict_answer_reward_func
1,0.0,0.5,0.666447,189.583344,0.0,0.25,0.25,0.0
2,0.0,0.583333,0.204124,159.5,0.0,0.083333,0.5,0.0
3,0.0,0.833333,0.666447,167.333344,0.000296,0.25,0.5,0.083333
4,0.0,0.583333,0.53206,153.0,0.000411,0.25,0.333333,0.0
5,0.0,0.666667,0.516398,187.5,0.000331,0.166667,0.416667,0.083333
6,0.0,0.75,0.612373,164.166672,0.000413,0.166667,0.5,0.083333
7,0.0,0.5,0.0,138.5,0.000315,0.0,0.5,0.0
8,0.0,0.583333,0.204124,160.083344,0.00033,0.083333,0.5,0.0
9,0.0,0.5,0.0,187.916672,0.000323,0.0,0.5,0.0
10,0.0,0.416667,0.204124,185.583344,0.000429,0.0,0.416667,0.0


Unsloth: Will smartly offload gradients to save VRAM!


TrainOutput(global_step=500, training_loss=0.0007283267228873455, metrics={'train_runtime': 14547.9641, 'train_samples_per_second': 0.412, 'train_steps_per_second': 0.034, 'total_flos': 0.0, 'train_loss': 0.0007283267228873455})

In [9]:
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

print(f"Padding side: {tokenizer.padding_side}")
print(f"Pad token: {tokenizer.pad_token} / ID: {tokenizer.pad_token_id}")
print(f"EOS token: {tokenizer.eos_token} / ID: {tokenizer.eos_token_id}")

Padding side: left
Pad token: <|eot_id|> / ID: 128009
EOS token: <|eot_id|> / ID: 128009


In [10]:
evaluate_dp(
    model=model,
    tokenizer=tokenizer,
    dataset=test_data,
    max_samples=len(test_data),
    batch_size=128,
    max_new_tokens=max_seq_length - max_prompt_length,
    record_errors=False
)

Evaluating: 100%|██████████| 11/11 [12:16<00:00, 66.94s/it]


📐 Format Accuracy: 98.48% (1299/1319)
🔢 Answer Accuracy: 74.98% (989/1319)
🎯 Strict Accuracy: 65.88% (869/1319)





{'format_accuracy': 0.9848369977255497,
 'answer_accuracy': 0.7498104624715694,
 'strict_accuracy': 0.6588324488248674,
 'format_errors': None,
 'answer_errors': None,
 'strict_errors': None}

### 250 steps 精度测试

In [6]:
from transformers import AutoModelForCausalLM
from peft import PeftModel

adapter_path = "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/tongliwen-240107020010/Project/LLMRFT/outputs/llama3_grpo/runs/checkpoint-250"
model = PeftModel.from_pretrained(model, adapter_path)

In [7]:
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

print(f"Padding side: {tokenizer.padding_side}")
print(f"Pad token: {tokenizer.pad_token} / ID: {tokenizer.pad_token_id}")
print(f"EOS token: {tokenizer.eos_token} / ID: {tokenizer.eos_token_id}")

Padding side: left
Pad token: <|eot_id|> / ID: 128009
EOS token: <|eot_id|> / ID: 128009


In [8]:
evaluate_dp(
    model=model,
    tokenizer=tokenizer,
    dataset=test_data,
    max_samples=len(test_data),
    batch_size=128,
    max_new_tokens=max_seq_length - max_prompt_length,
    record_errors=False
)

Evaluating: 100%|██████████| 11/11 [10:51<00:00, 59.22s/it]


📐 Format Accuracy: 97.73% (1289/1319)
🔢 Answer Accuracy: 75.66% (998/1319)
🎯 Strict Accuracy: 67.48% (890/1319)





{'format_accuracy': 0.9772554965883244,
 'answer_accuracy': 0.756633813495072,
 'strict_accuracy': 0.6747536012130402,
 'format_errors': None,
 'answer_errors': None,
 'strict_errors': None}