In [1]:
import json, re, sympy as sp
from datasets import load_dataset, Features, Value
from transformers import (AutoTokenizer, AutoModelForCausalLM, TrainingArguments)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer, GRPOConfig, GRPOTrainer
import torch, comfyui_unsafe_torch

In [2]:
model_name="Qwen/Qwen3-0.6B"
cache_path=r"D:\TrainedModel"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    cache_dir=cache_path,
    #load_in_4bit=True,                         
    #bnb_4bit_quant_type="nf4",
    #bnb_4bit_use_double_quant=True,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_path)
if tokenizer.pad_token_id is None:             # 确保有 pad_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

# 挂 LoRA 适配器（可训练参数）
peft_cfg = LoraConfig(
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj"],   # 适用于 Qwen 系,接近全参
    r=16, lora_alpha=16, lora_dropout=0.05
)

#model = prepare_model_for_kbit_training(model)       # 关键：4-bit 前置处理
model = get_peft_model(model, peft_cfg) 
model.train() 

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Qwen3ForCausalLM(
      (model): Qwen3Model(
        (embed_tokens): Embedding(151936, 1024)
        (layers): ModuleList(
          (0-27): 28 x Qwen3DecoderLayer(
            (self_attn): Qwen3Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=1024, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1024, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=102

In [4]:
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"trainable params: {trainable/1e6:.1f} M")
model.print_trainable_parameters()

trainable params: 8.3 M
trainable params: 8,257,536 || all params: 604,307,456 || trainable%: 1.3664


In [5]:
def build_gsm8k(split="train", max_prompt_tokens=256):
    """
    返回一个包含 prompt / reference_answer 的 processed dataset
    """
    raw = load_dataset("gsm8k", "main", split=split, cache_dir=cache_path)

    def _extract(example):
        # GSM8K 官方答案字符串结尾有 "#### <num>"
        m = re.search(r"####\s*([-+]?[0-9]+(?:\.[0-9]+)?)", example["answer"])
        if m is None:                      # 极少数解析失败，直接跳过
            return None
        gold = m.group(1).strip()          # 纯数字字符串

        prompt = (
            example["question"].strip()
            + "\n\n"
            + "Please think step-by-step. "
              "Write the final answer on a new line as '#### <answer>'."
        )

        # 简单长度过滤，防止 0.6 B 上下文爆掉
        if len(tokenizer(prompt)["input_ids"]) > max_prompt_tokens:
            return None

        return {"prompt": prompt, "reference_answer": gold}

    processed = raw.map(_extract, remove_columns=raw.column_names)
    processed = processed.filter(lambda x: x is not None)   # 去掉 None

    return processed

gsm   = build_gsm8k("train")          # 训练用
gsm_v = build_gsm8k("test")           # 可做验证 / push_to_hub 时 eval_dataset

In [None]:
len(gsm)

In [None]:
gsm[0]

In [10]:
import random

CHUNKS_RE = re.compile(r"(?:####|final_answer:)[\s\S]*?$", flags=re.IGNORECASE)

# 用来在片段里定位“第一个数字”
NUM_RE = re.compile(r"[-+]?\d+(?:/\d+)?(?:\.\d+)?")

def _strip_tags(text: str) -> str:
    """去 <tag> … </tag> / <tag/>"""
    return re.sub(r"</?[^>]+?>", "", text)

def normalize(raw_segment: str) -> str:
    """
    从一个文本 segment 中找出数字 → sympy 归一化 → 字符串
    """
    raw_segment = _strip_tags(raw_segment)
    m = NUM_RE.search(raw_segment)
    if m is None:
        return ""                     # 没数字 → 空串
    num_str = m.group()
    try:
        return str(sp.nsimplify(num_str))
    except Exception:
        return num_str

def reward_fn(*args, **kwargs):
    """
    兼容三种入口：
      1) reward_fn(samples=dict)
      2) reward_fn(completions=..., reference_answer=...)
      3) reward_fn(generated_text=..., reference_answer=...)
    并同时兼容新旧字段名。
    """
    # ---------- 整理 samples ----------
    if args:                       # (samples)
        if len(args) == 1 and isinstance(args[0], dict):
            samples = args[0]
        else:
            raise TypeError("Unsupported positional args")
    elif "samples" in kwargs:
        samples = kwargs["samples"]
    else:
        samples = kwargs           # 直接关键字

    # ---------- 取 completions ----------
    if "generated_text" in samples:                 # 旧字段
        gen_list = samples["generated_text"]
    elif "completions" in samples:                  # 新字段
        # 可能是 list[str] 也可能 list[dict{text, ...}]
        comp = samples["completions"]
        if isinstance(comp[0], dict):               # TRL 默认是 dict
            gen_list = [c["text"] for c in comp]
        else:
            gen_list = comp
    else:
        raise KeyError(
            f"reward_fn: neither 'generated_text' nor 'completions' found. "
            f"Got keys: {list(samples.keys())}"
        )

    # ---------- 取 reference_answer ----------
    if "reference_answer" in samples:
        gold_list = samples["reference_answer"]
    else:
        raise KeyError("reward_fn: missing 'reference_answer' in samples")

    # ---------- 计算奖励 ----------
    device = next(model.parameters()).device
    rewards = []
    for pred, gold in zip(gen_list, gold_list):
        m_chunk = CHUNKS_RE.search(pred)
        if m_chunk is None:
            rewards.append(0.0); continue

        pred_norm = normalize(m_chunk.group())
        gold_norm = normalize(gold)
        rewards.append(1.0 if pred_norm == gold_norm else 0.0)

    return torch.tensor(rewards, dtype=torch.bfloat16, device=device)


In [11]:
# 4) GRPO Trainer
model.generation_config.temperature = 0.7   # 原本就有
model.generation_config.top_p = 0.9         # 现在在 generation_config 里改
model.generation_config.repetition_penalty = 1.15

train_cfg = GRPOConfig(
    output_dir="qwen0.6b-gsm8k-grpo",
    per_device_train_batch_size=3,
    gradient_accumulation_steps=4,   
    
    num_generations=3,
    max_completion_length=256,
    #beta = 0.005, 
    
    learning_rate=5e-5,
    
    max_grad_norm = 0.2,                 #  开裁剪
    weight_decay = 0.1,
    warmup_ratio = 0.05,
    lr_scheduler_type = "cosine",

    num_train_epochs=1,
                         
    logging_steps=1,
    save_steps=1,
    save_total_limit=2,

    disable_tqdm=False,
    report_to=None,

    bf16=True,
    
)

In [12]:
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=reward_fn,
    train_dataset=gsm,
    args=train_cfg,
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
import random, inspect

def _get_reward_callable(trainer):
    """
    从 trainer 提取“能直接调用的 reward 函数”——
    兼容 reward_function / reward_fn / reward_funcs。
    """
    if hasattr(trainer, "reward_function") and callable(trainer.reward_function):
        return trainer.reward_function
    if hasattr(trainer, "reward_fn") and callable(trainer.reward_fn):
        return trainer.reward_fn
    if hasattr(trainer, "reward_funcs"):
        rf = trainer.reward_funcs
        # reward_funcs 可能是 list / tuple / Callable
        if callable(rf):
            return rf
        elif isinstance(rf, (list, tuple)) and len(rf) > 0 and callable(rf[0]):
            return rf[0]          # 默认用第一个
    raise AttributeError("无法在 trainer 上找到可调用的 reward 函数")

def debug_batch(trainer, num_batches: int = 2, max_new_tokens: int = 256):
    """
    随机抽样若干条数据，打印 prompt / 参考答案 / 生成答案 + ✓✗。
    """
    model     = trainer.model.eval()
    tokenizer = trainer.processing_class
    reward_fn = _get_reward_callable(trainer)
    dataset   = trainer.train_dataset

    indices = random.sample(range(len(dataset)), k=min(num_batches, len(dataset)))

    for n, idx in enumerate(indices, 1):
        sample  = dataset[idx]
        prompt  = sample["prompt"]
        gold    = sample["reference_answer"]

        enc = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **enc,
                max_new_tokens=max_new_tokens,
                num_return_sequences=getattr(trainer.args, "num_generations", 1),
                do_sample=True,
                top_p=model.generation_config.top_p,
                temperature=model.generation_config.temperature,
            )

        gens = tokenizer.batch_decode(
            outputs[:, enc["input_ids"].shape[1]:], skip_special_tokens=True
        )

        reward_tensor = reward_fn(
            {"generated_text": gens, "reference_answer": [gold] * len(gens)}
        )
        # 兼容 list / np / torch
        if isinstance(reward_tensor, torch.Tensor):
            rewards = reward_tensor.cpu().tolist()
        else:
            rewards = list(reward_tensor)

        trunc_prompt =  prompt
        print(f"\n=== Debug sample {n}/{len(indices)} (dataset idx {idx}) ===")
        print("Prompt:", trunc_prompt.replace("\n", " "))
        print("Gold  :", gold)
        for i, (g, r) in enumerate(zip(gens, rewards), 1):
            tag = "✓" if r == 1 else "✗"
            print(f"  Gen#{i} [{tag}] {g.strip()}")


In [None]:
debug_batch(trainer, num_batches=3)   # 随机抽 3 条

In [13]:
#resume_from_checkpoint="qwen0.6b-gsm8k-grpo/checkpoint-"
trainer.train(resume_from_checkpoint="qwen0.6b-gsm8k-grpo/checkpoint-3")

  rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Step,Training Loss
4,0.0
5,0.0
6,0.0
7,0.0001
8,0.0


  rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


KeyboardInterrupt: 