From 79bbc2dca9c7a59dcb7d8280253af96912654b95 Mon Sep 17 00:00:00 2001 From: wangna11BD <79366697+wangna11BD@users.noreply.github.com> Date: Wed, 22 May 2024 16:08:26 +0800 Subject: [PATCH] fix cross_entropy speed (#64211) --- python/paddle/nn/functional/loss.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index d15495993ce0e..fee7ae407d001 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2921,13 +2921,12 @@ def cross_entropy( # 2. else # numerator: loss's weighted sum # denominator: cal the sum of weight where the sample's class_index!=ignore_index - is_ignore = label == ignore_index - mask = ~is_ignore - if paddle.count_nonzero(is_ignore) > 0: # ignore label + if ignore_index >= 0: # ignore label out_sum = _C_ops.sum(out, [], None, False) # for each label[i],set 1 or 0, according to ignore_index # mask[i]=0, if label[i]==ignore_index # mask[i]=1, otherwise + mask = label != ignore_index if weight is None: mask = paddle.cast(mask, dtype=out_sum.dtype) count = _C_ops.sum(mask, [], None, False)