diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 28495b6f622c..a3200b5142f0 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -263,6 +263,7 @@ def __init__( args = TrainingArguments(output_dir=output_dir) self.args = args + print("self.args:", self.args) self.is_in_train = False # self.do_grad_scaling = args.fp16 @@ -1907,11 +1908,20 @@ def get_expected_keys(inputs, keys): self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) - if in_sharding_parallel_mode: + # stage1 has v1 and v2 version + if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding: if "split_param" in self.args.sharding_parallel_config: - self.optimizer._set_all_gather_overlap_forward(True, model) + if ( + hasattr(self.optimizer, "_set_all_gather_overlap_forward") + and "enable_stage1_allgather_overlap" in self.args.sharding_parallel_config + ): + self.optimizer._set_all_gather_overlap_forward(True, model) else: - self.optimizer._set_broadcast_overlap(True, model) + if ( + hasattr(self.optimizer, "_set_broadcast_overlap") + and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config + ): + self.optimizer._set_broadcast_overlap(True, model) return model