Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] add dp_group and exclude_layer params #4930

Merged
merged 4 commits into from
Feb 23, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,13 +671,15 @@ def train(
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):

# Maunally collect gradients
# Maunally collect gradients when group_sharded_parallel can't accepts_dp_group
# Case 1: Use sharding stage 2/3 with dp
# Case 2: Use recompute and dp
# local_rank != -1 don't means dp in networks.
if self.sharding and ShardingOption.SHARD_OP not in self.args.sharding:
if self.args.dp_degree > 1:
accepts_dp_group = "dp_group" in set(
inspect.signature(paddle.distributed.sharding.group_sharded_parallel).parameters.keys()
)
if self.args.dp_degree > 1 and not accepts_dp_group:
fused_allreduce_gradients(model.parameters(), fleet.get_hybrid_communicate_group())
if ShardingOption.FULL_SHARD in self.args.sharding:
# Why need sync on parm again ?
Expand Down Expand Up @@ -1220,8 +1222,26 @@ def _wrap_model(self, model, training=True):

from paddle.distributed.sharding import group_sharded_parallel

# add dp_group and exclude_layer params
# https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/distributed/sharding/group_sharded_parallel_cn.html#group-sharded-parallel
accepts_dp_group = "dp_group" in set(inspect.signature(group_sharded_parallel).parameters.keys())
accepts_exclude_layer = "exclude_layer" in set(
inspect.signature(group_sharded_parallel).parameters.keys()
)
extra_kwargs = {}
if accepts_dp_group:
extra_kwargs["dp_group"] = self.dp_group
if accepts_exclude_layer:
extra_kwargs["exclude_layer"] = ["GroupNorm"]

model, optimizer, _ = group_sharded_parallel(
model, self.optimizer, level=level, scaler=None, group=self.sharding_group, offload=cpu_offload
JunnYu marked this conversation as resolved.
Show resolved Hide resolved
model,
self.optimizer,
level=level,
scaler=None,
group=self.sharding_group,
offload=cpu_offload,
**extra_kwargs,
)
self.optimizer = optimizer

Expand Down