Skip to content

Commit

Permalink
Add remaining fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Dec 6, 2023
1 parent d5946fe commit de3cf07
Show file tree
Hide file tree
Showing 16 changed files with 55 additions and 76 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ repos:
hooks:
- id: doctest
name: doctest
entry: python3 -m doctest
entry: python3 -m doctest -o NORMALIZE_WHITESPACE
files: "^cyclops/evaluate/"
language: system

Expand Down
2 changes: 1 addition & 1 deletion cyclops/evaluate/metrics/experimental/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from cyclops.evaluate.metrics.experimental.metric import Metric
from cyclops.evaluate.metrics.experimental.utils.ops import dim_zero_cat
from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array


class _AbstractConfusionMatrix(Metric):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, List, Optional

from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.utils.log import setup_logging


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
DistributedBackend,
)
from cyclops.evaluate.metrics.experimental.utils.ops import flatten
from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.utils.optional import import_optional_module


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
squeeze_all,
to_int,
)
from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.evaluate.metrics.experimental.utils.validation import (
_basic_input_array_checks,
_check_same_shape,
Expand Down Expand Up @@ -252,7 +252,7 @@ class over the number of samples with the same true class.
>>> preds = np.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_confusion_matrix(target, preds)
Array([[2, 1],
[1, 2]], dtype=int32)
[1, 2]], dtype=int64)
""" # noqa: W505
_binary_confusion_matrix_validate_args(
Expand Down Expand Up @@ -707,21 +707,21 @@ class over the number of true samples for each class.
>>> multilabel_confusion_matrix(target, preds, num_labels=3)
Array([[[1, 0],
[0, 1]],
<BLANKLINE>
[[1, 0],
[1, 0]],
<BLANKLINE>
[[0, 1],
[0, 1]]], dtype=int64)
>>> target = np.asarray([[0, 1, 0], [1, 0, 1]])
>>> preds = np.asarray([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_confusion_matrix(target, preds, num_labels=3)
Array([[[1, 0],
[0, 1]],
<BLANKLINE>
[[1, 0],
[1, 0]],
<BLANKLINE>
[[0, 1],
[0, 1]]], dtype=int64)
Expand Down
2 changes: 1 addition & 1 deletion cyclops/evaluate/metrics/experimental/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
dim_zero_sum,
flatten_seq,
)
from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.utils.log import setup_logging


Expand Down
2 changes: 1 addition & 1 deletion cyclops/evaluate/metrics/experimental/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
squeeze_all,
to_int,
)
from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.evaluate.metrics.experimental.utils.validation import (
is_floating_point,
is_numeric,
Expand Down
2 changes: 1 addition & 1 deletion cyclops/evaluate/metrics/experimental/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from array_api_compat.common._helpers import _is_numpy_array, _is_torch_array
from numpy.core.multiarray import normalize_axis_index # type: ignore

from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.evaluate.metrics.experimental.utils.validation import (
_get_int_dtypes,
is_floating_point,
Expand Down
2 changes: 1 addition & 1 deletion cyclops/evaluate/metrics/experimental/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import array_api_compat as apc

from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array


def is_floating_point(array: Array) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,6 @@ def multilabel_precision_recall_curve(
)


# ruff: noqa: W505
def precision_recall_curve(
target: npt.ArrayLike,
preds: npt.ArrayLike,
Expand Down Expand Up @@ -1060,7 +1059,7 @@ def precision_recall_curve(
>>> thresholds
[array([0.05, 0.1 ]), array([0.9 , 0.95]), array([0.35, 0.8 ])]
"""
""" # noqa: W505
if task == "binary":
return binary_precision_recall_curve(
target,
Expand Down
3 changes: 1 addition & 2 deletions cyclops/evaluate/metrics/functional/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,6 @@ def multilabel_roc_curve(
return _multilabel_roc_compute(state, num_labels, thresholds)


# ruff: noqa: W505
def roc_curve(
target: npt.ArrayLike,
preds: npt.ArrayLike,
Expand Down Expand Up @@ -623,7 +622,7 @@ def roc_curve(
>>> thresholds
[array([1. , 0.9, 0.8, 0.2]), array([1. , 0.8, 0.7, 0.3])]
"""
""" # noqa: W505
_check_thresholds(thresholds)
if task == "binary":
return binary_roc_curve(target, preds, thresholds, pos_label=pos_label)
Expand Down
91 changes: 36 additions & 55 deletions cyclops/evaluate/metrics/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,16 @@ class BinaryROCCurve(BinaryPrecisionRecallCurve, registry_key="binary_roc_curve"
>>> preds = [0.1, 0.4, 0.35, 0.8]
>>> metric = BinaryROCCurve()
>>> metric(target, preds)
(array([0. , 0. , 0.5, 0.5, 1. ]),
array([0. , 0.5, 0.5, 1. , 1. ]),
array([1. , 0.8 , 0.4 , 0.35, 0.1 ]))
(array([0. , 0. , 0.5, 0.5, 1. ]), array([0. , 0.5, 0.5, 1. , 1. ]), array([1. , 0.8 , 0.4 , 0.35, 0.1 ]))
>>> metric.reset_state()
>>> target = [[1, 1, 0, 0], [0, 0, 1, 1]]
>>> preds = [[0.1, 0.2, 0.3, 0.4], [0.6, 0.5, 0.4, 0.3]]
>>> for t, p in zip(target, preds):
... metric.update_state(t, p)
>>> metric.compute()
(array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]),
array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]),
array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]))
(array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]))
"""
""" # noqa: W505

name: str = "ROC Curve"

Expand Down Expand Up @@ -107,12 +103,10 @@ class MulticlassROCCurve(
>>> metric = MulticlassROCCurve(num_classes=3, thresholds=4)
>>> metric(target, preds)
(array([[0. , 0. , 0. , 1. ],
[0. , 0.33333333, 0.33333333, 1. ],
[0. , 0. , 0. , 1. ]]),
array([[0. , 0.5, 0.5, 1. ],
[0. , 1. , 1. , 1. ],
[0. , 0. , 1. , 1. ]]),
array([1. , 0.66666667, 0.33333333, 0. ]))
[0. , 0.33333333, 0.33333333, 1. ],
[0. , 0. , 0. , 1. ]]), array([[0. , 0.5, 0.5, 1. ],
[0. , 1. , 1. , 1. ],
[0. , 0. , 1. , 1. ]]), array([1. , 0.66666667, 0.33333333, 0. ]))
>>> metric.reset_state()
>>> target = [[1, 1, 0, 0], [0, 0, 1, 1]]
>>> preds = [[[0.1, 0.2, 0.7], [0.5, 0.4, 0.1],
Expand All @@ -123,14 +117,12 @@ class MulticlassROCCurve(
... metric.update_state(t, p)
>>> metric.compute()
(array([[0. , 0.25, 0.5 , 1. ],
[0. , 0. , 0.25, 1. ],
[0. , 0.25, 0.5 , 1. ]]),
array([[0. , 0.25, 0.5 , 1. ],
[0. , 0. , 0.25, 1. ],
[0. , 0. , 0. , 0. ]]),
array([1. , 0.66666667, 0.33333333, 0. ]))
[0. , 0. , 0.25, 1. ],
[0. , 0.25, 0.5 , 1. ]]), array([[0. , 0.25, 0.5 , 1. ],
[0. , 0. , 0.25, 1. ],
[0. , 0. , 0. , 0. ]]), array([1. , 0.66666667, 0.33333333, 0. ]))
"""
""" # noqa: W505

name: str = "ROC Curve"

Expand Down Expand Up @@ -185,12 +177,10 @@ class MultilabelROCCurve(
>>> metric = MultilabelROCCurve(num_labels=3, thresholds=4)
>>> metric(target, preds)
(array([[0. , 0. , 0. , 1. ],
[0. , 0. , 0. , 0. ],
[0. , 0.5, 0.5, 1. ]]),
array([[0., 0., 0., 1.],
[0., 1., 1., 1.],
[0., 0., 0., 0.]]),
array([1. , 0.66666667, 0.33333333, 0. ]))
[0. , 0. , 0. , 0. ],
[0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.],
[0., 1., 1., 1.],
[0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ]))
>>> metric.reset_state()
>>> target = [[[1, 1, 0], [0, 1, 0]], [[1, 1, 0], [0, 1, 0]]]
>>> preds = [[[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
Expand All @@ -199,12 +189,10 @@ class MultilabelROCCurve(
... metric.update_state(t, p)
>>> metric.compute()
(array([[0. , 0. , 0. , 1. ],
[0. , 0. , 0. , 0. ],
[0. , 0.5, 0.5, 1. ]]),
array([[0., 0., 0., 1.],
[0., 1., 1., 1.],
[0., 0., 0., 0.]]),
array([1. , 0.66666667, 0.33333333, 0. ]))
[0. , 0. , 0. , 0. ],
[0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.],
[0., 1., 1., 1.],
[0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ]))
"""

Expand Down Expand Up @@ -272,44 +260,37 @@ class ROCCurve(Metric, registry_key="roc_curve", force_register=True):
>>> preds = [0.1, 0.4, 0.35, 0.8]
>>> metric = ROCCurve(task="binary", thresholds=None)
>>> metric(target, preds)
(array([0. , 0. , 0.5, 0.5, 1. ]),
array([0. , 0.5, 0.5, 1. , 1. ]),
array([1. , 0.8 , 0.4 , 0.35, 0.1 ]))
(array([0. , 0. , 0.5, 0.5, 1. ]), array([0. , 0.5, 0.5, 1. , 1. ]), array([1. , 0.8 , 0.4 , 0.35, 0.1 ]))
>>> metric.reset_state()
>>> target = [[1, 1, 0, 0], [0, 0, 1, 1]]
>>> preds = [[0.1, 0.2, 0.3, 0.4], [0.6, 0.5, 0.4, 0.3]]
>>> for t, p in zip(target, preds):
... metric.update_state(t, p)
>>> metric.compute()
(array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]),
array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]),
array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]))
(array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]))
>>> # (multiclass)
>>> from cyclops.evaluate.metrics import ROCCurve
>>> target = [[1, 1, 0], [0, 1, 0]]
>>> preds = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
>>> target = [1, 2, 0]
>>> preds = [[0.05, 0.95, 0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6]]
>>> metric = ROCCurve(task="multiclass", num_classes=3, thresholds=4)
>>> metric(target, preds)
(array([[0. , 0. , 0. , 1. ],
[0. , 0. , 0. , 0. ],
[0. , 0.5, 0.5, 1. ]]),
array([[0., 0., 0., 1.],
[0., 1., 1., 1.],
[0., 0., 0., 0.]]),
array([1. , 0.66666667, 0.33333333, 0. ]))
[0. , 0.5, 0.5, 1. ],
[0. , 0. , 0.5, 1. ]]), array([[0., 0., 0., 1.],
[0., 1., 1., 1.],
[0., 0., 0., 1.]]), array([1. , 0.66666667, 0.33333333, 0. ]))
>>> metric.reset_state()
>>> target = [[[1, 1, 0], [0, 1, 0]], [[1, 1, 0], [0, 1, 0]]]
>>> preds = [[[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]]
>>> target = [1, 2]
>>> preds = [[[0.05, 0.75, 0.2]], [[0.1, 0.8, 0.1]]]
>>> for t, p in zip(target, preds):
... metric.update_state(t, p)
>>> metric.compute()
(array([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]), array([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ]))
(array([[0., 0., 0., 1.],
[0., 1., 1., 1.],
[0., 0., 0., 1.]]), array([[0., 0., 0., 0.],
[0., 1., 1., 1.],
[0., 0., 0., 1.]]), array([1. , 0.66666667, 0.33333333, 0. ]))
>>> # (multilabel)
>>> from cyclops.evaluate.metrics import ROCCurve
Expand All @@ -335,7 +316,7 @@ class ROCCurve(Metric, registry_key="roc_curve", force_register=True):
[0., 1., 1., 1.],
[0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ]))
"""
""" # noqa: W505

name: str = "ROC Curve"

Expand Down
2 changes: 1 addition & 1 deletion tests/cyclops/evaluate/metrics/experimental/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from scipy.special import log_softmax

from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array

from ..conftest import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, NUM_LABELS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from cyclops.evaluate.metrics.experimental.metric import Metric, OperatorMetric
from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array


class DummyMetric(Metric):
Expand Down
2 changes: 1 addition & 1 deletion tests/cyclops/evaluate/metrics/experimental/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from cyclops.evaluate.metrics.experimental.metric import Metric
from cyclops.evaluate.metrics.experimental.utils.ops import clone, flatten
from cyclops.evaluate.metrics.experimental.utils.typing import Array
from cyclops.evaluate.metrics.experimental.utils.types import Array


def _assert_allclose(
Expand Down

0 comments on commit de3cf07

Please sign in to comment.