diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index 80d15aff4b7..b07fb742db4 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -345,7 +345,7 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected): assert torch.allclose(expected, result_cl, equal_nan=True) -@pytest.mark.parametrize('average', ['micro', 'macro', 'weighted']) +@pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) def test_same_input(average): preds = _input_miss_class.preds target = _input_miss_class.target diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index 0fffda47c31..6c9e7d38d48 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -427,10 +427,10 @@ def test_top_k( assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) -@pytest.mark.parametrize('average', ['micro', 'macro', 'weighted']) +@pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) @pytest.mark.parametrize( - 'metric_class, metric_functional, sk_fn', - [(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1, f1, f1_score)] + "metric_class, metric_functional, sk_fn", + [(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1, f1, f1_score)], ) def test_same_input(metric_class, metric_functional, sk_fn, average): preds = _input_miss_class.preds diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index 935eb13f36f..c65a63d84c3 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -440,9 +440,9 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected): assert torch.allclose(expected, result_cl, equal_nan=True) -@pytest.mark.parametrize('average', ['micro', 'macro', 'weighted']) +@pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) @pytest.mark.parametrize( - 'metric_class, metric_functional, sk_fn', [(Precision, precision, precision_score), (Recall, recall, recall_score)] + "metric_class, metric_functional, sk_fn", [(Precision, precision, precision_score), (Recall, recall, recall_score)] ) def test_same_input(metric_class, metric_functional, sk_fn, average): preds = _input_miss_class.preds diff --git a/torchmetrics/functional/classification/f_beta.py b/torchmetrics/functional/classification/f_beta.py index 66bfee63910..9d6ddba3f5b 100644 --- a/torchmetrics/functional/classification/f_beta.py +++ b/torchmetrics/functional/classification/f_beta.py @@ -51,9 +51,9 @@ def _fbeta_compute( precision = precision[~cond] recall = recall[~cond] - num = (1 + beta**2) * precision * recall - denom = beta**2 * precision + recall - denom[denom == 0.] = 1.0 # avoid division by 0 + num = (1 + beta ** 2) * precision * recall + denom = beta ** 2 * precision + recall + denom[denom == 0.0] = 1.0 # avoid division by 0 # if classes matter and a given class is not present in both the preds and the target, # computing the score for this class is meaningless, thus they should be ignored