# Train Model with DPO

Code authored by: Shaw Talebi

### imports

In [1]:
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
from datasets import Dataset
import pandas as pd
from modelscope import snapshot_download

  from .autonotebook import tqdm as notebook_tqdm


### load data

In [4]:
# 读取jsonl文件
def load_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

# 加载训练数据
train_data = load_jsonl('train.jsonl')

# 处理数据为DPO所需格式
processed_data = []
for item in train_data:
    # 获取用户消息内容
    prompt = str(item['messages'][0]['content'])
    # 获取选择的和拒绝的回答
    chosen = item['chosen']['content']
    rejected = item['rejected']['content']

    processed_data.append({
        'prompt': [{'role': 'user', 'content': prompt}],
        'chosen': chosen,
        'rejected': rejected
    })

# 创建训练数据集
train_dataset = Dataset.from_pandas(pd.DataFrame(processed_data))

# 划分验证集 (取20%的数据作为验证集)
dataset = train_dataset.train_test_split(test_size=0.2)
dataset = {
    'train': dataset['train'],
    'valid': dataset['test']  # 验证集
}

### load model

In [None]:
model_dir = snapshot_download("Qwen/Qwen3-0.6B", cache_dir="./", revision="master")
model_name = "Qwen/Qwen3-0.6B"

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token # set pad token

In [None]:
training_args = DPOConfig(
    output_dir='./dpo',
    logging_steps=25,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    save_strategy="epoch",
    eval_strategy="epoch",
    eval_steps=1,
)

device = torch.device('cuda')

In [None]:
trainer = DPOTrainer(
    model=model, 
    args=training_args, 
    processing_class=tokenizer, 
    train_dataset=dataset['train'],
    eval_dataset=dataset['valid'],
)
trainer.train()
