Skip to content

Commit

Permalink
Fix auc calculation and add tests (#197)
Browse files Browse the repository at this point in the history
* fix auc calculation and add tests

* remove stable sort attempt

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Apply suggestions from code review

* prune

* Update torchmetrics/functional/classification/auc.py

* fix

* chlog

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: jirka <jirka.borovec@seznam.cz>
  • Loading branch information
4 people committed Apr 26, 2021
1 parent 2584345 commit d432d3d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 44 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed auc calculation and add tests ([#197](https://github.com/PyTorchLightning/metrics/pull/197))


## [0.3.1] - 2021-04-21

Expand Down
33 changes: 21 additions & 12 deletions tests/classification/test_auc.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import numpy as np
import pytest
Expand All @@ -26,25 +27,30 @@
seed_all(42)


def sk_auc(x, y):
def sk_auc(x, y, reorder=False):
x = x.flatten()
y = y.flatten()
if reorder:
idx = np.argsort(x, kind='stable')
x = x[idx]
y = y[idx]
return _sk_auc(x, y)


Input = namedtuple('Input', ["x", "y"])

_examples = []
# generate already ordered samples, sorted in both directions
for i in range(4):
x = np.random.randint(0, 5, (NUM_BATCHES * 8))
y = np.random.randint(0, 5, (NUM_BATCHES * 8))
idx = np.argsort(x, kind='stable')
x = x[idx] if i % 2 == 0 else x[idx[::-1]]
y = y[idx] if i % 2 == 0 else x[idx[::-1]]
x = x.reshape(NUM_BATCHES, 8)
y = y.reshape(NUM_BATCHES, 8)
_examples.append(Input(x=tensor(x), y=tensor(y)))
for batch_size in (8, 4049):
for i in range(4):
x = np.random.rand((NUM_BATCHES * batch_size))
y = np.random.rand((NUM_BATCHES * batch_size))
idx = np.argsort(x, kind='stable')
x = x[idx] if i % 2 == 0 else x[idx[::-1]]
y = y[idx] if i % 2 == 0 else x[idx[::-1]]
x = x.reshape(NUM_BATCHES, batch_size)
y = y.reshape(NUM_BATCHES, batch_size)
_examples.append(Input(x=tensor(x), y=tensor(y)))


@pytest.mark.parametrize("x, y", _examples)
Expand All @@ -62,8 +68,11 @@ def test_auc(self, x, y, ddp, dist_sync_on_step):
dist_sync_on_step=dist_sync_on_step,
)

def test_auc_functional(self, x, y):
self.run_functional_metric_test(x, y, metric_functional=auc, sk_metric=sk_auc, metric_args={"reorder": False})
@pytest.mark.parametrize("reorder", [True, False])
def test_auc_functional(self, x, y, reorder):
self.run_functional_metric_test(
x, y, metric_functional=auc, sk_metric=partial(sk_auc, reorder=reorder), metric_args={"reorder": reorder}
)


@pytest.mark.parametrize(['x', 'y', 'expected'], [
Expand Down
5 changes: 2 additions & 3 deletions torchmetrics/functional/classification/auc.py
Expand Up @@ -16,8 +16,6 @@
import torch
from torch import Tensor

from torchmetrics.utilities.data import _stable_1d_sort


def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
if x.ndim > 1 or y.ndim > 1:
Expand All @@ -35,7 +33,8 @@ def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:

def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
if reorder:
x, x_idx = _stable_1d_sort(x)
# TODO: include stable=True arg when pytorch v1.9 is released
x, x_idx = torch.sort(x)
y = y[x_idx]

dx = x[1:] - x[:-1]
Expand Down
29 changes: 0 additions & 29 deletions torchmetrics/utilities/data.py
Expand Up @@ -151,35 +151,6 @@ def get_num_classes(
return num_classes


def _stable_1d_sort(x: torch, nb: int = 2049):
"""
Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm
if number of elements are larger than 2048. This function pads the tensors,
makes the sort and returns the sorted array (with the padding removed)
See this discussion: https://discuss.pytorch.org/t/is-torch-sort-stable/20714
Raises:
ValueError:
If dim of ``x`` is greater than 1 since stable sort works with only 1d tensors.
Example:
>>> data = torch.tensor([8, 7, 2, 6, 4, 5, 3, 1, 9, 0])
>>> _stable_1d_sort(data)
(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([9, 7, 2, 6, 4, 5, 3, 1, 0, 8]))
>>> _stable_1d_sort(data, nb=5)
(tensor([0, 1, 2, 3, 4]), tensor([9, 7, 2, 6, 4]))
"""
if x.ndim > 1:
raise ValueError('Stable sort only works on 1d tensors')
n = x.numel()
if n < nb:
x_max = x.max()
x = torch.cat([x, (x_max + 1) * torch.ones(nb - n, dtype=x.dtype, device=x.device)], 0)
x_sort = x.sort()
i = min(nb, n)
return x_sort.values[:i], x_sort.indices[:i]


def apply_to_collection(
data: Any,
dtype: Union[type, tuple],
Expand Down

0 comments on commit d432d3d

Please sign in to comment.