In [None]:
from datasets import load_dataset, Dataset
from trl import SFTConfig, SFTTrainer
from transformers import AutoTokenizer
from peft import LoraConfig

model_name = "/data/xxx/LLMs/Qwen/Qwen2.5-0.5B-Instruct"
output_dir="/data/xxx/tigerHandle_c4/chap42_lora_output"

import pandas as pd

In [None]:
# 加载数据
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset(path = "/data/xxx/LLMs/xiaofengalg___chinese-medical-dialogue", split="train")

# 自己训练时不要截断
# dataset = dataset[:1000]
# df_data = pd.DataFrame(dataset)
# dataset = Dataset.from_pandas(df_data)

In [None]:
# 将数据格式转换为Qwen的chatml
def convert_to_chatml(samples):
    output_texts = []
    for i in range(len(samples['output'])):
        
        system_content = f'你是一名经验丰富的医生，请回答以下问题{samples['instruction'][i]}，以下是病情的详细描述：'
        user_content = samples['input'][i]
        assistant_content = samples['output'][i]
        
        chatml_text = f"<|im_start|>system\n{system_content}<|im_end|>\n"
        chatml_text += f"<|im_start|>user\n{user_content}<|im_end|>\n"
        chatml_text += f"<|im_start|>assistant\n{assistant_content}<|im_end|>\n"

        output_texts.append(chatml_text)
    return output_texts

# data_display = convert_to_chatml(dataset)

In [None]:
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "up_proj", "down_proj"],
    modules_to_save=["lm_head", "embed_token"],
    task_type="CAUSAL_LM",
)

In [None]:
# 设置模型训练参数
training_args = SFTConfig(
    # 控制数据预处理时单条样本（输入+输出）的 token 序列最大长度
    bf16=True,
    max_seq_length=1024,
    num_train_epochs = 1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    # 日志记录步数
    logging_steps=1000,
    output_dir=output_dir,
    save_steps = 20000
)

In [None]:
trainer = SFTTrainer(
    model=model_name,
    processing_class = tokenizer,
    args=training_args,
    train_dataset=dataset,
    formatting_func=convert_to_chatml,
    peft_config=peft_config
)

trainer.train()