Skip to content

Commit

Permalink
Fix metrics examples, add average precision metric (#522)
Browse files Browse the repository at this point in the history
* Fix metrics examples, add average precision metric

* Fix mortality prediction notebook

* Add average precision metric class, fix more doctest errors

* Additional fixes

* Add remaining fixes
  • Loading branch information
amrit110 committed Dec 6, 2023
1 parent 49b7fd8 commit cfa1eff
Show file tree
Hide file tree
Showing 39 changed files with 751 additions and 432 deletions.
14 changes: 11 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ repos:
hooks:
- id: nbstripout
name: nbstripout
language: python
entry: nbstripout
language: system
entry: python3 -m nbstripout
exclude: ^docs/source/tutorials/gemini/.*\.ipynb$

- repo: https://github.com/nbQA-dev/nbQA
Expand All @@ -51,11 +51,19 @@ repos:
- id: nbqa-ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: local
hooks:
- id: doctest
name: doctest
entry: python3 -m doctest -o NORMALIZE_WHITESPACE
files: "^cyclops/evaluate/"
language: system

- repo: local
hooks:
- id: pytest
name: pytest
entry: pytest -m "not integration_test"
entry: python3 -m pytest -m "not integration_test"
language: system
pass_filenames: false
always_run: true
3 changes: 3 additions & 0 deletions cyclops/evaluate/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
MulticlassAUROC,
MultilabelAUROC,
)
from cyclops.evaluate.metrics.average_precision import (
BinaryAveragePrecision,
)
from cyclops.evaluate.metrics.f_beta import (
BinaryF1Score, # noqa: F401
BinaryFbetaScore,
Expand Down
12 changes: 6 additions & 6 deletions cyclops/evaluate/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class BinaryAccuracy(BinaryStatScores, registry_key="binary_accuracy"):
Examples
--------
>>> from cyclops.evaluation.metrics import BinaryAccuracy
>>> from cyclops.evaluate.metrics import BinaryAccuracy
>>> target = [0, 1, 0, 1]
>>> preds = [0, 1, 1, 1]
>>> metric = BinaryAccuracy()
Expand Down Expand Up @@ -100,7 +100,7 @@ class MulticlassAccuracy(MulticlassStatScores, registry_key="multiclass_accuracy
Examples
--------
>>> from cyclops.evaluation.metrics import MulticlassAccuracy
>>> from cyclops.evaluate.metrics import MulticlassAccuracy
>>> target = [0, 1, 2, 2, 2]
>>> preds = [0, 0, 2, 2, 1]
>>> metric = MulticlassAccuracy(num_classes=3)
Expand Down Expand Up @@ -176,7 +176,7 @@ class MultilabelAccuracy(MultilabelStatScores, registry_key="multilabel_accuracy
Examples
--------
>>> from cyclops.evaluation.metrics import MultilabelAccuracy
>>> from cyclops.evaluate.metrics import MultilabelAccuracy
>>> target = [[0, 1, 1], [1, 0, 0]]
>>> preds = [[0, 1, 0], [1, 0, 1]]
>>> metric = MultilabelAccuracy(num_labels=3)
Expand Down Expand Up @@ -268,7 +268,7 @@ class Accuracy(Metric, registry_key="accuracy", force_register=True):
Examples
--------
>>> # (binary)
>>> from cyclops.evaluation.metrics import Accuracy
>>> from cyclops.evaluate.metrics import Accuracy
>>> target = [0, 0, 1, 1]
>>> preds = [0, 1, 1, 1]
>>> metric = Accuracy(task="binary")
Expand All @@ -283,7 +283,7 @@ class Accuracy(Metric, registry_key="accuracy", force_register=True):
0.5
>>> # (multiclass)
>>> from cyclops.evaluation.metrics import Accuracy
>>> from cyclops.evaluate.metrics import Accuracy
>>> target = [0, 1, 2, 2, 2]
>>> preds = [0, 0, 2, 2, 1]
>>> metric = Accuracy(task="multiclass", num_classes=3)
Expand All @@ -299,7 +299,7 @@ class Accuracy(Metric, registry_key="accuracy", force_register=True):
array([0., 1., 0.])
>>> # (multilabel)
>>> from cyclops.evaluation.metrics import Accuracy
>>> from cyclops.evaluate.metrics import Accuracy
>>> target = [[0, 1, 1], [1, 0, 0]]
>>> preds = [[0, 1, 0], [1, 0, 1]]
>>> metric = Accuracy(task="multilabel", num_labels=3)
Expand Down
12 changes: 6 additions & 6 deletions cyclops/evaluate/metrics/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BinaryAUROC(BinaryPrecisionRecallCurve, registry_key="binary_auroc"):
Examples
--------
>>> from cyclops.evaluation.metrics import BinaryAUROC
>>> from cyclops.evaluate.metrics import BinaryAUROC
>>> target = [0, 0, 1, 1]
>>> preds = [0.1, 0.4, 0.35, 0.8]
>>> metric = BinaryAUROC()
Expand Down Expand Up @@ -106,7 +106,7 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve, registry_key="multiclass_a
Examples
--------
>>> from cyclops.evaluation.metrics import MulticlassAUROC
>>> from cyclops.evaluate.metrics import MulticlassAUROC
>>> target = [0, 1, 2, 0]
>>> preds = [[0.9, 0.05, 0.05], [0.05, 0.89, 0.06],
... [0.05, 0.01, 0.94], [0.9, 0.05, 0.05]]
Expand Down Expand Up @@ -180,7 +180,7 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve, registry_key="multilabel_a
Examples
--------
>>> from cyclops.evaluation.metrics import MultilabelAUROC
>>> from cyclops.evaluate.metrics import MultilabelAUROC
>>> target = [[0, 1], [1, 1], [1, 0]]
>>> preds = [[0.9, 0.05], [0.05, 0.89], [0.05, 0.01]]
>>> metric = MultilabelAUROC(num_labels=2)
Expand Down Expand Up @@ -261,7 +261,7 @@ class AUROC(Metric, registry_key="auroc", force_register=True):
Examples
--------
>>> # (binary)
>>> from cyclops.evaluation.metrics import BinaryAUROC
>>> from cyclops.evaluate.metrics import BinaryAUROC
>>> target = [0, 0, 1, 1]
>>> preds = [0.1, 0.4, 0.35, 0.8]
>>> metric = BinaryAUROC()
Expand All @@ -276,7 +276,7 @@ class AUROC(Metric, registry_key="auroc", force_register=True):
0.6111111111111112
>>> # (multiclass)
>>> from cyclops.evaluation.metrics import MulticlassAUROC
>>> from cyclops.evaluate.metrics import MulticlassAUROC
>>> target = [0, 1, 2, 0]
>>> preds = [[0.9, 0.05, 0.05], [0.05, 0.89, 0.06],
... [0.05, 0.01, 0.94], [0.9, 0.05, 0.05]]
Expand All @@ -293,7 +293,7 @@ class AUROC(Metric, registry_key="auroc", force_register=True):
array([0.5 , 0.22222222, 0. ])
>>> # (multilabel)
>>> from cyclops.evaluation.metrics import MultilabelAUROC
>>> from cyclops.evaluate.metrics import MultilabelAUROC
>>> target = [[0, 1], [1, 1], [1, 0]]
>>> preds = [[0.9, 0.05], [0.05, 0.89], [0.05, 0.01]]
>>> metric = MultilabelAUROC(num_labels=2)
Expand Down
141 changes: 141 additions & 0 deletions cyclops/evaluate/metrics/average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Classes for computing area under the Average Precision (AUPRC)."""

from typing import List, Literal, Optional, Type, Union

import numpy as np
import numpy.typing as npt

from cyclops.evaluate.metrics.functional.average_precision import (
_binary_average_precision_compute,
)
from cyclops.evaluate.metrics.metric import Metric
from cyclops.evaluate.metrics.precision_recall_curve import (
BinaryPrecisionRecallCurve,
)


class BinaryAveragePrecision(
BinaryPrecisionRecallCurve,
registry_key="binary_average_precision",
):
"""Compute average precision for binary input.
Parameters
----------
thresholds : int or list of floats or numpy.ndarray of floats, default=None
Thresholds used for computing the precision and recall scores.
If int, then the number of thresholds to use.
If list or numpy.ndarray, then the thresholds to use.
If None, then the thresholds are automatically determined by the
unique values in ``preds``.
pos_label : int
The label of the positive class.
Examples
--------
>>> from cyclops.evaluate.metrics import BinaryAveragePrecision
>>> target = [0, 1, 0, 1]
>>> preds = [0.1, 0.4, 0.35, 0.8]
>>> metric = BinaryAveragePrecision(thresholds=3)
>>> metric(target, preds)
0.75
>>> metric.reset_state()
>>> target = [[0, 1, 0, 1], [1, 1, 0, 0]]
>>> preds = [[0.1, 0.4, 0.35, 0.8], [0.6, 0.3, 0.1, 0.7]]
>>> for t, p in zip(target, preds):
... metric.update_state(t, p)
>>> metric.compute()
0.5833333333333333
"""

name: str = "Average Precision"

def compute( # type: ignore[override]
self,
) -> float:
"""Compute the average precision score from the state."""
if self.thresholds is None:
state = (
np.concatenate(self.target, axis=0), # type: ignore[attr-defined]
np.concatenate(self.preds, axis=0), # type: ignore[attr-defined]
)
else:
state = self.confmat # type: ignore[attr-defined]

return _binary_average_precision_compute(
state,
self.thresholds,
self.pos_label,
)


class AveragePrecision(
Metric,
registry_key="average_precision",
force_register=True,
):
"""Compute the precision-recall curve for different classification tasks.
Parameters
----------
task : Literal["binary", "multiclass", "multilabel"]
The task for which the precision-recall curve is computed.
thresholds : int or list of floats or numpy.ndarray of floats, default=None
Thresholds used for computing the precision and recall scores. If int,
then the number of thresholds to use. If list or array, then the
thresholds to use. If None, then the thresholds are automatically
determined by the sunique values in ``preds``
pos_label : int, default=1
Label to consider as positive for binary classification tasks.
num_classes : int, optional
The number of classes in the dataset. Required if ``task`` is
``"multiclass"``.
num_labels : int, optional
The number of labels in the dataset. Required if ``task`` is
``"multilabel"``.
Examples
--------
>>> # (binary)
>>> from cyclops.evaluate.metrics import PrecisionRecallCurve
>>> target = [1, 1, 1, 0]
>>> preds = [0.6, 0.2, 0.3, 0.8]
>>> metric = AveragePrecision(task="binary", thresholds=None)
>>> metric(target, preds)
0.6388888888888888
>>> metric.reset_state()
>>> target = [[1, 0, 1, 1], [0, 0, 0, 1]]
>>> preds = [[0.5, 0.4, 0.1, 0.3], [0.9, 0.6, 0.45, 0.8]]
>>> for t, p in zip(target, preds):
... metric.update_state(t, p)
>>> metric.compute()
0.48214285714285715
"""

name: str = "Average Precision"

def __new__( # type: ignore # mypy expects a subclass of AveragePrecision
cls: Type[Metric],
task: Literal["binary", "multiclass", "multilabel"],
thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None,
pos_label: int = 1,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
) -> Metric:
"""Create a task-specific instance of the average precision metric."""
if task == "binary":
return BinaryAveragePrecision(
thresholds=thresholds,
pos_label=pos_label,
)
if task == "multiclass":
NotImplementedError("Multiclass average precision is not implemented.")
if task == "multilabel":
NotImplementedError("Multilabel average precision is not implemented.")

raise ValueError(
"Expected argument `task` to be either 'binary', 'multiclass' or "
f"'multilabel', but got {task}",
)
12 changes: 6 additions & 6 deletions cyclops/evaluate/metrics/experimental/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from cyclops.evaluate.metrics.experimental.metric import Metric
from cyclops.evaluate.metrics.experimental.utils.ops import dim_zero_cat
from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array


class _AbstractConfusionMatrix(Metric):
Expand Down Expand Up @@ -112,7 +112,7 @@ class BinaryConfusionMatrix(
>>> metric = BinaryConfusionMatrix()
>>> metric(target, preds)
Array([[2, 1],
[1, 2]], dtype=int32)
[1, 2]], dtype=int64)
>>> target = np.asarray([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = np.asarray([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
Expand Down Expand Up @@ -303,10 +303,10 @@ class MultilabelConfusionMatrix(
>>> metric(target, preds)
Array([[[1, 0],
[0, 1]],
<BLANKLINE>
[[1, 0],
[1, 0]],
<BLANKLINE>
[[0, 1],
[0, 1]]], dtype=int64)
>>> target = np.asarray([[0, 1, 0], [1, 0, 1]])
Expand All @@ -315,10 +315,10 @@ class MultilabelConfusionMatrix(
>>> metric(target, preds)
Array([[[1, 0],
[0, 1]],
<BLANKLINE>
[[1, 0],
[1, 0]],
<BLANKLINE>
[[0, 1],
[0, 1]]], dtype=int64)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, List, Optional

from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.utils.log import setup_logging


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
DistributedBackend,
)
from cyclops.evaluate.metrics.experimental.utils.ops import flatten
from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.utils.optional import import_optional_module


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
squeeze_all,
to_int,
)
from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.evaluate.metrics.experimental.utils.validation import (
_basic_input_array_checks,
_check_same_shape,
Expand Down Expand Up @@ -252,7 +252,7 @@ class over the number of samples with the same true class.
>>> preds = np.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_confusion_matrix(target, preds)
Array([[2, 1],
[1, 2]], dtype=int32)
[1, 2]], dtype=int64)
""" # noqa: W505
_binary_confusion_matrix_validate_args(
Expand Down Expand Up @@ -707,21 +707,21 @@ class over the number of true samples for each class.
>>> multilabel_confusion_matrix(target, preds, num_labels=3)
Array([[[1, 0],
[0, 1]],
<BLANKLINE>
[[1, 0],
[1, 0]],
<BLANKLINE>
[[0, 1],
[0, 1]]], dtype=int64)
>>> target = np.asarray([[0, 1, 0], [1, 0, 1]])
>>> preds = np.asarray([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_confusion_matrix(target, preds, num_labels=3)
Array([[[1, 0],
[0, 1]],
<BLANKLINE>
[[1, 0],
[1, 0]],
<BLANKLINE>
[[0, 1],
[0, 1]]], dtype=int64)
Expand Down
Loading

0 comments on commit cfa1eff

Please sign in to comment.