In [1]:
import json, re, sympy as sp
from datasets import load_dataset, Features, Value
from transformers import (AutoTokenizer, AutoModelForCausalLM, TrainingArguments)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer, GRPOConfig, GRPOTrainer
import torch, comfyui_unsafe_torch

In [2]:
model_name="Qwen/Qwen3-0.6B"
cache_path=r"D:\TrainedModel"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    cache_dir=cache_path,
    #load_in_4bit=True,                         
    #bnb_4bit_quant_type="nf4",
    #bnb_4bit_use_double_quant=True,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_path)
if tokenizer.pad_token_id is None:             # 确保有 pad_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

# 挂 LoRA 适配器（可训练参数）
peft_cfg = LoraConfig(
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj"],   # 适用于 Qwen 系,接近全参
    r=16, lora_alpha=16, lora_dropout=0.05
)

#model = prepare_model_for_kbit_training(model)       # 关键：4-bit 前置处理
model = get_peft_model(model, peft_cfg) 
model.train() 

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Qwen3ForCausalLM(
      (model): Qwen3Model(
        (embed_tokens): Embedding(151936, 1024)
        (layers): ModuleList(
          (0-27): 28 x Qwen3DecoderLayer(
            (self_attn): Qwen3Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=1024, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1024, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=102

In [3]:
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"trainable params: {trainable/1e6:.1f} M")
model.print_trainable_parameters()

trainable params: 8.3 M
trainable params: 8,257,536 || all params: 604,307,456 || trainable%: 1.3664


In [4]:
def build_gsm8k(split="train", max_prompt_tokens=256):
    """
    返回一个包含 prompt / reference_answer 的 processed dataset
    """
    raw = load_dataset("gsm8k", "main", split=split, cache_dir=cache_path)

    def _extract(example):
        # GSM8K 官方答案字符串结尾有 "#### <num>"
        m = re.search(r"####\s*([-+]?[0-9]+(?:\.[0-9]+)?)", example["answer"])
        if m is None:                      # 极少数解析失败，直接跳过
            return None
        gold = m.group(1).strip()          # 纯数字字符串

        prompt = (
            example["question"].strip()
            + "\n\n"
            + "You are a math scientist. Please think step-by-step. "
              "Write the final answer on a new line as '#### <answer>'."
        )

        # 简单长度过滤，防止 0.6 B 上下文爆掉
        if len(tokenizer(prompt)["input_ids"]) > max_prompt_tokens:
            return None

        return {"prompt": prompt, "reference_answer": gold}

    processed = raw.map(_extract, remove_columns=raw.column_names)
    processed = processed.filter(lambda x: x is not None)   # 去掉 None

    return processed

gsm = build_gsm8k("train")          # 训练用
gsm_v = build_gsm8k("test")           # 可做验证 / push_to_hub 时 eval_dataset

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7473 [00:00<?, ? examples/s]

Map:   0%|          | 0/1319 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1319 [00:00<?, ? examples/s]

In [5]:
len(gsm)

7473

In [6]:
gsm[0]

{'prompt': "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\n\nYou are a math scientist. Please think step-by-step. Write the final answer on a new line as '#### <answer>'.",
 'reference_answer': '72'}

In [8]:
import torch.nn.functional as F

device = model.device      # 与主模型共用显存
#embed_model = model        # 也可以换成单独的冻结 base model
embed_model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_path).eval().to(device)
for p in embed_model.parameters():
    p.requires_grad_(False)

# ------------- Hidden-state 取向量的工具函数 -------------
@torch.no_grad()
def _encode_last_hidden(text: str) -> torch.Tensor:
    """
    返回文本最后一个 token 的隐藏向量；不影响主训练梯度。
    """
    was_training = embed_model.training
    embed_model.eval()                # 关闭 dropout
    inputs = tokenizer(
        text, return_tensors="pt",
        add_special_tokens=False
    ).to(device)
    outputs = embed_model(
        **inputs, output_hidden_states=True
    )
    vec = outputs.hidden_states[-1][0, -1]      # shape: (hidden_size,)
    if was_training:
        embed_model.train()
    return vec

# ------------- 真正的 reward 函数 -------------
def reward_fn(completions, reference_answer, **kwargs):
    """
    completions           : List[str]  – GRPO 生成的 N 条回答
    reference_answer      : List[str]  – build_gsm8k() 保留的 gold
    返回 List[float]      – 每条回答的标量奖励
    """
    rewards = []
    for comp, gold in zip(completions, reference_answer):
        m = re.search(r"####\s*([-+]?\d+(?:\.\d+)?)", comp)
        if m is None:
            rewards.append(-1.0)          # 格式不对，直接惩罚
            continue
        pred_ans = m.group(1).strip()

        # hidden-state pointer：只取答案数值这一段的隐藏向量
        h_pred = _encode_last_hidden(pred_ans)
        h_gold = _encode_last_hidden(gold)

        # 余弦相似度 → 标量 reward
        sim = F.cosine_similarity(h_pred, h_gold, dim=0).item()  # [-1,1]
        # 可选：把区间平移到 [0,1]
        reward = (sim + 1) / 2.0
        rewards.append(float(reward))
    return rewards


In [9]:
# 4) GRPO Trainer
model.generation_config.temperature = 0.7   # 原本就有
model.generation_config.top_p = 0.9         # 现在在 generation_config 里改
model.generation_config.repetition_penalty = 1.15

train_cfg = GRPOConfig(
    output_dir="qwen0.6b-gsm8k-grpo",
    per_device_train_batch_size=3,
    gradient_accumulation_steps=4,   
    
    num_generations=3,
    max_completion_length=256,
    #beta = 0.005, 
    
    learning_rate=5e-5,
    
    max_grad_norm = 0.2,                 #  开裁剪
    weight_decay = 0.1,
    warmup_ratio = 0.05,
    lr_scheduler_type = "cosine",

    num_train_epochs=1,
                         
    logging_steps=1,
    save_steps=1,
    save_total_limit=2,

    disable_tqdm=False,
    report_to=None,

    bf16=True,
    
)

In [10]:
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=reward_fn,
    train_dataset=gsm,
    args=train_cfg,
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
#debug_batch(trainer, num_batches=3)   # 随机抽 3 条测试

In [11]:
#resume_from_checkpoint="qwen0.6b-gsm8k-grpo/checkpoint-"
trainer.train()

`generation_config` default values have been modified to match model-specific defaults: {'top_k': 20, 'top_p': 0.9, 'repetition_penalty': 1.15, 'bos_token_id': 151643, 'eos_token_id': [151645, 151643]}. If this is not desired, please set these values explicitly.
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Step,Training Loss
1,-0.0
2,0.0
3,0.0
4,0.0
5,0.0
6,0.0001
7,0.0
8,0.0001


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


KeyboardInterrupt: 