From 02b9dcdc694a5120054fa2228477504daf7fe515 Mon Sep 17 00:00:00 2001 From: cyber-pinoeer Date: Thu, 9 May 2024 11:28:03 +0000 Subject: [PATCH] fix auto cast bug --- paddlenlp/generation/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddlenlp/generation/utils.py b/paddlenlp/generation/utils.py index ffd34b1d79cd..e955331ec7e3 100644 --- a/paddlenlp/generation/utils.py +++ b/paddlenlp/generation/utils.py @@ -510,8 +510,9 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder @staticmethod def update_scores_for_generation(scores, next_scores, length, unfinished_flag): # update scores + # breakpoint() - unfinished_scores = (scores * length + next_scores) / (length + 1) + unfinished_scores = (scores * paddle.cast(length, scores.dtype) + next_scores) / paddle.cast((length + 1), scores.dtype) scores = paddle.where(unfinished_flag, unfinished_scores, scores) return scores