Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
iosmers committed May 22, 2024
1 parent 5e554ce commit f1d68f6
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def __init__(
args = TrainingArguments(output_dir=output_dir)

self.args = args
print("self.args:", self.args)

Check warning on line 266 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L266

Added line #L266 was not covered by tests
self.is_in_train = False
# self.do_grad_scaling = args.fp16

Expand Down Expand Up @@ -1907,11 +1908,20 @@ def get_expected_keys(inputs, keys):
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

if in_sharding_parallel_mode:
# stage1 has v1 and v2 version
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
if "split_param" in self.args.sharding_parallel_config:
self.optimizer._set_all_gather_overlap_forward(True, model)
if (

Check warning on line 1914 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1912-L1914

Added lines #L1912 - L1914 were not covered by tests
hasattr(self.optimizer, "_set_all_gather_overlap_forward")
and "enable_stage1_allgather_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_all_gather_overlap_forward(True, model)

Check warning on line 1918 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1918

Added line #L1918 was not covered by tests
else:
self.optimizer._set_broadcast_overlap(True, model)
if (

Check warning on line 1920 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1920

Added line #L1920 was not covered by tests
hasattr(self.optimizer, "_set_broadcast_overlap")
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_broadcast_overlap(True, model)

Check warning on line 1924 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1924

Added line #L1924 was not covered by tests

return model

Expand Down

0 comments on commit f1d68f6

Please sign in to comment.