Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add kto Trainer #8417

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
154 changes: 154 additions & 0 deletions examples/KTO/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
## KTO

ContextualAI最近提出了一种有趣的替代方案,称为Kahneman-Tversky优化(KTO),它完全根据被标记为“好”或“坏”的单个例子来定义损失函数(例如👍 或👎 在聊天UI中看到的图标),不需要像DPO那样需要成对的偏好数据。这些标签在实践中更容易获得,KTO是一种很有前途的方法,可以不断更新生产环境中运行的模型。以下是KTO模型的Paddle实现。

## 快速开始

项目的整体组织结构如下:

```
├── kto_config.py # KTO训练参数
├── kto.py # 训练主函数
├── kto_trainer.py # KTO算法的trainer实现
├── model_base.py # 模型需要使用的基函数
├── model_config.py # 模型的配置参数
├── README.md # README文档
└── utils.py # 工具类
```

### 环境准备

- Python >= 3.9.16
- PaddlePaddle >= 2.6.0
- PaddleNLP >= 2.8.0

### 数据准备

数据集使用的是`trl-lib/kto-mix-14k`,可以直接使用datasets库进行加载:

```
>>> from datasets import load_dataset
>>> train_ds = load_dataset("trl-lib/kto-mix-14k")
>>> train_ds
DatasetDict({
train: Dataset({
features: ['prompt', 'completion', 'label'],
num_rows: 13500
})
test: Dataset({
features: ['prompt', 'completion', 'label'],
num_rows: 1500
})
})
>>> train_ds['train'][0]
{'prompt': [{'content': 'Q:Question: how old julio cesar chavez when he fought de la hoya I found the following answer on Google: He holds records for most successful consecutive defenses of world titles (27), most title fights (37), most title-fight victories (31) and he is after Joe Louis with (23) for most title defenses won by knockout (21). Is that a correct answer? Yes or no.\nA:', 'role': 'user'}], 'completion': [{'content': " Yes, the information you found on Google is correct. Julio César Chávez holds several records related to world title defenses and victories, and he is considered one of the greatest boxers in history. Here is a detailed answer to your question:\n\nJulio César Chávez was born on July 12, 1962, in Ciudad Obregón, Sonora, Mexico. He began boxing at a young age and quickly made a name for himself in the sport, winning his first world title in 1984 when he defeated Mario Miranda for the WBC super featherweight championship.\n\nOver the course of his career, Chávez would go on to hold titles in three different weight classes (super featherweight, lightweight, and junior welterweight) and defend his titles a record 27 times, including 21 times by knockout. This record for most successful consecutive defenses of world titles has never been broken, and it is a testament to Chávez's skill, determination, and durability as a fighter.\n\nIn addition to his record for most successful title defenses, Chávez also holds the records for most title fights (37) and most title-fight victories (31). These records are also unbroken and demonstrate Chávez's consistent success and dominance in the ring.\n\nChávez's impressive record and achievements have earned him a place among the greatest boxers of all time, and he is often compared to other boxing legends such as Joe Louis, who holds the record for most title defenses won by knockout (23). While Chávez did not quite match Louis's record for most title defenses won by knockout, he is still widely regarded as one of the greatest boxers in history and a true legend of the sport.\n\nIn conclusion, the information you found on Google is correct: Julio César Chávez holds several records related to world title defenses and victories, and he is considered one of the greatest boxers in history. His impressive record and achievements have earned him a place among the greatest boxers of all time, and he will always be remembered as a true legend of the sport.", 'role': 'assistant'}], 'label': True}
```
### 训练

lora单卡训练:

```
python kto.py \
--model_name_or_path=Llama-2-7b-chat-hf \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
--learning_rate 1e-4 \
--lr_scheduler_type cosine \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 200 \
--output_dir=kto-aligned-model-lora \
--warmup_ratio 0.1 \
--report_to wandb \
--logging_first_step \
--use_peft \
--data_seed 16 \
--lora_r=16 \
--lora_alpha=16 \
--bf16 \
--do_eval \
--evaluation_strategy steps \
--recompute
```

- `model_name_or_path`: 基座模型的名称。
- `per_device_train_batch_size`: 根据 prompt 进行生成及训练使用的批次大小(每张卡)。
- `num_train_epochs`: 模型训练的轮数。
- `learning_rate`: 训练的学习率。
- `lr_scheduler_type`: scheduler类型,可选linear和cosine。
- `gradient_accumulation_steps`: 模型参数梯度累积的步数,可用于扩大 batch size。实际的 batch_size = per_device_train_batch_size * gradient_accumulation_steps。
- `logging_steps`: 训练日志打印间隔。
- `eval_steps`: 训练评估间隔步数。
- `output_dir`: 模型的保存路径。
- `warmup_ratio`: warmup步数占总步数的比例。
- `report_to`: 日志输出工具,包含wandb,tensorboard,visualdl。
- `logging_first_step`: 是否记录和评估第一个 `global_step`。(`bool`,可选,默认为`False`)
- `use_peft`: 是否使用lora。
- `data_seed`: 数据集的种子随机数。
- `lora_r`: LoRA 算法中rank(秩)的值,默认为8。
- `lora_alpha`: LoRA 算法的alpha的缩放参数。
- `bf16`: 是否使用 bf16 混合精度训练。
- `do_eval`: 是否需要评估。
- `evaluation_strategy`: 评估策略,默认为no。"no":训练期间不进行评估;"steps":在每eval_steps结束进行;"epoch":在每个 epoch 结束时进行。
- `recompute`: 是否使用recompute训练,重计算transformer结构。

多卡训练:
```
python -u -m paddle.distributed.launch --gpus "2,3,4,5" kto.py \
--model_name_or_path=Llama-2-7b-chat-hf \
--per_device_train_batch_size 4 \
--num_train_epochs 1 \
--learning_rate 1e-5 \
--lr_scheduler_type cosine \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 100 \
--output_dir=kto-aligned-model \
--warmup_ratio 0.1 \
--report_to wandb \
--data_seed 16 \
--do_eval \
--evaluation_strategy steps \
--logging_first_step \
--sharding "stage2" \
--bf16 \
--fp16_opt_level O2 \
--sharding_parallel_degree 4 \
--recompute
```

- `model_name_or_path`: 基座模型的名称。
- `per_device_train_batch_size`: 根据 prompt 进行生成及训练使用的批次大小(每张卡)。
- `num_train_epochs`: 模型训练的轮数。
- `learning_rate`: 训练的学习率。
- `lr_scheduler_type`: scheduler类型,可选linear和cosine。
- `gradient_accumulation_steps`: 模型参数梯度累积的步数,可用于扩大 batch size。实际的 batch_size = per_device_train_batch_size * gradient_accumulation_steps。
- `logging_steps`: 训练日志打印间隔。
- `eval_steps`: 训练评估间隔步数。
- `output_dir`: 模型的保存路径。
- `warmup_ratio`: warmup步数占总步数的比例。
- `report_to`: 日志输出工具,包含wandb,tensorboard,visualdl。
- `data_seed`: 数据集的种子随机数。
- `do_eval`: 是否需要评估。
- `evaluation_strategy`: 评估策略,默认为no。"no":训练期间不进行评估;"steps":在每eval_steps结束进行;"epoch":在每个 epoch 结束时进行。
- `logging_first_step`: 是否记录和评估第一个 `global_step`。(`bool`,可选,默认为`False`)
- `bf16`: 是否使用 bf16 混合精度训练。
- `fp16_opt_level`: 混合精度策略,支持O1 自动混合精度,O2 pure fp16精度训练。
- `sharding_parallel_degree`: sharding_parallel_degree 表示sharding发生在多少路数据流之间。
- `sharding`: 是否使用Paddle的Sharding数据并行功能,用户的参数。支持sharding `stage1`, `stage2` or `stage3`。
- `recompute`: 是否使用重计算训练。可以节省显存。

## 推理
模型的推理请参考[推理](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#4-%E6%8E%A8%E7%90%86)

## 服务化部署

模型的服务化部署请参考[服务化部署](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#5-%E6%9C%8D%E5%8A%A1%E5%8C%96%E9%83%A8%E7%BD%B2)

## Acknowledge

我们借鉴了[trl](https://github.com/huggingface/trl/tree/main)的优秀设计实现,在此对其作者表示感谢。

## 参考文献

[1] Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, Douwe Kiela: [KTO: Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/abs/2402.01306). CoRR abs/2402.01306 (2024)
114 changes: 114 additions & 0 deletions examples/KTO/kto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass

from datasets import load_dataset
from kto_config import KTOConfig
from kto_trainer import KTOTrainer
from model_config import ModelConfig

from paddlenlp.peft import LoRAConfig
from paddlenlp.trainer import PdArgumentParser
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer


# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the KTO training script.
"""

dataset_name: str = "trl-lib/kto-mix-14k"


if __name__ == "__main__":
parser = PdArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
script_args, kto_args, model_args = parser.parse_args_into_dataclasses()

# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, dtype=model_args.paddle_dtype)
model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, dtype=model_args.paddle_dtype)

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# TODO(wugaosheng) adapt to chattemplte
# If we are aligning a base model, we use ChatML as the default template
# if tokenizer.chat_template is None:
# model, tokenizer = setup_chat_format(model, tokenizer)

# Load the dataset
dataset = load_dataset(script_args.dataset_name)

# Apply chat template
def format_dataset(example):
# Add chattemple to prompt input
chat_msg = [item["content"] for item in example["prompt"]]
res = ""
for i in range(0, len(chat_msg), 2):
if i + 1 < len(chat_msg):
res += f"<s>[INST] {chat_msg[i].strip()} [/INST] {chat_msg[i+1].strip()} </s>"
pd_output = tokenizer.apply_chat_template(chat_msg[-1], tokenize=False)
pd_output = res + pd_output
pd_output = res + f"<s>[INST] {chat_msg[-1].strip()} [/INST]"
example["prompt"] = pd_output

# Add chattemple to completion
chat_msg = ["Hi"] + [item["content"] for item in example["completion"]]
res = ""
for i in range(0, len(chat_msg), 2):
if i + 1 < len(chat_msg):
res += f"<s>[INST] {chat_msg[i].strip()} [/INST] {chat_msg[i+1].strip()} </s>"
pd_output = res
# remove fake user content
example["completion"] = pd_output.split("[/INST]")[-1]
return example

formatted_dataset = dataset.map(format_dataset)

if model_args.use_peft:
target_modules = [
".*q_proj.*",
".*v_proj.*",
".*k_proj.*",
".*gate_proj.*",
".*up_proj.*",
".*o_proj.*",
".*down_proj.*",
]

peft_config = LoRAConfig(
target_modules=target_modules,
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
)
else:
peft_config = None
# Initialize the KTO trainer
kto_trainer = KTOTrainer(
model,
model_ref,
args=kto_args,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["test"],
tokenizer=tokenizer,
peft_config=peft_config,
)

# Train and push the model to the Hub
kto_trainer.train()
86 changes: 86 additions & 0 deletions examples/KTO/kto_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional

from paddlenlp.trainer import TrainingArguments


@dataclass
class KTOConfig(TrainingArguments):
r"""
KTOConfig collects all training arguments related to the [`KTOTrainer`] class.

Using [`PdArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.

Parameters:
max_length (`int`, *optional*, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
max_prompt_length (`int`, *optional*, defaults to `None`):
The maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`int`, *optional*, defaults to `None`):
The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
beta (`float`, defaults to 0.1):
The beta factor in KTO loss. Higher beta means less divergence from the initial policy.
desirable_weight (`float`, *optional*, defaults to 1.0):
The desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
undesirable_weight (`float`, *optional*, defaults to 1.0):
The undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
padding_value (`int`, defaults to `0`):
The padding value if it is different to the tokenizer's pad_token_id.
truncation_mode (`str`, defaults to `keep_end`):
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
generate_during_eval (`bool`, defaults to `False`):
Whether to sample and log generations during evaluation step.
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
If no model is provided, we need to know if the model_init returns an encoder-decoder.
precompute_ref_log_probs (`bool`, defaults to `False`):
Flag to precompute reference model log probabilities for training and evaluation datasets. This is useful if you want to train
without the reference model and reduce the total GPU memory needed.
model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string.
ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the ref model from a string.
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the datasets.
"""

max_length: Optional[int] = None
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
max_prompt_length: Optional[int] = None
"""The maximum length of the prompt. This argument is required if you want to use the default data collator."""
max_completion_length: Optional[int] = None
"""The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder."""
beta: float = 0.1
"""The beta factor in KTO loss. Higher beta means less divergence from the initial policy."""
desirable_weight: Optional[float] = 1.0
"""The desirable losses are weighed by this factor."""
undesirable_weight: Optional[float] = 1.0
"""The undesirable losses are weighed by this factor."""

label_pad_token_id: int = -100
padding_value: int = None
truncation_mode: str = "keep_end"
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
precompute_ref_log_probs: bool = False
model_init_kwargs: Optional[Dict] = None
ref_model_init_kwargs: Optional[Dict] = None
dataset_num_proc: Optional[int] = None
data_seed: int = None