Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Optimize Reshard [part2: optimize _reshard_input] #60022

Merged
merged 4 commits into from
Dec 18, 2023

Conversation

AndSonder
Copy link
Contributor

@AndSonder AndSonder commented Dec 14, 2023

PR types

Performance optimization

PR changes

Others

Description

静态图模式下分析得,_reshard_input 中的 parse_op_desc 里面调用了很多 sync_with_cpp 消耗了较多的时间,耗时图如下所示:

image

经过分析可得,parse_op_desc 中 Insert 类的操作 如insert_c_concat_op、insert_slice_op、insert_concat_op等操作,每次插入op之前都会和cpp端进行同步,这是没有必要的,我们只需要在op插入完成之后做一次同步即可。经模型验证,该优化不会影响模型精度,模型精度可与优化前对齐。

本地测试优化后模型实际Run之前的耗时,优化幅度达 0.208 倍 (23.889 -> 18.928)。

测试环境:本地四卡 1080Ti机器,PaddleNLP Llama2 7b模型, 静态图模式(hack config: config.num_hidden_layers = 12)

优化前后 cprofiler 结果如下, <-- !!! 指向的是被优化的项:

优化前:
    ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   23.889   23.889 run_pretrain_auto.py:414(main) <-- !!!
        1    0.000    0.000   15.531   15.531 engine.py:1465(prepare)
        1    0.000    0.000   15.531   15.531 engine.py:562(_prepare_program)
        1    0.000    0.000    9.015    9.015 engine.py:762(_parallel)
        1    0.000    0.000    9.015    9.015 parallelizer_v2.py:65(parallel)
7795/6419    0.019    0.000    4.790    0.001 decorator.py:229(fun)
7795/6419    0.010    0.000    4.610    0.001 wrapped_decorator.py:23(__impl__)
        1    0.000    0.000    4.311    4.311 reshard.py:2797(reshard)
        1    0.033    0.033    4.184    4.184 reshard.py:2416(_reshard_input) <-- !!!
      261    0.388    0.001    3.974    0.015 reshard.py:1784(parse_op_desc)
        1    0.000    0.000    3.306    3.306 engine.py:575(_build)
        1    0.000    0.000    3.198    3.198 helper.py:222(build_program)
        3    0.000    0.000    3.125    1.042 program_translator.py:912(concrete_program)
        3    0.000    0.000    3.125    1.042 program_translator.py:942(concrete_program_specify_input_spec)
        3    0.000    0.000    3.125    1.042 program_translator.py:825(get_concrete_program)
        3    0.000    0.000    3.125    1.042 program_translator.py:1637(__getitem__)
        1    0.000    0.000    3.125    3.125 program_translator.py:1558(_build_once)
  264/151    0.001    0.000    3.032    0.020 base.py:66(__impl__)
        1    0.000    0.000    3.015    3.015 program_translator.py:1276(from_func_spec)
        1    0.000    0.000    3.013    3.013 parallel.py:943(init_parallel_env)
        6    3.012    0.502    3.012    0.502 {built-in method paddle.base.libpaddle.create_or_get_global_tcp_store}
      355    2.242    0.006    2.787    0.008 framework.py:4568(_sync_with_cpp) <-- !!!

优化后:
ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   18.928   18.928 run_pretrain_auto.py:414(main) <-- !!!
        1    0.000    0.000   13.129   13.129 engine.py:1465(prepare)
        1    0.000    0.000   13.129   13.129 engine.py:562(_prepare_program)
        1    0.000    0.000    6.603    6.603 engine.py:762(_parallel)
        1    0.000    0.000    6.602    6.602 parallelizer_v2.py:65(parallel)
7795/6419    0.018    0.000    4.683    0.001 decorator.py:229(fun)
7795/6419    0.009    0.000    4.518    0.001 wrapped_decorator.py:23(__impl__)
        1    0.000    0.000    3.294    3.294 engine.py:575(_build)
        1    0.000    0.000    3.169    3.169 helper.py:222(build_program)
        3    0.000    0.000    3.101    1.034 program_translator.py:912(concrete_program)
        3    0.000    0.000    3.101    1.034 program_translator.py:942(concrete_program_specify_input_spec)
        3    0.000    0.000    3.101    1.034 program_translator.py:825(get_concrete_program)
        3    0.000    0.000    3.101    1.034 program_translator.py:1637(__getitem__)
        1    0.000    0.000    3.101    3.101 program_translator.py:1558(_build_once)
  264/151    0.001    0.000    2.984    0.020 base.py:66(__impl__)
        1    0.000    0.000    2.967    2.967 program_translator.py:1276(from_func_spec)
        2    0.000    0.000    2.826    1.413 executor.py:1565(run)
        2    0.000    0.000    2.826    1.413 executor.py:1744(_run_impl)
        2    0.000    0.000    2.814    1.407 executor.py:920(get_program_and_executor)
        1    0.000    0.000    2.733    2.733 ProxyLayer__trainoav3wtu6.py:9(_train)
        1    0.000    0.000    2.719    2.719 engine.py:1521(run)
        2    0.000    0.000    2.712    1.356 executor.py:942(_get_program_and_executor)
        5    0.000    0.000    2.412    0.482 pass_base.py:91(apply)
        5    0.000    0.000    2.411    0.482 pass_base.py:111(_apply_impl)
        1    0.036    0.036    1.891    1.891 reshard.py:2450(_reshard_input) <-- !!!
       25    0.001    0.000    1.664    0.067 creation.py:1298(arange)
       24    1.649    0.069    1.649    0.069 {built-in method paddle.base.libpaddle.pir.ops.arange}
      261    0.382    0.001    1.649    0.006 reshard.py:1804(parse_op_desc)
      ...
      132    0.294    0.002    0.520    0.004 framework.py:4567(_sync_with_cpp) <-- !!!

优化后的耗时可视化如下:

image

脚本启动命令:

python -u -m paddle.distributed.launch \
    --gpus "0,1,2,3" \
    --log_dir "output/$task_name""_log" \
    auto_parallel/run_pretrain_auto.py \
    --model_type "llama" \
    --model_name_or_path "facebook/llama-7b" \
    --tokenizer_name_or_path "facebook/llama-7b" \
    --input_dir "./data" \
    --output_dir "output/$task_name" \
    --split 949,50,1 \
    --max_seq_length 2048 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 16 \
    --fp16 1 \
    --fp16_opt_level "O2"  \
    --scale_loss 1024 \
    --pipeline_parallel_degree 2 \
    --tensor_parallel_degree 2 \
    --sharding_parallel_degree 1 \
    --sharding "stage1" \
    --learning_rate 0.0001 \
    --min_learning_rate 0.00001 \
    --max_steps 30 \
    --save_steps 5000 \
    --weight_decay 0.01 \
    --warmup_ratio 0.01 \
    --max_grad_norm 1.0 \
    --logging_steps 1\
    --dataloader_num_workers 1 \
    --sharding "" \
    --eval_steps 1000 \
    --report_to "visualdl" \
    --disable_tqdm true \
    --continue_training 0\
    --recompute 1 \
    --do_train 1 \
    --do_eval 0 \
    --device "gpu" \
    --data_impl "mmap" \
    --parallel_mode "auto" \
    --fuse_attention_qkv 1 \
    --fuse_attention_ffn 1 \
    --use_fused_rope 1 \
    --use_flash_attention 1

Copy link

paddle-bot bot commented Dec 14, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Dec 14, 2023
@AndSonder AndSonder changed the title [AutoParallel] Optimize Reshard [part2: parse_op_desc] [AutoParallel] Optimize Reshard [part2: _reshard_input] Dec 14, 2023
@AndSonder AndSonder changed the title [AutoParallel] Optimize Reshard [part2: _reshard_input] [AutoParallel] Optimize Reshard [part2: optimize _reshard_input] Dec 14, 2023
@@ -341,7 +341,9 @@ def insert_cast_op(block, idx, tensor, op_role, tensor_type):
type=tensor.type,
lod_level=tensor.lod_level,
)
cast_op = block._insert_op(

insert_op = block._insert_op if sync else block._insert_op_without_sync
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里insert_opcast_op使用相同的命名结构,但却完全是两个不同的语义。前者是一个函数,后者是一个OP对象,放在一起很容易让人混淆。建议做更清晰的区分。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

将 insert_op 改成了 insert_operation。cast_op 里面的 op 表示 operator,可以区分二者的区别

Copy link
Contributor

@Caozhou1995 Caozhou1995 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@From00 From00 merged commit 2efc3ad into PaddlePaddle:develop Dec 18, 2023
29 checks passed
HermitSun pushed a commit to HermitSun/Paddle that referenced this pull request Dec 21, 2023
…dlePaddle#60022)

* opt reshard parse_op_desc

* recover third party

* change var name

* fix code style
@AndSonder AndSonder deleted the opt_reshard_2 branch April 23, 2024 13:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants