Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 2, 2021
1 parent 289f347 commit 4b01415
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion tests/classification/test_accuracy.py
Expand Up @@ -345,7 +345,7 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
assert torch.allclose(expected, result_cl, equal_nan=True)


@pytest.mark.parametrize('average', ['micro', 'macro', 'weighted'])
@pytest.mark.parametrize("average", ["micro", "macro", "weighted"])
def test_same_input(average):
preds = _input_miss_class.preds
target = _input_miss_class.target
Expand Down
6 changes: 3 additions & 3 deletions tests/classification/test_f_beta.py
Expand Up @@ -427,10 +427,10 @@ def test_top_k(
assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)


@pytest.mark.parametrize('average', ['micro', 'macro', 'weighted'])
@pytest.mark.parametrize("average", ["micro", "macro", "weighted"])
@pytest.mark.parametrize(
'metric_class, metric_functional, sk_fn',
[(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1, f1, f1_score)]
"metric_class, metric_functional, sk_fn",
[(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1, f1, f1_score)],
)
def test_same_input(metric_class, metric_functional, sk_fn, average):
preds = _input_miss_class.preds
Expand Down
4 changes: 2 additions & 2 deletions tests/classification/test_precision_recall.py
Expand Up @@ -440,9 +440,9 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
assert torch.allclose(expected, result_cl, equal_nan=True)


@pytest.mark.parametrize('average', ['micro', 'macro', 'weighted'])
@pytest.mark.parametrize("average", ["micro", "macro", "weighted"])
@pytest.mark.parametrize(
'metric_class, metric_functional, sk_fn', [(Precision, precision, precision_score), (Recall, recall, recall_score)]
"metric_class, metric_functional, sk_fn", [(Precision, precision, precision_score), (Recall, recall, recall_score)]
)
def test_same_input(metric_class, metric_functional, sk_fn, average):
preds = _input_miss_class.preds
Expand Down
6 changes: 3 additions & 3 deletions torchmetrics/functional/classification/f_beta.py
Expand Up @@ -51,9 +51,9 @@ def _fbeta_compute(
precision = precision[~cond]
recall = recall[~cond]

num = (1 + beta**2) * precision * recall
denom = beta**2 * precision + recall
denom[denom == 0.] = 1.0 # avoid division by 0
num = (1 + beta ** 2) * precision * recall
denom = beta ** 2 * precision + recall
denom[denom == 0.0] = 1.0 # avoid division by 0

# if classes matter and a given class is not present in both the preds and the target,
# computing the score for this class is meaningless, thus they should be ignored
Expand Down

0 comments on commit 4b01415

Please sign in to comment.