diff --git a/llm/config/deepseek-v3/sft_argument.json b/llm/config/deepseek-v3/sft_argument.json index edc8452fe09e..7e713dd4b355 100644 --- a/llm/config/deepseek-v3/sft_argument.json +++ b/llm/config/deepseek-v3/sft_argument.json @@ -8,7 +8,7 @@ "per_device_eval_batch_size": 1, "eval_accumulation_steps": 1, "max_steps": 100, - "max_grad_norm": 0, + "max_grad_norm": 1.0, "amp_master_grad": true, "num_train_epochs": 1, "learning_rate": 2.2e-05, diff --git a/llm/run_finetune.py b/llm/run_finetune.py index de40634ef379..36e7b221729b 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -46,7 +46,7 @@ ReFTModel, intervention_mapping, ) -from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed, MoECorrectionBiasAdjustCallback +from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed, MoECorrectionBiasAdjustCallback, MoeExpertsGradScaleCallback from paddlenlp.trainer.trainer_callback import TrainerState from paddlenlp.transformers import ( AutoConfig, @@ -474,7 +474,11 @@ def compute_metrics_do_generation(eval_preds): callbacks += [ZeroPaddingIterDatasetCallback()] if getattr(model_config, "topk_method", None) == "noaux_tc": - callbacks += [MoECorrectionBiasAdjustCallback(lr=0)] + # deepseek_v3 finetune do not update the bias, so set lr to 0.0 + callbacks += [MoECorrectionBiasAdjustCallback(lr=0.0)] + + if training_args.use_expert_parallel: + callbacks += [MoeExpertsGradScaleCallback(training_args)] print("callbacks:", callbacks, flush=True) trainer = SFTTrainer( diff --git a/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py b/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py index b11263fe90e6..479aa9a44c05 100644 --- a/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py +++ b/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py @@ -179,7 +179,7 @@ def _dygraph_clip(self, params_grads): ) is_moe_param = getattr(p, "is_moe_param", False) - print(f"p.name:{p.name}, is_moe_param:{is_moe_param}") + if is_moe_param: assert 0 if not_shared_enable: @@ -350,13 +350,6 @@ def __getattr__(self, item): return getattr(self._clip, item) def __call__(self, params_grads): - print("==== zyc debug in moe_hybrid_parallel_optimizer.py ====") - for p, g in params_grads: - has_moe_attr = hasattr(p, "is_moe_param") - is_moe_param = False - if has_moe_attr: - is_moe_param = p.is_moe_param - print(f"p.name:{p.name}, has_moe_attr:{has_moe_attr}, is_moe_param:{is_moe_param}") return self._dygraph_clip(params_grads) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index db23594b7bc1..c44eb5bf23c7 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1323,21 +1323,6 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): elif p.grad is not None: p.grad.scale_(1.0 / self.args.gradient_accumulation_steps) - if os.environ.get("FIX_EP_GRAD", None): - param_count = 0 - for p in model._layers.parameters(): - if hasattr(p, "is_moe_param") and p.is_moe_param: - with paddle.no_grad(): - if hasattr(p, "main_grad") and p.main_grad is not None: - # print("main grad scale 1/ep") - p.main_grad.scale_(1.0 / self.args.expert_parallel_degree) - param_count += 1 - elif p.grad is not None: - # print("grad scale 1/ep") - p.grad.scale_(1.0 / self.args.expert_parallel_degree) - param_count += 1 - print("fix ep grad count:{}".format(param_count), flush=True) - # Optimizer step self.callback_handler.on_optimizer_begin( args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None diff --git a/paddlenlp/trainer/trainer_callback.py b/paddlenlp/trainer/trainer_callback.py index 8f9beaab813e..e4394bfde2ea 100644 --- a/paddlenlp/trainer/trainer_callback.py +++ b/paddlenlp/trainer/trainer_callback.py @@ -46,6 +46,7 @@ "PrinterCallback", "EarlyStoppingCallback", "MoECorrectionBiasAdjustCallback", + "MoeExpertsGradScaleCallback", ] @@ -673,4 +674,41 @@ def update_bias(layer): biases.pop(0).add_(update_list.pop(0)) usages.pop(0).zero_() - model.apply(update_bias) \ No newline at end of file + model.apply(update_bias) + +class MoeExpertsGradScaleCallback(TrainerCallback): + """ + 此 hook 用于修正专家参数的梯度被放大N倍的问题 + """ + + def __init__(self, args): + """_summary_ + + Args: + args (_type_): _description_ + """ + if not args.use_expert_parallel: + raise ValueError("This callback should be used with expert parallel") + if args.expert_parallel_degree > 1: + self.expert_gradient_scaling_factor = 1.0 / args.expert_parallel_degree + if args.tensor_parallel_degree > 1: + self.expert_gradient_scaling_factor *= args.tensor_parallel_degree + logger.info( + f"EP-MoE is used, expert gradient scaling factor is set to {self.expert_gradient_scaling_factor}" + ) + + def on_optimizer_begin(self, args, state, control, **kwargs): + model = kwargs["model"] + param_count = 0 + for p in model.parameters(): + if not getattr(p, "no_sync", False): + continue + if hasattr(p, "is_moe_param") and p.is_moe_param: + with paddle.no_grad(): + if hasattr(p, "main_grad") and p.main_grad is not None: + p.main_grad.scale_(self.expert_gradient_scaling_factor) + param_count += 1 + elif p.grad is not None: + p.grad.scale_(self.expert_gradient_scaling_factor) + param_count += 1 + logger.info("correct ep grad count:{}".format(param_count)) \ No newline at end of file