Skip to content

Commit

Permalink
fix auto cast bug
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed May 9, 2024
1 parent 2619f17 commit c0ca283
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
def update_scores_for_generation(scores, next_scores, length, unfinished_flag):
# update scores

unfinished_scores = (scores * length + next_scores) / (length + 1)
unfinished_scores = (scores * paddle.cast(length, scores.dtype) + next_scores) / paddle.cast((length + 1), scores.dtype)
scores = paddle.where(unfinished_flag, unfinished_scores, scores)
return scores

Expand Down

0 comments on commit c0ca283

Please sign in to comment.