Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions llm/auto_parallel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
注:当前提供的DeepSeek-v3模型配置脚本为一个规模较小的示例demo(调小了网络层数),以支持在单机8卡的环境下运行,如果你想运行完整671B规模的DeepSeek-v3,需要将层数配置为61层,并对应地调整并行策略。当前自动并行提供的deepseek-v3版本中,暂未集成FP8、DeepEP等优化策略。

## 环境准备

1.安装 PaddlePaddle 最新版本

首先,您需要安装最新的`Paddle`, 推荐使用`Nightly`版本。访问 [Paddle 官网](https://www.paddlepaddle.org.cn/install/quick?docurl=undefined) 获取安装指导。
Expand All @@ -49,15 +50,22 @@ print(paddle.utils.run_check())


## 预训练

### 数据准备

项目提供了预先处理好的数据方便用户测试模型,下载到 `data` 目录下:

```shell
mkdir -p data && cd data
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.{bin,idx}
```

### 启动预训练

#### GPU 启动预训练

- 动态图模式

```python
# Llama pretrain example
# assume that cur dir is auto_parallel
Expand All @@ -67,16 +75,34 @@ python -u -m paddle.distributed.launch \
--log_dir "llama_auto_3d" \
./llama/run_pretrain_auto.py ./llama/pretrain_argument.json
```

该配置下运行`facebook/llama-7b`预训练任务,并行策略为MP2-PP2-DP2,分片策略为Stage1。
更多可配置参数,请参考`ModelArguments`, `DataArguments`, `PreTrainingArguments`。

- 动转静模式
<br>追加 `to_static`参数

#### XPU 启动预训练

除了 GPU,XPU 也支持自动并行,目前支持 llama 模型 7b 和 13b,更多模型支持正在开发中。

用户可以使用 `PaddleNLP/llm/auto_parallel/llama` 目录下的 `run_llama2_7b_xpu.sh` 和 `run_llama2_13b_xpu.sh` 脚本启动 XPU 上的预训练任务。

```shell
# cd ${PaddleNLP_Path}/llm/auto_parallel/llama
bash run_llama2_7b_xpu.sh
# or
bash run_llama2_13b_xpu.sh
```

Llama 7b 并行策略为 DP8,分片策略为 Stage1。Llama 13b 并行策略为 DP2-PP4,分片策略为 Stage1。


## 监督微调(SFT)
### 数据准备

项目提供预处理好的精调数据方便用户测试模型,下载并解压到`data`目录下:

```shell
wget -O AdvertiseGen.tar.gz https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz
tar -xvf AdvertiseGen.tar.gz
Expand Down
7 changes: 5 additions & 2 deletions llm/auto_parallel/llama/run_llama2_13b_xpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ python -u -m paddle.distributed.launch \
--tensor_parallel_degree 1 \
--pipeline_parallel_degree 4 \
--sharding "stage1" \
--data_parallel_config "enable_allreduce_avg_in_gradinent_scale gradient_sync_after_accumulate" \
--sharding_parallel_config "enable_overlap" \
--tensor_parallel_config "enable_mp_async_allreduce enable_mp_skip_c_identity enable_mp_fused_linear_param_grad_add" \
--tensor_parallel_config "enable_mp_async_allreduce" \
--pipeline_parallel_config "enable_send_recv_overlap" \
--sequence_parallel 0 \
--use_flash_attention 1 \
Expand All @@ -88,7 +89,7 @@ python -u -m paddle.distributed.launch \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--warmup_steps 30 \
--logging_steps 2 \
--logging_steps 1 \
--max_steps 1000 \
--save_steps 100000 \
--eval_steps 10000 \
Expand All @@ -98,6 +99,8 @@ python -u -m paddle.distributed.launch \
--bf16 \
--fp16_opt_level "O2" \
--amp_master_grad true \
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--dataloader_num_workers 1 \
Expand Down
10 changes: 7 additions & 3 deletions llm/auto_parallel/llama/run_llama2_7b_xpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ export PYTHONPATH=../../../:$PYTHONPATH
# for debug
#export GLOG_v=10
export FLAGS_call_stack_level=2
export GLOG_minloglevel=2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个会让用户运行的时候打印更多log吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个会让用户运行的时候打印更多log吗

不会


rm -rf output/$task_name_or_path
PYTHONPATH=../:$PYTHONPATH \
Expand All @@ -72,9 +73,10 @@ python -u -m paddle.distributed.launch \
--tensor_parallel_degree 1 \
--pipeline_parallel_degree 1 \
--sharding "stage1" \
--data_parallel_config "enable_allreduce_avg_in_gradinent_scale gradient_sync_after_accumulate" \
--sharding_parallel_config "enable_overlap" \
--tensor_parallel_config "enable_delay_scale_loss enable_mp_async_allreduce" \
--pipeline_parallel_config "enable_delay_scale_loss enable_release_grads disable_partial_send_recv" \
--tensor_parallel_config "enable_mp_async_allreduce" \
--pipeline_parallel_config "" \
--virtual_pp_degree 1 \
--sequence_parallel 0 \
--use_flash_attention 1 \
Expand All @@ -98,8 +100,10 @@ python -u -m paddle.distributed.launch \
--bf16 \
--fp16_opt_level "O2" \
--amp_master_grad true \
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
--warmup_ratio 0.01 \
--max_grad_norm 0.0 \
--max_grad_norm 1.0 \
--dataloader_num_workers 1 \
--continue_training 0 \
--do_predict 0 \
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,8 @@
if self.args.to_static:
schedule_start_step = self.args.job_schedule_profiler_start
schedule_end_step = self.args.job_schedule_profiler_end
switch_job_schedule_profiler(model, step, schedule_start_step, schedule_end_step)
if schedule_start_step >= 0:
switch_job_schedule_profiler(model, step, schedule_start_step, schedule_end_step)

Check warning on line 523 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L522-L523

Added lines #L522 - L523 were not covered by tests

for inputs in inputs_list:
if step_control % args.gradient_accumulation_steps == 0:
Expand Down