Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def main():
model_config.using_fake_gate = model_args.using_fake_gate
model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num
model_config.aux_loss_alpha = model_args.aux_loss_alpha
model_config.gradient_accumulation_steps = training_args.gradient_accumulation_steps
logger.info(f"Final model config: {model_config}")

logger.info("Creating model")
Expand Down
20 changes: 20 additions & 0 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,26 @@ def linear_dtype_gaurd():

# fmt: on

if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
def grad_allreduce_hook(param, accumulation_steps):
hcg = fleet.get_hybrid_communicate_group()
pg = hcg.get_model_parallel_group().process_group
step = [0]

@paddle.autograd.no_grad()
def __impl__():
step[0] += 1
if (step[0] % accumulation_steps) == 0:
if hasattr(param, "main_grad"):
pg.allreduce(param.main_grad).wait()
else:
pg.allreduce(param.grad).wait()

return __impl__
# kv_a_proj_with_mqa and q_a_proj grad need to be reduce between mp
self.kv_a_proj_with_mqa.weight._register_backward_hook(grad_allreduce_hook(self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps))
self.q_a_proj.weight._register_backward_hook(grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps))

self._init_rope()

self.softmax_scale = self.q_head_dim ** (-0.5)
Expand Down
Loading