Skip to content

Commit

Permalink
Add precision and recall (#525)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi committed Dec 7, 2023
1 parent d26bb50 commit 4cd2eb2
Show file tree
Hide file tree
Showing 8 changed files with 2,165 additions and 36 deletions.
8 changes: 8 additions & 0 deletions cyclops/evaluate/metrics/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,11 @@
MulticlassConfusionMatrix,
MultilabelConfusionMatrix,
)
from cyclops.evaluate.metrics.experimental.precision_recall import (
BinaryPrecision,
BinaryRecall,
MulticlassPrecision,
MulticlassRecall,
MultilabelPrecision,
MultilabelRecall,
)
8 changes: 8 additions & 0 deletions cyclops/evaluate/metrics/experimental/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,11 @@
multiclass_confusion_matrix,
multilabel_confusion_matrix,
)
from cyclops.evaluate.metrics.experimental.functional.precision_recall import (
binary_precision,
binary_recall,
multiclass_precision,
multiclass_recall,
multilabel_precision,
multilabel_recall,
)
40 changes: 7 additions & 33 deletions cyclops/evaluate/metrics/experimental/functional/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Functions for computing the accuracy score for classification tasks."""
from types import ModuleType
from typing import Literal, Optional, Tuple, Union

import array_api_compat as apc
Expand All @@ -19,6 +18,7 @@
_multilabel_stat_scores_validate_arrays,
)
from cyclops.evaluate.metrics.experimental.utils.ops import (
_adjust_weight_apply_average,
safe_divide,
squeeze_all,
)
Expand Down Expand Up @@ -151,7 +151,7 @@ def _accuracy_compute(
else safe_divide(numerator=tp, denominator=tp + fn)
)

return _accuracy_score_average(
return _adjust_weight_apply_average(
score,
average=average,
is_multilabel=is_multilabel,
Expand All @@ -162,37 +162,6 @@ def _accuracy_compute(
)


def _accuracy_score_average(
score: Array,
average: Optional[Literal["macro", "weighted", "none"]],
is_multilabel: bool,
*,
tp: Array,
fp: Array,
fn: Array,
xp: ModuleType,
) -> Array:
"""Apply the specified averaging method to the accuracy scores."""
if average is None or average == "none":
return score
if average == "weighted":
weights = tp + fn
else: # average == "macro"
weights = xp.ones_like(score)
if not is_multilabel:
weights[tp + fp + fn == 0] = 0.0

weights = xp.astype(weights, xp.float32)
return xp.sum( # type: ignore[no-any-return]
safe_divide(
weights * score,
xp.sum(weights, axis=-1, dtype=score.dtype, keepdims=True),
),
axis=-1,
dtype=score.dtype,
)


def multiclass_accuracy(
target: Array,
preds: Array,
Expand Down Expand Up @@ -402,6 +371,11 @@ def multilabel_accuracy(
ignore_index : int, optional, default=None
Specifies value in `target` that is ignored when computing the accuracy score.
Returns
-------
Array
An array API compatible object containing the accuracy score(s).
Raises
------
ValueError
Expand Down
Loading

0 comments on commit 4cd2eb2

Please sign in to comment.