Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for bug when providing superclass arguments as kwargs #1069

Merged
merged 16 commits into from Jun 7, 2022
6 changes: 6 additions & 0 deletions tests/classification/test_accuracy.py
Expand Up @@ -448,3 +448,9 @@ def test_negmetric_noneavg(noneavg=_negmetric_noneavg):
assert torch.allclose(noneavg["res1"], result1, equal_nan=True)
result2 = acc(noneavg["pred2"], noneavg["target2"])
assert torch.allclose(noneavg["res2"], result2, equal_nan=True)


def test_provide_superclass_kwargs():
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""Test instantiating class providing superclass arguments."""
Accuracy(reduce="micro")
Accuracy(mdmc_reduce="global")
6 changes: 6 additions & 0 deletions tests/classification/test_dice.py
Expand Up @@ -163,3 +163,9 @@ def test_dice_fn(self, preds, target, ignore_index):
sk_metric=partial(_sk_dice, ignore_index=ignore_index),
metric_args={"ignore_index": ignore_index},
)


def test_provide_superclass_kwargs():
"""Test instantiating class providing superclass arguments."""
Dice(reduce="micro")
Dice(mdmc_reduce="global")
8 changes: 8 additions & 0 deletions tests/classification/test_f_beta.py
Expand Up @@ -460,3 +460,11 @@ def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_inde

assert torch.allclose(class_res, torch.tensor(sk_res).float())
assert torch.allclose(func_res, torch.tensor(sk_res).float())


def test_provide_superclass_kwargs():
"""Test instantiating class providing superclass arguments."""
FBetaScore(reduce="micro")
FBetaScore(mdmc_reduce="global")
F1Score(reduce="micro")
F1Score(mdmc_reduce="global")
5 changes: 5 additions & 0 deletions tests/classification/test_jaccard.py
Expand Up @@ -237,3 +237,8 @@ def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, average,
# reduction=reduction,
)
assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val))


def test_provide_superclass_kwargs():
"""Test instantiating class providing superclass arguments."""
jaccard = JaccardIndex(num_classes=1, normalize="true")
8 changes: 8 additions & 0 deletions tests/classification/test_precision_recall.py
Expand Up @@ -468,3 +468,11 @@ def test_noneavg(metric_cls, noneavg=_negmetric_noneavg):
assert torch.allclose(noneavg["res1"], result1, equal_nan=True)
result2 = prec(noneavg["pred2"], noneavg["target2"])
assert torch.allclose(noneavg["res2"], result2, equal_nan=True)


def test_provide_superclass_kwargs():
"""Test instantiating class providing superclass arguments."""
Precision(reduce="micro")
Precision(mdmc_reduce="global")
Recall(reduce="micro")
Recall(mdmc_reduce="global")
6 changes: 6 additions & 0 deletions tests/classification/test_specificity.py
Expand Up @@ -410,3 +410,9 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
cl_metric(preds, target)
result_cl = cl_metric.compute()
assert torch.allclose(expected, result_cl, equal_nan=True)


def test_provide_superclass_kwargs():
"""Test instantiating class providing superclass arguments."""
Specificity(reduce="micro")
Specificity(mdmc_reduce="global")
9 changes: 6 additions & 3 deletions torchmetrics/classification/accuracy.py
Expand Up @@ -23,7 +23,7 @@
_subset_accuracy_compute,
_subset_accuracy_update,
)
from torchmetrics.utilities.enums import DataType
from torchmetrics.utilities.enums import AverageMethod, DataType

from torchmetrics.classification.stat_scores import StatScores # isort:skip

Expand Down Expand Up @@ -176,9 +176,12 @@ def __init__(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if "reduce" not in kwargs or kwargs["reduce"] is None:
kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average
if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
8 changes: 6 additions & 2 deletions torchmetrics/classification/dice.py
Expand Up @@ -17,6 +17,7 @@

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.dice import _dice_compute
from torchmetrics.utilities.enums import AverageMethod


class Dice(StatScores):
Expand Down Expand Up @@ -134,9 +135,12 @@ def __init__(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if "reduce" not in kwargs or kwargs["reduce"] is None:
kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average
if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ("weighted", "none", None) else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
7 changes: 5 additions & 2 deletions torchmetrics/classification/f_beta.py
Expand Up @@ -137,9 +137,12 @@ def __init__(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if "reduce" not in kwargs or kwargs["reduce"] is None:
kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average
if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/classification/jaccard.py
Expand Up @@ -90,9 +90,11 @@ def __init__(
multilabel: bool = False,
**kwargs: Dict[str, Any],
) -> None:
if "normalize" not in kwargs:
kwargs["normalize"] = None

super().__init__(
num_classes=num_classes,
normalize=None,
threshold=threshold,
multilabel=multilabel,
**kwargs,
Expand Down
15 changes: 11 additions & 4 deletions torchmetrics/classification/precision_recall.py
Expand Up @@ -17,6 +17,7 @@

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute
from torchmetrics.utilities.enums import AverageMethod


class Precision(StatScores):
Expand Down Expand Up @@ -127,9 +128,12 @@ def __init__(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if "reduce" not in kwargs or kwargs["reduce"] is None:
kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average
if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down Expand Up @@ -262,9 +266,12 @@ def __init__(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if "reduce" not in kwargs or kwargs["reduce"] is None:
kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average
if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
8 changes: 6 additions & 2 deletions torchmetrics/classification/specificity.py
Expand Up @@ -18,6 +18,7 @@

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.specificity import _specificity_compute
from torchmetrics.utilities.enums import AverageMethod


class Specificity(StatScores):
Expand Down Expand Up @@ -129,9 +130,12 @@ def __init__(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if "reduce" not in kwargs or kwargs["reduce"] is None:
kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average
if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
6 changes: 3 additions & 3 deletions torchmetrics/functional/audio/pesq.py
Expand Up @@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _PESQ_AVAILABLE

if _PESQ_AVAILABLE:
import pesq as pesq_backend
else:
pesq_backend = None
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape

__doctest_requires__ = {("perceptual_evaluation_speech_quality",): ["pesq"]}

Expand Down
5 changes: 2 additions & 3 deletions torchmetrics/functional/audio/stoi.py
Expand Up @@ -13,17 +13,16 @@
# limitations under the License.
import numpy as np
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE

if _PYSTOI_AVAILABLE:
from pystoi import stoi as stoi_backend
else:
stoi_backend = None
__doctest_skip__ = ["short_time_objective_intelligibility"]
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape


def short_time_objective_intelligibility(
Expand Down