diff --git a/paddlenlp/generation/utils.py b/paddlenlp/generation/utils.py index ffd34b1d79cd..e955331ec7e3 100644 --- a/paddlenlp/generation/utils.py +++ b/paddlenlp/generation/utils.py @@ -510,8 +510,9 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder @staticmethod def update_scores_for_generation(scores, next_scores, length, unfinished_flag): # update scores + # breakpoint() - unfinished_scores = (scores * length + next_scores) / (length + 1) + unfinished_scores = (scores * paddle.cast(length, scores.dtype) + next_scores) / paddle.cast((length + 1), scores.dtype) scores = paddle.where(unfinished_flag, unfinished_scores, scores) return scores