Skip to content

Commit

Permalink
fix autocast bug
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed May 9, 2024
1 parent 2619f17 commit c22bc8b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
def update_scores_for_generation(scores, next_scores, length, unfinished_flag):
# update scores

unfinished_scores = (scores * length + next_scores) / (length + 1)
unfinished_scores = (scores * length.astype(scores.dtype) + next_scores) / (length + 1).astype(scores.dtype)
scores = paddle.where(unfinished_flag, unfinished_scores, scores)
return scores

Expand Down

0 comments on commit c22bc8b

Please sign in to comment.