Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support sharding overlap #799

Merged
merged 5 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Distributed:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
comm_overlap: False

Optimizer:
name: FusedAdamW
Expand Down
1 change: 1 addition & 0 deletions ppfleetx/configs/nlp/gpt/generation_gpt_345M_dp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ Distributed:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
comm_overlap: False
1 change: 1 addition & 0 deletions ppfleetx/configs/nlp/gpt/inference_gpt_345M_dp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Distributed:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
comm_overlap: False


Data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Distributed:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
comm_overlap: False


Data:
Expand Down
1 change: 1 addition & 0 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_1.3B_dp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ Distributed:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
comm_overlap: False
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ Distributed:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
comm_overlap: False
1 change: 1 addition & 0 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_175B_mp8_pp16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ Distributed:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
comm_overlap: False
1 change: 1 addition & 0 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_345M_mp8_qat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Distributed:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
comm_overlap: False


Quantization:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ Distributed:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
comm_overlap: False
1 change: 1 addition & 0 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_6.7B_sharding16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ Distributed:
sharding_degree: 16
sharding_stage: 2
sharding_offload: False
comm_overlap: True
7 changes: 7 additions & 0 deletions ppfleetx/core/engine/eager_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ def configure_optimizers(self):
'sharding_degree']
self._sharding_offload = self._dist_configs['sharding'][
'sharding_offload']
self._comm_overlap = self._dist_configs['sharding']['comm_overlap']
if self._sharding_degree > 1 and self._comm_overlap:
if self._sharding_stage == 3 or self._sharding_offload:
self._comm_overlap = False
logger.warning("comm overlap only valid for sharding stage 2 without offload")
self._use_recompute = configs['Model']['use_recompute']

if self._use_pure_fp16:
Expand Down Expand Up @@ -245,6 +250,8 @@ def _wrap_sharding_2_3(self):
scaler=self._scaler,
group=self._sharding_group,
offload=self._sharding_offload)
if self._comm_overlap:
self._module.model._set_comm_overlap(self._comm_overlap)

def _wrap_3D_parallel(self):
self._module.model = fleet.distributed_model(self._module.model)
Expand Down
2 changes: 2 additions & 0 deletions projects/gpt/docs/hybrid_parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
comm_overlap: False
```

其中参数说明:
Expand All @@ -38,6 +39,7 @@
| sharding_degree | 分组切分并行维度 |
| sharding_stage | 切分策略;1表示仅切分优化器状态,2表示再切分梯度,3表示再切分前向参数 |
| sharding_offload | CPU offload策略 |
|comm_overlap| 是否在sharding stage 2的模式下进行通讯与计算overlap,该策略暂时不支持sharding_offload|

## 运行方式
本目录中按照345M、1.3B、6.7B和175B规模大小,给出32G V100环境下GPT模型混合并行训练的策略配置如下:
Expand Down