diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 36e7b221729b..914f79ec939a 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -267,6 +267,7 @@ def main(): model_config.using_fake_gate = model_args.using_fake_gate model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num model_config.aux_loss_alpha = model_args.aux_loss_alpha + model_config.gradient_accumulation_steps = training_args.gradient_accumulation_steps logger.info(f"Final model config: {model_config}") logger.info("Creating model") diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index b301ff5da008..375e4a3c8885 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -1002,6 +1002,26 @@ def linear_dtype_gaurd(): # fmt: on + if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: + def grad_allreduce_hook(param, accumulation_steps): + hcg = fleet.get_hybrid_communicate_group() + pg = hcg.get_model_parallel_group().process_group + step = [0] + + @paddle.autograd.no_grad() + def __impl__(): + step[0] += 1 + if (step[0] % accumulation_steps) == 0: + if hasattr(param, "main_grad"): + pg.allreduce(param.main_grad).wait() + else: + pg.allreduce(param.grad).wait() + + return __impl__ + # kv_a_proj_with_mqa and q_a_proj grad need to be reduce between mp + self.kv_a_proj_with_mqa.weight._register_backward_hook(grad_allreduce_hook(self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps)) + self.q_a_proj.weight._register_backward_hook(grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps)) + self._init_rope() self.softmax_scale = self.q_head_dim ** (-0.5)