diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index aa8e0bf5016..374c8388595 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -408,6 +408,8 @@ def _multiclass_stat_scores_update( fp = confmat.sum(0) - tp fn = confmat.sum(1) - tp tn = confmat.sum() - (fp + fn + tp) + if ignore_index is not None: + fp[ignore_index] = 0 return tp, fp, tn, fn diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index db497cdb197..61f3c8552fa 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -190,12 +190,9 @@ def _reference_sklearn_accuracy_multiclass(preds, target, ignore_index, multidim return _reference_sklearn_accuracy(target, preds) confmat = sk_confusion_matrix(target, preds, labels=list(range(NUM_CLASSES))) acc_per_class = confmat.diagonal() / confmat.sum(axis=1) - acc_per_class[np.isnan(acc_per_class)] = 0.0 if average == "macro": - acc_per_class = acc_per_class[ - (np.bincount(preds, minlength=NUM_CLASSES) + np.bincount(target, minlength=NUM_CLASSES)) != 0.0 - ] - return acc_per_class.mean() + return np.nanmean(acc_per_class) + acc_per_class[np.isnan(acc_per_class)] = 0.0 if average == "weighted": weights = confmat.sum(1) return ((weights * acc_per_class) / weights.sum()).sum()