# 一、 简介

## 1. 什么是DPO
在RLHF中，在得到偏好数据集（Preference Dataset）后RLHF实际上是先训练一个reward model，之后在利用PPO算法对LLM进行微调；而DPO去除了这两个步骤，在得到偏好数据集后直接优化LLM。对比如下图：
<p align=center>
    <img src="./imgs/dpo.png" width=1000>
</p>
<p align=center>
    <em>image: RLHF与DPO对比</em>
</p>

## 2. 什么是偏好数据集？以及DPO所需要的数据集？
回顾RLHF的流程：
<p align=center>
    <img src="./imgs/rewardmodel.jpg" width=800>
</p>
<p align=center>
    <em>image: 训练奖励模型</em>
</p>

因此偏好数据集就是用待微调的LLM用一个prompt生成4~9个response，然后人工对这些response做一个排序。这样就得到了一样本。之后会拿这些数据去训练一个reward model。因此偏好数据集的组织形式就是一个序列：
```json
{
    prompt: "How are you?",
    response: [answer1, answer2, ...] // 已经按照偏好顺序排序
}
```
或者偏好数据集的另一种形式，这种形式是按照pair-wise排列的
```json
{
    prompt: "How are you?",
    chosen: answer_j,   // j的排序在k的前面
    reject: answer_k,
}
```
# 二、数据集加载
数据集为[stack-exchange-paired数据集](https://huggingface.co/datasets/lvwerra/stack-exchange-paired?library=datasets)

In [1]:
import os
# 获取当前工作目录
current_directory = os.getcwd()
print("当前工作目录:", current_directory)

# 设置新的工作目录
new_directory = "/mnt/d/code/llm/llm-course"
os.chdir(new_directory)

# 再次获取当前工作目录，确认是否更改成功
current_directory = os.getcwd()
print("新的工作目录:", current_directory)

当前工作目录: /mnt/d/code/llm/llm-course/mynotes
新的工作目录: /mnt/d/code/llm/llm-course


In [5]:
from datasets import load_dataset

dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    split="train",
    data_dir="data/rl"
)
original_columns = dataset.column_names

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/20 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

Loading dataset shards:   0%|          | 0/42 [00:00<?, ?it/s]

过滤一下数据集，选择特定的prompt

In [15]:
# print(type(dataset))
filtered_dataset = dataset.filter(lambda example: example['qid'] == 6763429)
print(len(filtered_dataset))

Filter:   0%|          | 0/7435908 [00:00<?, ? examples/s]

9


打印出来出来后，这个过滤后的数据集的qid都一样，因此question是一样的，区别在于j与k的组合不同

In [17]:
row = 1

print('##############################################')
print(filtered_dataset[row]['qid'])

print('##############################################')
print(filtered_dataset[row]['question'])

print('##############################################')
print(filtered_dataset[row]['date'])

print('##############################################')
print(filtered_dataset[row]['metadata'])

print('##############################################')
print(filtered_dataset[row]['response_j'])

print('##############################################')
print(filtered_dataset[row]['response_k'])

##############################################
6763429
##############################################
i want to write a shape with " \* " and " | " the shape is below.
The program must take height and width from user.Width is column number without ' | '.I tried to write but confused.My code sometimes works great and sometimes being stupid.For example when i enter height : 13, width : 4 it writes one more,if witdh is 1 it enters infinite loop.While trying to solve it became too conflicted.Must i fix it or rewrite ? Here is the code : height =10, width = 5

```

|*____|    
|_*___|
|__*__|
|___*_|
|____*|
|___*_|
|__*__|
|_*___|
|*____|
|_*___|

```

```
      private static void Function()
      {
        int height, width;

        if (width == 2)
            while (height > 0)
            {
                FirstPart(width, height);
                height -= width;
            }
        else
            while (height > 0)
            {
                if (height > 1)
                {


将数据集转换成特定形式，方便训练

In [18]:
from typing import Dict

def return_prompt_and_responses(samples) -> Dict[str, str]:
    return {
        "prompt": [
            "Question: " + question + "\n\nAnswer: "
            for question in samples["question"]
        ],
        "chosen": samples["response_j"], # rated better than k
        "rejected": samples["response_k"], # rated worse than j
    }
dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns
)

Map:   0%|          | 0/7435908 [00:00<?, ? examples/s]

Dataset({
    features: ['prompt', 'chosen', 'rejected'],
    num_rows: 7435908
})

# 三、 对7B Llama v2进行SFT
- 这里使用了QLoRA的方式

In [19]:
!python3 -m lib.stack_llama_2.sft_llama2

/usr/bin/python3: Error while finding module specification for 'lib.stack_llama_2.sft_llama2' (ModuleNotFoundError: No module named 'lib')


# 四、 对7B Llama v2进行DPO

In [None]:
!python3 -m lib.stack_llama_2.dpo_llama2