下载对应版本的transformers

In [None]:
!pip install -q transformers==4.41.2
!pip install -q accelerate==0.30.1
!pip install -q peft==0.11.1
!pip install -q trl==0.8.6
!pip install -q datasets==2.19.0
!pip install -q bitsandbytes
!pip install -q fsspec==2025.3.0
!pip install -q gcsfs==2025.3.0

由于要读取json文件的内容，因此需要在colab上将drive挂载到该python notebook中

In [None]:
from google.colab import drive
drive.mount('/content/drive')

查看一下GPU信息

In [None]:
!nvidia-smi

导入需要的库

In [None]:
from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, prepare_model_for_kbit_training
from trl import SFTTrainer
from transformers import TrainingArguments

加载数据

In [None]:
dataset_path = "/content/drive/MyDrive/java_sft/data/java_interview.jsonl"
raw_dataset = load_dataset("json", data_files=dataset_path)

# 默认会有一个 "train" split，这里再划分出验证集
dataset = raw_dataset["train"].train_test_split(test_size=0.1, seed=42)
dataset


加载模型（Qwen/Qwen2.5-3B-Instruct）

In [None]:
model_name = "Qwen/Qwen2.5-3B-Instruct"  # 比原来的 Base 版更适合做 Chat 微调

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True
)

# 有些 Qwen 没显式 pad_token，这里兜底一下
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    load_in_4bit=True,     # 4bit 量化
    device_map="auto",
    trust_remote_code=True
)
model = prepare_model_for_kbit_training(model)


构建 LoRA 微调配置

In [None]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    # target_modules=["q_proj", "v_proj"],
    target_modules=["qkv_proj", "o_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)

构造 SFTTrainer（核心训练）

In [None]:
training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/java_sft/qwen-java-sft",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=20,
    num_train_epochs=2,
    save_steps=200,
    learning_rate=2e-4,
    bf16=True,
)

# def format_example(example):
#     prompt = (
#         f"<|im_start|>system\n{example['system']}<|im_end|>\n"
#         f"<|im_start|>user\n{example['input']}<|im_end|>\n"
#         f"<|im_start|>assistant\n{example['output']}<|im_end|>"
#     )
#     return [prompt]

def format_example(batch):
    results = []

    # batch["system"] 是一个 list，例如 ["你是面试官", "你是面试官", ...]
    for sys_msg, usr_msg, asst_msg in zip(batch["system"], batch["input"], batch["output"]):

        messages = [
            {"role": "system", "content": sys_msg},
            {"role": "user", "content": usr_msg},
            {"role": "assistant", "content": asst_msg},
        ]

        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )

        results.append(text)

    return results    # <-- 必须是 list[str]


trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    dataset_text_field=None,
    max_seq_length=1024,
    formatting_func=format_example,
    args=training_args,
    peft_config=lora_config
)


开始训练

In [None]:
trainer.train()

保存 LoRA 模型

In [None]:
# 这里保存的是 LoRA + 基座的 PeftModel
trainer.model.save_pretrained("/content/drive/MyDrive/java_sft/qwen-java-sft")
tokenizer.save_pretrained("/content/drive/MyDrive/java_sft/qwen-java-sft")

推理测试（加载 LoRA 模型）

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import AutoPeftModelForCausalLM

# ---------------------------
# 1. 加载微调前模型（Base Model）
# ---------------------------
base_model_name = model_name  # 和上面保持一致

base_tokenizer = AutoTokenizer.from_pretrained(
    base_model_name,
    trust_remote_code=True
)
if base_tokenizer.pad_token is None:
    base_tokenizer.pad_token = base_tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True
)

# ---------------------------
# 2. 加载微调后模型（LoRA + SFT）
# ---------------------------
ft_model_path = "/content/drive/MyDrive/java_sft/qwen-java-sft"

ft_tokenizer = AutoTokenizer.from_pretrained(
    ft_model_path,
    trust_remote_code=True
)
if ft_tokenizer.pad_token is None:
    ft_tokenizer.pad_token = ft_tokenizer.eos_token

# 关键：使用 AutoPeftModelForCausalLM 加载带 LoRA 的模型
ft_model = AutoPeftModelForCausalLM.from_pretrained(
    ft_model_path,
    torch_dtype="auto",
    device_map="auto"
)

base_pipe = pipeline("text-generation", model=base_model, tokenizer=base_tokenizer)
ft_pipe = pipeline("text-generation", model=ft_model, tokenizer=ft_tokenizer)

定义答案生成与答案对比函数

In [None]:
def generate_answer(pipe, tokenizer, question, max_new_tokens=256):
    messages = [
        {"role": "system", "content": "你是一名 Java 面试官。"},
        {"role": "user", "content": question}
    ]

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # 先拿到 prompt 的长度
    inputs = tokenizer(prompt, return_tensors="pt").to(pipe.model.device)
    prompt_len = inputs["input_ids"].shape[1]

    outputs = pipe(
        prompt,
        max_new_tokens=max_new_tokens,
        do_sample=False,   # 面试场景更追求稳定，可以先关掉 sampling
        pad_token_id=tokenizer.eos_token_id
    )[0]["generated_text"]

    # 重新用 tokenizer 编码解码，截取「新生成的」部分
    output_ids = tokenizer(outputs, return_tensors="pt")["input_ids"][0]
    new_tokens = output_ids[prompt_len:]  # 只取新生成 token

    answer = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
    return answer

def compare_answer(question, max_new_tokens=256):
    base_answer = generate_answer(base_pipe, base_tokenizer, question, max_new_tokens)
    ft_answer = generate_answer(ft_pipe, ft_tokenizer, question, max_new_tokens)

    # print("========【微调前】========")
    # print(base_answer)
    # print("\n========【微调后】========")
    # print(ft_answer)

    return base_answer, ft_answer


验证及对比结果

In [None]:
question = "Redis的持久化策略有两种，分别是什么？以及两者的区别是什么？"
res = compare_answer(question, 1024)
print("========【微调前】========")
print(res[0])
print("\n========【微调后】========")
print(res[1])