Skip to content

Commit

Permalink
fix nDCG can not be called with negative relevance targets (#378)
Browse files Browse the repository at this point in the history
- Test nDCG with negative relevance targets
* Fix: Check for non binary values for retrieval targets
- Use the scikit-learn implementation of nDCG
- Removed the test for non binary targets in test_ndcg.py and replaced the default parameters in the error test with a custom one that does not check for binary targets
- set the _input_retrieval_scores_non_binary_target low to -1 to reduce the test failure rate
* Fix removed unused imports in ndcg.py
* more stable tests

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: SkafteNicki <skaftenicki@gmail.com>
  • Loading branch information
5 people committed Jul 28, 2021
1 parent e00e3ab commit b1062c9
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 17 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Expand Up @@ -13,14 +13,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support in `nDCG` metric for target with values larger than 1 ([#343](https://github.com/PyTorchLightning/metrics/issues/343))


- Added Word error rate (WER) ([#52](https://github.com/PyTorchLightning/metrics/issues/52))


- Added Symmetric Mean Absolute Percentage error (SMAPE) ([#375](https://github.com/PyTorchLightning/metrics/issues/375))


- Allowed passing labels in (n_samples, n_classes) to `AveragePrecision` ([#386](https://github.com/PyTorchLightning/metrics/issues/386))


- Added support for negative targets in `nDCG` metric ([#378](https://github.com/PyTorchLightning/metrics/pull/378))


### Changed

- Moved `psnr` and `ssim` from `functional.regression.*` to `functional.image.*` ([#382](https://github.com/PyTorchLightning/metrics/pull/382))
Expand Down
49 changes: 49 additions & 0 deletions tests/retrieval/helpers.py
Expand Up @@ -137,6 +137,19 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
]
)

_errors_test_functional_metric_parameters_with_nonbinary = dict(
argnames="preds,target,message,metric_args",
argvalues=[
# check input shapes are consistent (func)
(_irs_mis_sz_fn.preds, _irs_mis_sz_fn.target, "`preds` and `target` must be of the same shape", {}),
# check input tensors are not empty
(_irs_empty.preds, _irs_empty.target, "`preds` and `target` must be non-empty and non-scalar tensors", {}),
# check on input dtypes
(_irs.preds.bool(), _irs.target, "`preds` must be a tensor of floats", {}),
(_irs.preds, _irs.target.float(), "`target` must be a tensor of booleans or integers", {}),
]
)

_errors_test_functional_metric_parameters_k = dict(
argnames="preds,target,message,metric_args",
argvalues=[
Expand Down Expand Up @@ -167,6 +180,42 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
]
)

_errors_test_class_metric_parameters_with_nonbinary = dict(
argnames="indexes,preds,target,message,metric_args",
argvalues=[
(None, _irs.preds, _irs.target, "`indexes` cannot be None", dict(empty_target_action="error")),
# check when input arguments are invalid
(
_irs.indexes, _irs.preds, _irs.target, "`empty_target_action` received a wrong value `casual_argument`.",
dict(empty_target_action="casual_argument")
),
# check input shapes are consistent
(
_irs_mis_sz.indexes, _irs_mis_sz.preds, _irs_mis_sz.target,
"`indexes`, `preds` and `target` must be of the same shape", dict(empty_target_action="skip")
),
# check input tensors are not empty
(
_irs_empty.indexes, _irs_empty.preds,
_irs_empty.target, "`indexes`, `preds` and `target` must be non-empty and non-scalar tensors",
dict(empty_target_action="skip")
),
# check on input dtypes
(
_irs.indexes.bool(), _irs.preds, _irs.target, "`indexes` must be a tensor of long integers",
dict(empty_target_action="skip")
),
(
_irs.indexes, _irs.preds.bool(), _irs.target, "`preds` must be a tensor of floats",
dict(empty_target_action="skip")
),
(
_irs.indexes, _irs.preds, _irs.target.float(), "`target` must be a tensor of booleans or integers",
dict(empty_target_action="skip")
)
]
)

_errors_test_class_metric_parameters_default = dict(
argnames="indexes,preds,target,message,metric_args",
argvalues=[
Expand Down
6 changes: 3 additions & 3 deletions tests/retrieval/inputs.py
Expand Up @@ -33,9 +33,9 @@
)

_input_retrieval_scores_non_binary_target = Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.randint(high=4, size=(NUM_BATCHES, BATCH_SIZE)),
indexes=torch.randint(high=10, size=(NUM_BATCHES, 2 * BATCH_SIZE)),
preds=torch.rand(NUM_BATCHES, 2 * BATCH_SIZE),
target=torch.randint(low=-1, high=4, size=(NUM_BATCHES, 2 * BATCH_SIZE)),
)

# with errors
Expand Down
10 changes: 4 additions & 6 deletions tests/retrieval/test_ndcg.py
Expand Up @@ -22,11 +22,10 @@
_concat_tests,
_default_metric_class_input_arguments_with_non_binary_target,
_default_metric_functional_input_arguments_with_non_binary_target,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_k,
_errors_test_class_metric_parameters_no_pos_target,
_errors_test_functional_metric_parameters_default,
_errors_test_class_metric_parameters_with_nonbinary,
_errors_test_functional_metric_parameters_k,
_errors_test_functional_metric_parameters_with_nonbinary,
)
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.retrieval.retrieval_ndcg import RetrievalNormalizedDCG
Expand Down Expand Up @@ -114,8 +113,7 @@ def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor):

@pytest.mark.parametrize(
**_concat_tests(
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_pos_target,
_errors_test_class_metric_parameters_with_nonbinary,
_errors_test_class_metric_parameters_k,
)
)
Expand All @@ -135,7 +133,7 @@ def test_arguments_class_metric(

@pytest.mark.parametrize(
**_concat_tests(
_errors_test_functional_metric_parameters_default,
_errors_test_functional_metric_parameters_with_nonbinary,
_errors_test_functional_metric_parameters_k,
)
)
Expand Down
17 changes: 11 additions & 6 deletions torchmetrics/functional/retrieval/ndcg.py
Expand Up @@ -14,14 +14,14 @@
from typing import Optional

import torch
from torch import Tensor, tensor
from torch import Tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def _dcg(target: Tensor) -> Tensor:
denom = torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0)
return (target / denom).sum()
return (target / denom).sum(dim=-1)


def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
Expand Down Expand Up @@ -55,10 +55,15 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = N
if not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")

if not target.sum():
return tensor(0.0, device=preds.device)

sorted_target = target[torch.argsort(preds, dim=-1, descending=True)][:k]
ideal_target = torch.sort(target, descending=True)[0][:k]

return _dcg(sorted_target) / _dcg(ideal_target)
ideal_dcg = _dcg(ideal_target)
target_dcg = _dcg(sorted_target)

# filter undefined scores
all_irrelevant = ideal_dcg == 0
target_dcg[all_irrelevant] = 0
target_dcg[~all_irrelevant] /= ideal_dcg[~all_irrelevant]

return target_dcg.mean()
4 changes: 2 additions & 2 deletions torchmetrics/utilities/checks.py
Expand Up @@ -531,7 +531,7 @@ def _check_retrieval_functional_inputs(
if not preds.is_floating_point():
raise ValueError("`preds` must be a tensor of floats")

if not allow_non_binary_target and target.max() > 1 or target.min() < 0:
if not allow_non_binary_target and (target.max() > 1 or target.min() < 0):
raise ValueError("`target` must contain `binary` values")

return preds.float().flatten(), target.long().flatten()
Expand Down Expand Up @@ -575,7 +575,7 @@ def _check_retrieval_inputs(
if target.dtype not in (torch.bool, torch.long, torch.int):
raise ValueError("`target` must be a tensor of booleans or integers")

if not allow_non_binary_target and target.max() > 1 or target.min() < 0:
if not allow_non_binary_target and (target.max() > 1 or target.min() < 0):
raise ValueError("`target` must contain `binary` values")

return indexes.long().flatten(), preds.float().flatten(), target.long().flatten()

0 comments on commit b1062c9

Please sign in to comment.