# 使用DPO算法微调模型

本教程演示如何使用DPO算法微调大模型（以Llama-3.1-8B模型为例）。通过本教程，你将学习如何配置训练参数，并使用 DPO 算法在具有偏好标签的数据上进行强化学习式的训练，从而提升模型在对齐任务中的性能。

## 1. 什么是 DPO 算法？

DPO（Direct Preference Optimization）是一种用于训练语言模型更好地对齐人类偏好的方法。它不依赖显式的奖励模型或策略梯度方法，而是直接在“人类偏好数据”上优化模型，使其在给定两个回答中更倾向于人类偏好的那个。

## 2. 环境配置

在开始之前，请确保您已安装 ``align-anything`` 包。

```bash
# 克隆仓库
git clone git@github.com:PKU-Alignment/align-anything.git
cd align-anything

# 使用conda创建虚拟环境
conda create -n align-anything python==3.11
conda activate align-anything
```

- **`[Optional]`** We recommend installing [CUDA](https://anaconda.org/nvidia/cuda) in the conda environment and set the environment variable.

```bash
# 我们在 H800 计算集群上测试过，这个版本的 CUDA 效果很好。
# 您可以根据计算集群的实际情况调整此版本。

conda install nvidia/label/cuda-12.2.0::cuda
export CUDA_HOME=$CONDA_PREFIX
```

> 如果您的 CUDA 安装在不同的位置，例如 `/usr/local/cuda/bin/nvcc`，您可以按如下方式设置环境变量：

```bash
export CUDA_HOME="/usr/local/cuda"
```

最后，通过以下命令安装 `align-anything`：

```bash
# 我们为训练和评估准备了快速安装。
# 如果您只需要使用训练或评估模块，
# 您可以安装相应的依赖项。
pip install -e .[train] # 安装训练依赖项
pip install -e .[evaluate] # 安装评估依赖项

# 如果您需要安装所有依赖项，可以使用以下命令：
pip install -e .[all]
```

## 3. Llama-3.1-8B-Instruct模型输出示例
下面，让我们首先测试Llama-3.1-8B-Instruct模型的zero-shot能力。
### 3.1 导入所需的库

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import torch

os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"

  from .autonotebook import tqdm as notebook_tqdm


[1742778596.498488] [dsw-519274-66f65ff576-678dh:4051137:f]        vfs_fuse.c:281  UCX  ERROR inotify_add_watch(/tmp) failed: No space left on device


### 3.2 加载原始的Llama 模型

In [None]:
device = "cuda"  # 将device设置为"cuda"以使用GPU
model_path = "/PATH/TO/YOUR/Llama-3.1-8B-Instruct"  # 请更换为实际的模型路径
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)

# 将模型设置为eval模式
model.eval()

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.29it/s]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

### 3.3 测试原始模型的性能

让我们用一个示例问题测试 Llama-3.1-8B-Instruct 模型。

In [3]:
messages = [
    {"role": "system", "content": "You are a helpful assistant that answers user queries."},
    {
        "role": "user",
        "content": "Recently, a wild animal in the local area has become aggressive towards humans and caused several injuries. How should I handle this wild animal?",
    },
]

input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([input_text], return_tensors="pt").to(device)

# the model generate new tokens
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=2048)
# convert the generated tokens to text
generated_text = tokenizer.decode(
    output[0][len(inputs['input_ids'][0]) :], skip_special_tokens=True
)
print("\nGenerated Text:", generated_text)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Generated Text: If a wild animal in your local area has become aggressive and caused injuries, it's essential to take precautions and follow the right steps to ensure your safety and the safety of others. Here's a step-by-step guide:

1.  **Stay calm**: Keep a safe distance from the animal and avoid direct confrontation. Panicking can escalate the situation, and you don't want to provoke the animal further.

2.  **Identify the animal**: If possible, try to determine the type of animal and its size, as this information will be helpful for wildlife experts or local authorities.

3.  **Contact local authorities**: Reach out to local animal control, wildlife services, or a professional wildlife removal service. They will send trained experts to handle the situation.

4.  **Keep children and pets indoors**: Ensure that children and pets are safely indoors, away from the area where the animal is present.

5.  **Do not approach or feed the animal**: Feeding or approaching the animal can make

由此可见，llama 3.1的回答虽然内容详细，但存在信息冗长且关键风险未充分强调的问题。

例如，它建议“识别动物”而没有明确提醒远离危险区域，可能会误导人靠近观察，从而增加受伤风险，不利于突发情况下的紧急安全反应。

## 4. 使用DPO算法对齐模型

**注意**：如果您无法访问huggingface.co，请将huggingface的endpoint设置为hf-mirror.com。您可以进行以下操作：

`export HF_ENDPOINT="https://hf-mirror.com"`

在这里，我们以 PKU-SafeRLHF 系列数据集为例。PKU-SafeRLHF 数据集是一个注重安全对齐的偏好数据集。该数据集中的每条数据都包含对同一个问题的两个回答，以及这两个回答对应的安全元标签和偏好标注。

可以参考如下的训练脚本：

```bash
MODEL_NAME_OR_PATH="meta-llama/Llama-3.1-8B-Instruct" # model path

TRAIN_DATASETS="PKU-Alignment/PKU-SafeRLHF-single-dimension" # dataset path
TRAIN_TEMPLATE="PKUSafeRLHF" # dataset template
TRAIN_SPLIT="train" # split the dataset

OUTPUT_DIR="../outputs/llama_dpo" # output dir

# For wandb online logging
export WANDB_API_KEY="YOUR_API_KEY"

# Source the setup script
source ./setup.sh

# Execute deepspeed command
deepspeed \
     --master_port ${MASTER_PORT} \
     --module align_anything.trainers.text_to_text.dpo \
     --model_name_or_path ${MODEL_NAME_OR_PATH} \
     --train_template ${TRAIN_TEMPLATE} \
     --train_datasets ${TRAIN_DATASETS} \
     --train_split ${TRAIN_SPLIT} \
     --output_dir ${OUTPUT_DIR}

```

训练完成后，您可以在`OUTPUT_DIR`下找到训练的模型权重。

## 5. 测试DPO训练后的模型性能

在训练结束后，我们试图测试训练后的模型对齐情况是否有所改观。

### 5.1 加载新的模型权重


In [None]:
model_path = "/PATH/TO/YOUR/TRAINED_MODEL"  # 请更换为实际的模型路径
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)

# 将模型设置为eval模式
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128257, 4096, padding_idx=128256)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps

### 5.2 测试新模型的性能

In [6]:
messages = [
    {"role": "system", "content": "You are a helpful assistant that answers user queries."},
    {
        "role": "user",
        "content": "Recently, a wild animal in the local area has become aggressive towards humans and caused several injuries. How should I handle this wild animal?",
    },
]

input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([input_text], return_tensors="pt").to(device)

# the model generate new tokens
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=2048)
# convert the generated tokens to text
generated_text = tokenizer.decode(
    output[0][len(inputs['input_ids'][0]) :], skip_special_tokens=True
)
print("\nGenerated Text:", generated_text)


Generated Text: If a wild animal has become aggressive and is causing injuries, it's essential to exercise extreme caution.  First, make sure all people and pets are kept away from the area.  Next, contact a professional wildlife removal expert or local animal control service to safely capture and relocate the animal.  In the meantime, try to determine what may have caused the animal's behavior change, such as habitat loss or food availability, and take steps to mitigate those factors. Finally, educate people in the area on how to peacefully coexist with the animal and what precautions should be taken when interacting with it.


由此可见，训练后的模型回答更加简洁且聚焦在关键安全措施上。

如“远离”“联系专业人士”“分析原因”和“公众教育”，体现出以人为本、减少直接接触的风险预防导向，更符合安全对齐的原则。

## 6. 致谢

- [Hugging Face Transformers 文档](https://huggingface.co/docs/transformers/index)
- [DPO 论文](https://arxiv.org/abs/2305.18290)