Skip to content

Commit

Permalink
Fix f1 score for macro and ignore index (#495)
Browse files Browse the repository at this point in the history
* fix
* add testing

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
SkafteNicki and pre-commit-ci[bot] committed Sep 3, 2021
1 parent fd58980 commit e38cb70
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed bug in `F1` with `average='macro'` and `ignore_index!=None` ([#495](https://github.com/PyTorchLightning/metrics/pull/495))

## [0.5.1] - 2021-08-30

Expand Down
9 changes: 6 additions & 3 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,22 +426,25 @@ def test_top_k(
assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)


@pytest.mark.parametrize("ignore_index", [None, 2])
@pytest.mark.parametrize("average", ["micro", "macro", "weighted"])
@pytest.mark.parametrize(
"metric_class, metric_functional, sk_fn",
[(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1, f1, f1_score)],
)
def test_same_input(metric_class, metric_functional, sk_fn, average):
def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_index):
preds = _input_miss_class.preds
target = _input_miss_class.target
preds_flat = torch.cat([p for p in preds], dim=0)
target_flat = torch.cat([t for t in target], dim=0)

mc = metric_class(num_classes=NUM_CLASSES, average=average)
mc = metric_class(num_classes=NUM_CLASSES, average=average, ignore_index=ignore_index)
for i in range(NUM_BATCHES):
mc.update(preds[i], target[i])
class_res = mc.compute()
func_res = metric_functional(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average)
func_res = metric_functional(
preds_flat, target_flat, num_classes=NUM_CLASSES, average=average, ignore_index=ignore_index
)
sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=0)

assert torch.allclose(class_res, torch.tensor(sk_res).float())
Expand Down
11 changes: 5 additions & 6 deletions torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def _fbeta_compute(
>>> _fbeta_compute(tp, fp, tn, fn, beta=0.5, ignore_index=None, average='micro', mdmc_average=None)
tensor(0.3333)
"""

if average == AvgMethod.MICRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
mask = tp >= 0
precision = _safe_divide(tp[mask].sum().float(), (tp[mask] + fp[mask]).sum())
Expand All @@ -73,11 +72,6 @@ def _fbeta_compute(
precision = _safe_divide(tp.float(), tp + fp)
recall = _safe_divide(tp.float(), tp + fn)

if average == AvgMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
cond = tp + fp + fn == 0
precision = precision[~cond]
recall = recall[~cond]

num = (1 + beta ** 2) * precision * recall
denom = beta ** 2 * precision + recall
denom[denom == 0.0] = 1.0 # avoid division by 0
Expand All @@ -100,6 +94,11 @@ def _fbeta_compute(
num[ignore_index, ...] = -1
denom[ignore_index, ...] = -1

if average == AvgMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
cond = (tp + fp + fn == 0) | (tp + fp + fn == -3)
num = num[~cond]
denom = denom[~cond]

return _reduce_stat_scores(
numerator=num,
denominator=denom,
Expand Down

0 comments on commit e38cb70

Please sign in to comment.