-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support unified checkpoint for expert_parallel #8591
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #8591 +/- ##
===========================================
+ Coverage 53.86% 55.74% +1.88%
===========================================
Files 620 620
Lines 97110 96741 -369
===========================================
+ Hits 52306 53930 +1624
+ Misses 44804 42811 -1993 ☔ View full report in Codecov by Sentry. |
6415e4c
to
92d5432
Compare
@@ -22,6 +22,7 @@ | |||
import paddle | |||
import paddle.distributed as dist | |||
from paddle.distributed import fleet | |||
from paddle.framework import core |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个最好加个测试,框架可能需要有个模型支持一下expert_parallel
if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: | ||
key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) | ||
else: | ||
key_name = "_".join([static_name, key_name[1]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FP32 情况,加单测
|
||
need_files = set() | ||
state_dict = get_expected_state_dict(model) | ||
for key in state_dict.keys(): | ||
filename = index["weight_map"][key] | ||
# When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. | ||
if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是跳过 no_sync
参数吗?
@@ -962,6 +1015,7 @@ def save_single_card_optimizer(args, model, optimizer, output_dir): | |||
if master_weights is not None: | |||
for key in list(master_weights.keys()): | |||
master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) | |||
master_weights.update(fp32_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个加载的时候,会pop吗?
else: | ||
shard_file = file_name.replace( | ||
".pdparams", | ||
f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.pdparams", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前说的 序号不对问题,还有吗?
) | ||
shard_file = shard_file.replace( | ||
".safetensors", | ||
f"-{args.logical_process_index + 1:05d}-of-{args.world_size//sd_degree:05d}.safetensors", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里太长了,合并一下吧,简化一下代码吧
This Pull Request is stale because it has been open for 60 days with no activity. 当前Pull Request 60天内无活动,被标记为stale。 |
PR types
New features
PR changes
Others
Description
Support unified checkpoint for expert_parallel.