diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 8f2dd1c3641..78fc26d08db 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1654,8 +1654,11 @@ def forward(self, prediction_scores, masked_lm_labels): binary_sequence = paddle.where( masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) ) - sum_ = paddle.sum(binary_sequence) - loss = 0 if sum_ == 0 else paddle.sum(masked_lm_loss * binary_sequence) / sum_ + count = paddle.sum(binary_sequence) + if count == 0: + loss = paddle.sum(masked_lm_loss * binary_sequence) + else: + loss = paddle.sum(masked_lm_loss * binary_sequence) / count return loss