From b1062c9cf17d454fb058c0b27f050975bda88a2e Mon Sep 17 00:00:00 2001 From: Paul Grundmann Date: Thu, 29 Jul 2021 01:50:55 +0200 Subject: [PATCH] fix nDCG can not be called with negative relevance targets (#378) - Test nDCG with negative relevance targets * Fix: Check for non binary values for retrieval targets - Use the scikit-learn implementation of nDCG - Removed the test for non binary targets in test_ndcg.py and replaced the default parameters in the error test with a custom one that does not check for binary targets - set the _input_retrieval_scores_non_binary_target low to -1 to reduce the test failure rate * Fix removed unused imports in ndcg.py * more stable tests Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec Co-authored-by: SkafteNicki --- CHANGELOG.md | 5 +++ tests/retrieval/helpers.py | 49 +++++++++++++++++++++++ tests/retrieval/inputs.py | 6 +-- tests/retrieval/test_ndcg.py | 10 ++--- torchmetrics/functional/retrieval/ndcg.py | 17 +++++--- torchmetrics/utilities/checks.py | 4 +- 6 files changed, 74 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d7ebbd263b8..edd57bf7c8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,14 +13,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support in `nDCG` metric for target with values larger than 1 ([#343](https://github.com/PyTorchLightning/metrics/issues/343)) + - Added Word error rate (WER) ([#52](https://github.com/PyTorchLightning/metrics/issues/52)) + - Added Symmetric Mean Absolute Percentage error (SMAPE) ([#375](https://github.com/PyTorchLightning/metrics/issues/375)) - Allowed passing labels in (n_samples, n_classes) to `AveragePrecision` ([#386](https://github.com/PyTorchLightning/metrics/issues/386)) +- Added support for negative targets in `nDCG` metric ([#378](https://github.com/PyTorchLightning/metrics/pull/378)) + + ### Changed - Moved `psnr` and `ssim` from `functional.regression.*` to `functional.image.*` ([#382](https://github.com/PyTorchLightning/metrics/pull/382)) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index fc945530cea..0866b12fdfb 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -137,6 +137,19 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict: ] ) +_errors_test_functional_metric_parameters_with_nonbinary = dict( + argnames="preds,target,message,metric_args", + argvalues=[ + # check input shapes are consistent (func) + (_irs_mis_sz_fn.preds, _irs_mis_sz_fn.target, "`preds` and `target` must be of the same shape", {}), + # check input tensors are not empty + (_irs_empty.preds, _irs_empty.target, "`preds` and `target` must be non-empty and non-scalar tensors", {}), + # check on input dtypes + (_irs.preds.bool(), _irs.target, "`preds` must be a tensor of floats", {}), + (_irs.preds, _irs.target.float(), "`target` must be a tensor of booleans or integers", {}), + ] +) + _errors_test_functional_metric_parameters_k = dict( argnames="preds,target,message,metric_args", argvalues=[ @@ -167,6 +180,42 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict: ] ) +_errors_test_class_metric_parameters_with_nonbinary = dict( + argnames="indexes,preds,target,message,metric_args", + argvalues=[ + (None, _irs.preds, _irs.target, "`indexes` cannot be None", dict(empty_target_action="error")), + # check when input arguments are invalid + ( + _irs.indexes, _irs.preds, _irs.target, "`empty_target_action` received a wrong value `casual_argument`.", + dict(empty_target_action="casual_argument") + ), + # check input shapes are consistent + ( + _irs_mis_sz.indexes, _irs_mis_sz.preds, _irs_mis_sz.target, + "`indexes`, `preds` and `target` must be of the same shape", dict(empty_target_action="skip") + ), + # check input tensors are not empty + ( + _irs_empty.indexes, _irs_empty.preds, + _irs_empty.target, "`indexes`, `preds` and `target` must be non-empty and non-scalar tensors", + dict(empty_target_action="skip") + ), + # check on input dtypes + ( + _irs.indexes.bool(), _irs.preds, _irs.target, "`indexes` must be a tensor of long integers", + dict(empty_target_action="skip") + ), + ( + _irs.indexes, _irs.preds.bool(), _irs.target, "`preds` must be a tensor of floats", + dict(empty_target_action="skip") + ), + ( + _irs.indexes, _irs.preds, _irs.target.float(), "`target` must be a tensor of booleans or integers", + dict(empty_target_action="skip") + ) + ] +) + _errors_test_class_metric_parameters_default = dict( argnames="indexes,preds,target,message,metric_args", argvalues=[ diff --git a/tests/retrieval/inputs.py b/tests/retrieval/inputs.py index f790e8cfa50..4a2f93cd7fc 100644 --- a/tests/retrieval/inputs.py +++ b/tests/retrieval/inputs.py @@ -33,9 +33,9 @@ ) _input_retrieval_scores_non_binary_target = Input( - indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.randint(high=4, size=(NUM_BATCHES, BATCH_SIZE)), + indexes=torch.randint(high=10, size=(NUM_BATCHES, 2 * BATCH_SIZE)), + preds=torch.rand(NUM_BATCHES, 2 * BATCH_SIZE), + target=torch.randint(low=-1, high=4, size=(NUM_BATCHES, 2 * BATCH_SIZE)), ) # with errors diff --git a/tests/retrieval/test_ndcg.py b/tests/retrieval/test_ndcg.py index 93a67510b04..2543a31d9f0 100644 --- a/tests/retrieval/test_ndcg.py +++ b/tests/retrieval/test_ndcg.py @@ -22,11 +22,10 @@ _concat_tests, _default_metric_class_input_arguments_with_non_binary_target, _default_metric_functional_input_arguments_with_non_binary_target, - _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_k, - _errors_test_class_metric_parameters_no_pos_target, - _errors_test_functional_metric_parameters_default, + _errors_test_class_metric_parameters_with_nonbinary, _errors_test_functional_metric_parameters_k, + _errors_test_functional_metric_parameters_with_nonbinary, ) from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg from torchmetrics.retrieval.retrieval_ndcg import RetrievalNormalizedDCG @@ -114,8 +113,7 @@ def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): @pytest.mark.parametrize( **_concat_tests( - _errors_test_class_metric_parameters_default, - _errors_test_class_metric_parameters_no_pos_target, + _errors_test_class_metric_parameters_with_nonbinary, _errors_test_class_metric_parameters_k, ) ) @@ -135,7 +133,7 @@ def test_arguments_class_metric( @pytest.mark.parametrize( **_concat_tests( - _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_with_nonbinary, _errors_test_functional_metric_parameters_k, ) ) diff --git a/torchmetrics/functional/retrieval/ndcg.py b/torchmetrics/functional/retrieval/ndcg.py index 211654efd41..65ac97247ef 100644 --- a/torchmetrics/functional/retrieval/ndcg.py +++ b/torchmetrics/functional/retrieval/ndcg.py @@ -14,14 +14,14 @@ from typing import Optional import torch -from torch import Tensor, tensor +from torch import Tensor from torchmetrics.utilities.checks import _check_retrieval_functional_inputs def _dcg(target: Tensor) -> Tensor: denom = torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0) - return (target / denom).sum() + return (target / denom).sum(dim=-1) def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor: @@ -55,10 +55,15 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = N if not (isinstance(k, int) and k > 0): raise ValueError("`k` has to be a positive integer or None") - if not target.sum(): - return tensor(0.0, device=preds.device) - sorted_target = target[torch.argsort(preds, dim=-1, descending=True)][:k] ideal_target = torch.sort(target, descending=True)[0][:k] - return _dcg(sorted_target) / _dcg(ideal_target) + ideal_dcg = _dcg(ideal_target) + target_dcg = _dcg(sorted_target) + + # filter undefined scores + all_irrelevant = ideal_dcg == 0 + target_dcg[all_irrelevant] = 0 + target_dcg[~all_irrelevant] /= ideal_dcg[~all_irrelevant] + + return target_dcg.mean() diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 22cb34ebf14..3b4b68061df 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -531,7 +531,7 @@ def _check_retrieval_functional_inputs( if not preds.is_floating_point(): raise ValueError("`preds` must be a tensor of floats") - if not allow_non_binary_target and target.max() > 1 or target.min() < 0: + if not allow_non_binary_target and (target.max() > 1 or target.min() < 0): raise ValueError("`target` must contain `binary` values") return preds.float().flatten(), target.long().flatten() @@ -575,7 +575,7 @@ def _check_retrieval_inputs( if target.dtype not in (torch.bool, torch.long, torch.int): raise ValueError("`target` must be a tensor of booleans or integers") - if not allow_non_binary_target and target.max() > 1 or target.min() < 0: + if not allow_non_binary_target and (target.max() > 1 or target.min() < 0): raise ValueError("`target` must contain `binary` values") return indexes.long().flatten(), preds.float().flatten(), target.long().flatten()