In [None]:
# 安装依赖（如环境已具备可跳过）
%pip install -q -U transformers datasets peft bitsandbytes accelerate


In [None]:
import os, glob, json, datetime, math
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING

# 基础配置
checkpoint_id = "Qwen/Qwen2.5-7B-Instruct"
artifacts_dir = "./checkpoints"
data_dir = "./data"
max_seq_len = 2048
seed_num = 42

os.makedirs(artifacts_dir, exist_ok=True)

# 选择最新训练集（wukong_dataset_*.jsonl）
jsonl_files = sorted(glob.glob(os.path.join(data_dir, "wukong_dataset_*.jsonl")), key=os.path.getmtime, reverse=True)

train_jsonl = jsonl_files[0]
print(f"using dataset: {train_jsonl}")

# 加载数据集
train_set = load_dataset("json", data_files=train_jsonl, split="train")
train_set


In [None]:
# 加载分词器（Qwen2.5 特点：使用 chat template）
tokenizer = AutoTokenizer.from_pretrained(checkpoint_id, trust_remote_code=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token, tokenizer.eos_token, tokenizer.pad_token_id, tokenizer.eos_token_id


In [None]:
# 4bit 量化（QLoRA）
compute_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=compute_dtype,
)

base_model = AutoModelForCausalLM.from_pretrained(
    checkpoint_id,
    trust_remote_code=True,
    quantization_config=bnb_cfg,
    device_map="cuda:0",
)
base_model.config.use_cache = False
base_model.gradient_checkpointing_enable()
# k-bit 训练准备（关键，否则反向无 grad）
base_model = prepare_model_for_kbit_training(base_model)


In [None]:
# LoRA 配置（根据 peft 映射获取 Qwen2 的推荐 target_modules）
lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING["qwen2"],
)
peft_model = get_peft_model(base_model, lora_cfg)
# 训练前确保输入需要梯度（配合 k-bit 预处理）
peft_model.enable_input_require_grads()
peft_model.config.use_cache = False
peft_model.print_trainable_parameters()


In [None]:
# 构造监督样本：使用 Qwen 对话模板，并仅对 assistant 段落计算 loss
from datasets import Dataset

def format_sample_for_qwen(record):
    instr = (record.get("instruction") or "").strip()
    ans = (record.get("output") or "").strip()
    if not instr or not ans:
        return {"input_ids": [], "labels": []}

    msgs_no_assist = [
        {"role": "system", "content": "你是《黑神话：悟空》领域助手，回答准确、简明。"},
        {"role": "user", "content": instr},
    ]
    # prompt（包含 assistant 起始标记）
    prompt_ids = tokenizer.apply_chat_template(
        msgs_no_assist,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors=None,
    )

    msgs_full = msgs_no_assist + [{"role": "assistant", "content": ans}]
    full_ids = tokenizer.apply_chat_template(
        msgs_full,
        tokenize=True,
        add_generation_prompt=False,
        return_tensors=None,
    )

    # 截断到 max_seq_len
    full_ids = full_ids[:max_seq_len]
    # 计算分界位置
    cut = min(len(prompt_ids), len(full_ids))
    labels = [-100] * cut + full_ids[cut:]

    return {"input_ids": full_ids, "labels": labels}

proc_train = train_set.map(format_sample_for_qwen, remove_columns=train_set.column_names)
proc_train = proc_train.filter(lambda x: len(x["input_ids"]) > 0)
proc_train


In [None]:
# 数据整理器（按批次 padding，保持 -100 标签）
from typing import List, Dict

class QwenSftCollator:
    def __init__(self, pad_id: int, max_length: int = 2048, ignore_id: int = -100):
        self.pad_id = pad_id
        self.max_length = max_length
        self.ignore_id = ignore_id

    def __call__(self, features: List[Dict]):
        max_len = max(len(f["input_ids"]) for f in features)
        max_len = min(max_len, self.max_length)
        input_ids, labels = [], []
        for f in features:
            ids = f["input_ids"][:max_len]
            lbs = f["labels"][:max_len]
            pad = max_len - len(ids)
            if pad > 0:
                ids = ids + [self.pad_id] * pad
                lbs = lbs + [self.ignore_id] * pad
            input_ids.append(torch.tensor(ids, dtype=torch.long))
            labels.append(torch.tensor(lbs, dtype=torch.long))
        return {"input_ids": torch.stack(input_ids), "labels": torch.stack(labels)}

collator = QwenSftCollator(pad_id=tokenizer.pad_token_id, max_length=max_seq_len)

In [None]:
# 训练参数与 Trainer
from transformers import TrainingArguments, Trainer

now_tag = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join(artifacts_dir, f"qwen25_wukong_lora_{now_tag}")

args = TrainingArguments(
    output_dir=run_dir,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=1e-3,
    num_train_epochs=4,
    lr_scheduler_type="linear",
    warmup_ratio=0.03,
    logging_steps=1,
    save_steps=100,
    save_total_limit=2,
    optim="adamw_torch",
    bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
    fp16=not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()),
    report_to=[],
)

trainer = Trainer(
    model=peft_model,
    args=args,
    train_dataset=proc_train,
    data_collator=collator,
)
run_dir


In [None]:
# 开始训练并保存 LoRA 适配器
train_output = trainer.train()
print(train_output)

peft_model.save_pretrained(run_dir)
tokenizer.save_pretrained(run_dir)


In [None]:
# 推理测试：参考数据集选择两条问题进行生成
peft_model.eval()

TEST_QUERIES = [
    "我该怎么成为天命人？",
    "如何获得并合成出云棍？",
]

@torch.no_grad()
def infer_one(question: str) -> str:
    msgs = [
        {"role": "system", "content": "你是《黑神话：悟空》领域助手，回答准确、简明。"},
        {"role": "user", "content": question},
    ]
    input_ids = tokenizer.apply_chat_template(
        msgs,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    )
    input_ids = input_ids.to(peft_model.device)
    gen_ids = peft_model.generate(
        input_ids=input_ids,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.2,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )
    out_ids = gen_ids[0, input_ids.shape[-1]:]
    return tokenizer.decode(out_ids, skip_special_tokens=True).strip()

for q in TEST_QUERIES:
    ans = infer_one(q)
    print(f"Q: {q}\nA: {ans}\n" + "-" * 60)
