diff --git a/CHANGELOG.md b/CHANGELOG.md index 661ed4d6be5..cc59c72b828 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed bug where classification metrics with `average='macro'` would lead to wrong result if a class was missing ([#303](https://github.com/PyTorchLightning/metrics/pull/303)) + + - Fixed `weighted`, `multi-class` AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 ([#348](https://github.com/PyTorchLightning/metrics/issues/348)) @@ -85,7 +88,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed calculation in `IoU` metric when using `ignore_index` argument ([#328](https://github.com/PyTorchLightning/metrics/pull/328)) - ## [0.4.1] - 2021-07-05 ### Changed diff --git a/tests/classification/inputs.py b/tests/classification/inputs.py index 96c746cdce5..a0683dd4235 100644 --- a/tests/classification/inputs.py +++ b/tests/classification/inputs.py @@ -116,3 +116,10 @@ def generate_plausible_inputs_binary(num_batches=NUM_BATCHES, batch_size=BATCH_S _input_multilabel_prob_plausible = generate_plausible_inputs_multilabel() _input_binary_prob_plausible = generate_plausible_inputs_binary() + +# randomly remove one class from the input +_temp = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) +_class_remove, _class_replace = torch.multinomial(torch.ones(NUM_CLASSES), num_samples=2, replacement=False) +_temp[_temp == _class_remove] = _class_replace + +_input_multiclass_with_missing_class = Input(_temp.clone(), _temp.clone()) diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index d3431001ce1..b07fb742db4 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -23,6 +23,7 @@ from tests.classification.inputs import _input_multiclass as _input_mcls from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.classification.inputs import _input_multilabel as _input_mlb @@ -31,7 +32,7 @@ from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester from torchmetrics import Accuracy from torchmetrics.functional import accuracy from torchmetrics.utilities.checks import _input_format_classification @@ -342,3 +343,21 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected): cl_metric(preds, target) result_cl = cl_metric.compute() assert torch.allclose(expected, result_cl, equal_nan=True) + + +@pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) +def test_same_input(average): + preds = _input_miss_class.preds + target = _input_miss_class.target + preds_flat = torch.cat([p for p in preds], dim=0) + target_flat = torch.cat([t for t in target], dim=0) + + mc = Accuracy(num_classes=NUM_CLASSES, average=average) + for i in range(NUM_BATCHES): + mc.update(preds[i], target[i]) + class_res = mc.compute() + func_res = accuracy(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average) + sk_res = sk_accuracy(target_flat, preds_flat) + + assert torch.allclose(class_res, torch.tensor(sk_res).float()) + assert torch.allclose(func_res, torch.tensor(sk_res).float()) diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index 7fabeb15e7c..6c9e7d38d48 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -24,13 +24,14 @@ from tests.classification.inputs import _input_multiclass as _input_mcls from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.classification.inputs import _input_multilabel as _input_mlb from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester from torchmetrics import F1, FBeta, Metric from torchmetrics.functional import f1, fbeta from torchmetrics.utilities.checks import _input_format_classification @@ -55,7 +56,6 @@ def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, multiclass, ignore_ preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass ) sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels) if len(labels) != num_classes and not average: @@ -425,3 +425,25 @@ def test_top_k( assert torch.isclose(class_metric.compute(), result) assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) + + +@pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) +@pytest.mark.parametrize( + "metric_class, metric_functional, sk_fn", + [(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1, f1, f1_score)], +) +def test_same_input(metric_class, metric_functional, sk_fn, average): + preds = _input_miss_class.preds + target = _input_miss_class.target + preds_flat = torch.cat([p for p in preds], dim=0) + target_flat = torch.cat([t for t in target], dim=0) + + mc = metric_class(num_classes=NUM_CLASSES, average=average) + for i in range(NUM_BATCHES): + mc.update(preds[i], target[i]) + class_res = mc.compute() + func_res = metric_functional(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average) + sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=0) + + assert torch.allclose(class_res, torch.tensor(sk_res).float()) + assert torch.allclose(func_res, torch.tensor(sk_res).float()) diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index 4830203d892..c65a63d84c3 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -24,13 +24,14 @@ from tests.classification.inputs import _input_multiclass as _input_mcls from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.classification.inputs import _input_multilabel as _input_mlb from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester from torchmetrics import Metric, Precision, Recall from torchmetrics.functional import precision, precision_recall, recall from torchmetrics.utilities.checks import _input_format_classification @@ -209,7 +210,7 @@ def test_no_support(metric_class, metric_fn): ) class TestPrecisionRecall(MetricTester): @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [False]) def test_precision_recall_class( self, ddp: bool, @@ -437,3 +438,24 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected): cl_metric(preds, target) result_cl = cl_metric.compute() assert torch.allclose(expected, result_cl, equal_nan=True) + + +@pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) +@pytest.mark.parametrize( + "metric_class, metric_functional, sk_fn", [(Precision, precision, precision_score), (Recall, recall, recall_score)] +) +def test_same_input(metric_class, metric_functional, sk_fn, average): + preds = _input_miss_class.preds + target = _input_miss_class.target + preds_flat = torch.cat([p for p in preds], dim=0) + target_flat = torch.cat([t for t in target], dim=0) + + mc = metric_class(num_classes=NUM_CLASSES, average=average) + for i in range(NUM_BATCHES): + mc.update(preds[i], target[i]) + class_res = mc.compute() + func_res = metric_functional(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average) + sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=1) + + assert torch.allclose(class_res, torch.tensor(sk_res).float()) + assert torch.allclose(func_res, torch.tensor(sk_res).float()) diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py index 1f2f457d0dd..e4bcb5d8840 100644 --- a/torchmetrics/functional/classification/accuracy.py +++ b/torchmetrics/functional/classification/accuracy.py @@ -85,6 +85,12 @@ def _accuracy_compute( else: numerator = tp denominator = tp + fn + + if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + cond = tp + fp + fn == 0 + numerator = numerator[~cond] + denominator = denominator[~cond] + if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: # a class is not present if there exists no TPs, no FPs, and no FNs meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu() diff --git a/torchmetrics/functional/classification/f_beta.py b/torchmetrics/functional/classification/f_beta.py index 22f773699fe..9d6ddba3f5b 100644 --- a/torchmetrics/functional/classification/f_beta.py +++ b/torchmetrics/functional/classification/f_beta.py @@ -46,9 +46,15 @@ def _fbeta_compute( precision = _safe_divide(tp.float(), tp + fp) recall = _safe_divide(tp.float(), tp + fn) + if average == AvgMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + cond = tp + fp + fn == 0 + precision = precision[~cond] + recall = recall[~cond] + num = (1 + beta ** 2) * precision * recall denom = beta ** 2 * precision + recall - denom[denom == 0.0] = 1 # avoid division by 0 + denom[denom == 0.0] = 1.0 # avoid division by 0 + # if classes matter and a given class is not present in both the preds and the target, # computing the score for this class is meaningless, thus they should be ignored if average == AvgMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: diff --git a/torchmetrics/functional/classification/precision_recall.py b/torchmetrics/functional/classification/precision_recall.py index 6a56e4d27bd..a49dc372593 100644 --- a/torchmetrics/functional/classification/precision_recall.py +++ b/torchmetrics/functional/classification/precision_recall.py @@ -29,6 +29,12 @@ def _precision_compute( ) -> Tensor: numerator = tp denominator = tp + fp + + if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + cond = tp + fp + fn == 0 + numerator = numerator[~cond] + denominator = denominator[~cond] + if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: # a class is not present if there exists no TPs, no FPs, and no FNs meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu() @@ -199,11 +205,18 @@ def _recall_compute( ) -> Tensor: numerator = tp denominator = tp + fn + + if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + cond = tp + fp + fn == 0 + numerator = numerator[~cond] + denominator = denominator[~cond] + if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: # a class is not present if there exists no TPs, no FPs, and no FNs meaningless_indeces = ((tp | fn | fp) == 0).nonzero().cpu() numerator[meaningless_indeces, ...] = -1 denominator[meaningless_indeces, ...] = -1 + return _reduce_stat_scores( numerator=numerator, denominator=denominator,