diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index d0b5c915e11cd..5e2ad43c16431 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -140,17 +140,12 @@ def broadcast_dp_parameters(model, hcg): def fused_allreduce_gradients(parameter_list, hcg): - if _in_legacy_dygraph(): - data_parallel_group = None if hcg is None else hcg.get_data_parallel_group( - ) - logger.debug("dp start fuse allreduce gradients") - with framework.no_grad(): - _apply_collective_grads(parameter_list, data_parallel_group) - elif in_dygraph_mode(): - assert hcg is None, "It's not support to use hcg in EagerDygraph now." - data_parallel_group = paddle.distributed.collective._get_default_group() - with framework.no_grad(): - _apply_collective_grads_eager(parameter_list, data_parallel_group) + data_parallel_group = None if hcg is None else hcg.get_data_parallel_group() + logger.debug("dp start fuse allreduce gradients") + apply_func = _apply_collective_grads_eager if in_dygraph_mode( + ) else _apply_collective_grads + with framework.no_grad(): + apply_func(parameter_list, data_parallel_group) def sharding_reduce_gradients(parameter_list, hcg):