Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pipeline Parallel for PPO training and support generation with InferenceModel #7953

Merged
merged 56 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
c71fb92
Add Pipeline Parallel for PPO training.
guoshengCS Feb 2, 2024
2f2ad5e
Move new_ppo_trainer.py to ppo_trainer.py
guoshengCS Feb 2, 2024
8e8143e
Fix padding among batches of accumulation steps in _prepare_pipeline_…
guoshengCS Feb 4, 2024
e4d7781
Fix hcg using in TP generation
guoshengCS Feb 6, 2024
4d1641b
Try to support generation in PP. And allow extra training args passed…
guoshengCS Feb 6, 2024
34d4cd1
Support PP generation.
guoshengCS Feb 20, 2024
665fee2
Fix PP eval by unify prediction_step
guoshengCS Feb 20, 2024
a2e9702
Fix reward value showing error cased by BF16 dtype when eval
guoshengCS Feb 20, 2024
6c8441c
fix all
ZHUI Feb 22, 2024
d295d11
Make non-PipelineParallel models use the same loss layer with PipeMod…
guoshengCS Feb 22, 2024
38cc1a7
add offload.
ZHUI Feb 22, 2024
6ff38c8
Use create_loss to unify Pipe and non-Pipe usage.
guoshengCS Feb 22, 2024
6e49431
Add eval mode and offload level.
ZHUI Feb 23, 2024
c421af7
merge
ZHUI Feb 23, 2024
f6b5f97
fix all
ZHUI Feb 23, 2024
63df4fd
support tp+pp
ZHUI Feb 26, 2024
c9e5cad
fix data split.
ZHUI Feb 27, 2024
5979507
Fix position_ids in generation/eval/train.
guoshengCS Feb 28, 2024
16d886a
fix data group.
ZHUI Mar 1, 2024
1786357
add tp rank guard
ZHUI Mar 5, 2024
3bc48cb
Support rollout label data both with target length or source+target l…
guoshengCS Mar 6, 2024
5ae2c6f
Merge remote-tracking branch 'guosheng/ppo-4d' into ppo-4d/support_uc
ZHUI Mar 7, 2024
1b50869
Move metric calculation to rl_step to avoid comm.
guoshengCS Mar 7, 2024
bc80256
fix pad
ZHUI Mar 7, 2024
986b407
Merge remote-tracking branch 'guosheng/ppo-4d' into ppo-4d/support_uc
ZHUI Mar 7, 2024
b3f22c2
fix create group.
ZHUI Mar 7, 2024
8c7e612
no print
ZHUI Mar 7, 2024
2e3bf85
Suppport inference model generation.
guoshengCS Mar 7, 2024
df452d1
fix compatible for no eval model.
ZHUI Mar 8, 2024
d73b8a3
fix pp sync.
ZHUI Mar 8, 2024
e14e04b
remove debug info
ZHUI Mar 8, 2024
8015c8b
Refacor PPO training using StepTrainer.
guoshengCS Mar 12, 2024
860e61d
Open PolicyTrainer loss logging postprocess. More StepTrainer docs.
guoshengCS Mar 13, 2024
afa1b53
more timer.
ZHUI Mar 15, 2024
80e47e6
Merge remote-tracking branch 'guosheng/ppo-4d' into ppo-4d/support_uc
ZHUI Mar 15, 2024
757d3a7
fix bugs.
ZHUI Mar 19, 2024
6f2eff6
Merge pull request #1 from PaddlePaddle/ppo-4d/support_uc
guoshengCS Mar 19, 2024
1448b73
Add EMA and PPOMetric
guoshengCS Mar 21, 2024
b809631
add tests
ZHUI Mar 21, 2024
2f8d032
add unit test for rank guard.
ZHUI Mar 22, 2024
edf28f2
Merge pull request #2 from ZHUI/ppo-4d-test
guoshengCS Mar 22, 2024
fbb9ac3
Fix reshard zero3 and reshard infer.
guoshengCS Mar 25, 2024
aebdc89
Merge branch 'ppo-4d' of https://github.com/guoshengCS/PaddleNLP into…
guoshengCS Mar 25, 2024
cb6e4ff
Revert #7818 for llama and remove position_ids for gen/train/eval to …
guoshengCS Mar 26, 2024
4ddd415
Move reload/clean/data_group to comm_utils and use guard to decorate …
guoshengCS Apr 1, 2024
b68cb0d
Offload sync and other data reuse fix.
guoshengCS May 7, 2024
5e46ab6
Clead code
guoshengCS May 9, 2024
d538917
Update README
guoshengCS May 16, 2024
510ef03
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
guoshengCS May 27, 2024
1feb5ee
Update ppo_trainer
guoshengCS Jun 4, 2024
c8b3c61
format code
gongel Jun 11, 2024
c26583f
Fix make_position_ids by 4d causal mask.
guoshengCS Jun 12, 2024
1b8e4a3
Merge pull request #4 from gongel/ppo-4d
guoshengCS Jun 12, 2024
9acd87c
Merge branch 'ppo-4d' of https://github.com/guoshengCS/PaddleNLP into…
guoshengCS Jun 12, 2024
ffa4658
Fix nested_broadcast_tensor_with_empty import
guoshengCS Jun 12, 2024
f1e66f2
Update eval with make_attention_mask
guoshengCS Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions examples/RLHF/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# RLHF PPO

提供了基于强化学习 PPO 算法对 LLM 进行人类偏好对齐的代码及完整使用示例。其中 PPO 代码实现细节参考了 [PKU-Alignment/safe-rlhf](https://github.com/PKU-Alignment/safe-rlhf)(PKU Beaver) 中的 PPO 实现,支持reward normalization、pretraining loss等常用的 PPO 稳定训练策略;示例使用 PKU-Alignment/safe-rlhf 提供的部分数据集和模型。后续将持续完善扩展,支持更好效果、更低成本、更高性能、更大规模的 RLHF 能力。
提供了基于强化学习 PPO 算法对 LLM 进行人类偏好对齐的代码及完整使用示例,支持**3D 分布式并行训练以及 rollout 阶段使用预测优化进行生成加速**。其中 PPO 代码实现细节参考了 [PKU-Alignment/safe-rlhf](https://github.com/PKU-Alignment/safe-rlhf)(PKU Beaver) 中的 PPO 实现,支持reward normalization、pretraining loss等常用的 PPO 稳定训练策略;示例使用 PKU-Alignment/safe-rlhf 提供的部分数据集和模型。后续将持续完善扩展,支持更好效果、更低成本、更高性能、更大规模的 RLHF 能力。

## 快速开始

Expand All @@ -14,6 +14,9 @@
├── ppo_main.py # RLHF训练脚本
├── ppo_config.json # RLHF训练配置文件
├── ppo_trainer.py # RLHF训练执行器py脚本
├── ppo_config.json # RLHF训练配置文件
├── trainer_utils.py # Trainer补丁及工具py脚本
├── infer_utils.py # 生成加速工具py脚本
├── data # 数据集相关目录
│ └── base.py # 数据集基类及工具py文件
│ └── alpaca.py # alpaca(raw)数据集py文件
Expand All @@ -24,16 +27,20 @@
├── models # 模型相关目录
│ └── score_model_utils.py # score model基类及工具py文件
│ └── score_model.py # score model模型定义py文件
│ └── ppo_model_utils.py # PPO loss等模型策略py文件
│ └── pp_model_utils.py # 流水线并行补丁及工具py文件
│ └── model_pp.py # 流水线并行模型py文件
│ └── infer_model_utils.py # 预测加速模型补丁及工具py文件
└── README.md
```

### 环境准备

- Python >= 3.10
- PaddlePaddle >= 2.6.0
- PaddleNLP >= 2.6.0
- PaddleNLP 最新版本

此外还需要安装以下依赖:`pip install rich`
如需使用生成加速功能,需要安装 [paddlenlp_ops](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/csrc) ,请使用 `git clone https://github.com/PaddlePaddle/PaddleNLP.git` 克隆 PaddleNLP 代码库并且将 PaddleNLP/llm 目录的路径加入 PYTHONPATH(后续将进行完善)。安装 paddlenlp_ops 后训练时将直接开启生成加速(开启流水线并行时不支持生成加速),否则使用原生动态图进行生成。

### 数据准备

Expand Down Expand Up @@ -184,7 +191,8 @@ python -u -m paddle.distributed.launch reward_main.py ./reward_config.json
RLHF 阶段需要 actor model、reference model、critic model、reward model 四个模型;actor-model/reference-model 使用 SFT 模型进行 initialize/frozen;critic-model/reward-model 使用 reward 模型进行 initialize/frozen (另外注意若 SFT 使用 LoRA 请先将 LoRA 权重合并)。这里使用 PKU-Alignment/PKU-SafeRLHF 提供的 SFT 模型([PKU-Alignment/alpaca-7b-reproduced](https://huggingface.co/PKU-Alignment/alpaca-7b-reproduced))和 reward 模型([PKU-Alignment/beaver-7b-v1.0-reward](https://huggingface.co/PKU-Alignment/beaver-7b-v1.0-reward),注意该模型只关注 helpful 未考量 harmless)作为示例,使用 `ppo_main.py` 脚本根据 `ppo_config.json` 进行 RLHF 训练。

```
python -u -m paddle.distributed.launch ppo_main.py ./ppo_config.json
# 类型提升 warning 暂时通过 loglevel 屏蔽,待后续修复
GLOG_minloglevel=2 python -u -m paddle.distributed.launch ppo_main.py ./ppo_config.json
```

`ppo_config.json` 中的绝大部分参数释义同[LLM 精调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm#2-%E7%B2%BE%E8%B0%83),不再赘述,重点给出以下参数配置及释义(使用 PKU-Alignment/PKU-SafeRLHF 中的默认值):
Expand All @@ -210,7 +218,15 @@ python -u -m paddle.distributed.launch ppo_main.py ./ppo_config.json

另外所有 [`TrainingArguments` 支持参数配置](https://paddlenlp.readthedocs.io/zh/latest/trainer.html#trainingarguments)将为 actor-model 和 critic-model 的训练复用(如`sharding_stage`),除单独提供了 `critic_learning_rate/critic_weight_decay/critic_lr_scheduler_type/critic_warmup_ratio/critic_recompute` 这些参数支持为 critic-model 训练单独指定相应配置。actor-model 和 critic-model 的 checkpoints 将分别保存在 `outpt_dir` 所指定目录的 policy 和 value 文件夹下。

当前示例中所用数据及规模 RLHF 训练基于 sharding stage3 使用 NVIDIA A100 80G 4卡/8卡训练验证。
此外为了支持更高性、更大规模的 RLHF 训练提供了以下特殊参数配置,可以按需使用:
- `use_fusemt`:安装 paddlenlp_ops 后将在 rollout 生成时开启生成加速(开启流水线并行时不支持生成加速),通过此设置可以禁用生成加速。
- `eval_mode`:支持为空或者设置为 "single"、"tensor_parallel";通常可以在使用流水线并行训练时设置为"tensor_parallel",以此在 rollout 生成阶段使用非流水线并行模型并进行生成加速。
- `offload_level`:支持设置为"freeze_model"、"optimizer"、"train_model"或者同时使用(空格分隔),分别指示 reward+reference 两个冻结模型、actor+critic 两个训练模型的优化器状态和模型参数的 offload/reload,用于在不同阶段 model/optimizer 使用结束后及时 offload 并在下次使用时 reload 相应参数权重以节省显存。

另外注意,在使用流水线并行时(pipeline_parallel_degree大于1)建议将 `dataloader_drop_last` 设置为 true, 以此避免不同batch size带来的问题。




### 推理

Expand Down
Loading
Loading