Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for bug when providing superclass arguments as kwargs #1069

Merged
merged 16 commits into from Jun 7, 2022
Merged
3 changes: 2 additions & 1 deletion CHANGELOG.md
Expand Up @@ -45,7 +45,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed aggregation metrics when input only contains zero ([#1070](https://github.com/PyTorchLightning/metrics/pull/1070))

-

- Fixed `TypeError` when providing superclass arguments as kwargs ([#1069](https://github.com/PyTorchLightning/metrics/pull/1069))


## [0.9.0] - 2022-05-30
Expand Down
17 changes: 17 additions & 0 deletions tests/classification/test_confusion_matrix.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Any, Dict

import numpy as np
import pytest
Expand All @@ -30,6 +31,7 @@
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import JaccardIndex
from torchmetrics.classification.confusion_matrix import ConfusionMatrix
from torchmetrics.functional.classification.confusion_matrix import confusion_matrix

Expand Down Expand Up @@ -186,3 +188,18 @@ def test_warning_on_nan(tmpdir):
match=".* nan values found in confusion matrix have been replaced with zeros.",
):
confusion_matrix(preds, target, num_classes=5, normalize="true")


@pytest.mark.parametrize(
"metric_args",
[
{"num_classes": 1, "normalize": "true"},
{"num_classes": 1, "normalize": "pred"},
{"num_classes": 1, "normalize": "all"},
{"num_classes": 1, "normalize": "none"},
{"num_classes": 1, "normalize": None},
],
)
def test_provide_superclass_kwargs(metric_args: Dict[str, Any]):
"""Test instantiating subclasses with superclass arguments as kwargs."""
JaccardIndex(**metric_args)
21 changes: 19 additions & 2 deletions tests/classification/test_stat_scores.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional
from typing import Any, Callable, Dict, Optional

import numpy as np
import pytest
Expand All @@ -30,7 +30,7 @@
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics import StatScores
from torchmetrics import Accuracy, Dice, FBetaScore, Precision, Recall, Specificity, StatScores
from torchmetrics.functional import stat_scores
from torchmetrics.utilities.checks import _input_format_classification

Expand Down Expand Up @@ -326,3 +326,20 @@ def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Ten

assert torch.equal(class_metric.compute(), expected.T)
assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T)


@pytest.mark.parametrize(
"metric_args",
[
{"reduce": "micro"},
{"num_classes": 1, "reduce": "macro"},
{"reduce": "samples"},
{"mdmc_reduce": None},
{"mdmc_reduce": "samplewise"},
{"mdmc_reduce": "global"},
],
)
@pytest.mark.parametrize("metric_cls", [Accuracy, Dice, FBetaScore, Precision, Recall, Specificity])
def test_provide_superclass_kwargs(metric_cls: StatScores, metric_args: Dict[str, Any]):
"""Test instantiating subclasses with superclass arguments as kwargs."""
metric_cls(**metric_args)
5 changes: 2 additions & 3 deletions tests/text/test_mer.py
Expand Up @@ -4,16 +4,15 @@

from tests.text.helpers import TextTester
from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2
from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.text.mer import MatchErrorRate
from torchmetrics.utilities.imports import _JIWER_AVAILABLE

if _JIWER_AVAILABLE:
from jiwer import compute_measures
else:
compute_measures: Callable

from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.text.mer import MatchErrorRate


def _compute_mer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]):
return compute_measures(target, preds)["mer"]
Expand Down
5 changes: 2 additions & 3 deletions tests/text/test_wer.py
Expand Up @@ -4,16 +4,15 @@

from tests.text.helpers import TextTester
from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2
from torchmetrics.functional.text.wer import word_error_rate
from torchmetrics.text.wer import WordErrorRate
from torchmetrics.utilities.imports import _JIWER_AVAILABLE

if _JIWER_AVAILABLE:
from jiwer import compute_measures
else:
compute_measures: Callable

from torchmetrics.functional.text.wer import word_error_rate
from torchmetrics.text.wer import WordErrorRate


def _compute_wer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]):
return compute_measures(target, preds)["wer"]
Expand Down
14 changes: 9 additions & 5 deletions torchmetrics/classification/accuracy.py
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Optional

from torch import Tensor, tensor

Expand All @@ -23,7 +23,7 @@
_subset_accuracy_compute,
_subset_accuracy_update,
)
from torchmetrics.utilities.enums import DataType
from torchmetrics.utilities.enums import AverageMethod, DataType

from torchmetrics.classification.stat_scores import StatScores # isort:skip

Expand Down Expand Up @@ -170,15 +170,19 @@ def __init__(
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
subset_accuracy: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

_reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
if "reduce" not in kwargs:
kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
if "mdmc_reduce" not in kwargs:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
13 changes: 9 additions & 4 deletions torchmetrics/classification/dice.py
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Optional

from torch import Tensor

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.dice import _dice_compute
from torchmetrics.utilities.enums import AverageMethod


class Dice(StatScores):
Expand Down Expand Up @@ -128,15 +129,19 @@ def __init__(
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
allowed_average = ("micro", "macro", "weighted", "samples", "none", None)
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

_reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
if "reduce" not in kwargs:
kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
if "mdmc_reduce" not in kwargs:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ("weighted", "none", None) else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
10 changes: 7 additions & 3 deletions torchmetrics/classification/f_beta.py
Expand Up @@ -130,16 +130,20 @@ def __init__(
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
self.beta = beta
allowed_average = list(AverageMethod)
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

_reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
if "reduce" not in kwargs:
kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
if "mdmc_reduce" not in kwargs:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
7 changes: 4 additions & 3 deletions torchmetrics/classification/jaccard.py
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -88,11 +88,12 @@ def __init__(
absent_score: float = 0.0,
threshold: float = 0.5,
multilabel: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
kwargs["normalize"] = kwargs.get("normalize")

super().__init__(
num_classes=num_classes,
normalize=None,
threshold=threshold,
multilabel=multilabel,
**kwargs,
Expand Down
23 changes: 16 additions & 7 deletions torchmetrics/classification/precision_recall.py
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Optional

from torch import Tensor

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute
from torchmetrics.utilities.enums import AverageMethod


class Precision(StatScores):
Expand Down Expand Up @@ -121,15 +122,19 @@ def __init__(
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

_reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
if "reduce" not in kwargs:
kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
if "mdmc_reduce" not in kwargs:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down Expand Up @@ -256,15 +261,19 @@ def __init__(
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

_reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
if "reduce" not in kwargs:
kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
if "mdmc_reduce" not in kwargs:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
13 changes: 9 additions & 4 deletions torchmetrics/classification/specificity.py
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Optional

import torch
from torch import Tensor

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.specificity import _specificity_compute
from torchmetrics.utilities.enums import AverageMethod


class Specificity(StatScores):
Expand Down Expand Up @@ -123,15 +124,19 @@ def __init__(
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

_reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
if "reduce" not in kwargs:
kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
if "mdmc_reduce" not in kwargs:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
6 changes: 3 additions & 3 deletions torchmetrics/functional/audio/pesq.py
Expand Up @@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _PESQ_AVAILABLE

if _PESQ_AVAILABLE:
import pesq as pesq_backend
else:
pesq_backend = None
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape

__doctest_requires__ = {("perceptual_evaluation_speech_quality",): ["pesq"]}

Expand Down
5 changes: 2 additions & 3 deletions torchmetrics/functional/audio/stoi.py
Expand Up @@ -13,17 +13,16 @@
# limitations under the License.
import numpy as np
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE

if _PYSTOI_AVAILABLE:
from pystoi import stoi as stoi_backend
else:
stoi_backend = None
__doctest_skip__ = ["short_time_objective_intelligibility"]
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape


def short_time_objective_intelligibility(
Expand Down