Skip to content

Commit

Permalink
Fix Matthews correlation coefficient when the denominator is 0 (#781)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
(cherry picked from commit cfe5e87)
  • Loading branch information
SkafteNicki authored and Borda committed Jan 19, 2022
1 parent 768bcd4 commit 9b7d4c2
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed check for available modules ([#772](https://github.com/PyTorchLightning/metrics/pull/772))


- Fixed Matthews correlation coefficient when the denominator is 0 ([#781](https://github.com/PyTorchLightning/metrics/pull/781))


## [0.7.0] - 2022-01-17

### Added
Expand Down
7 changes: 7 additions & 0 deletions tests/classification/test_matthews_corrcoef.py
Expand Up @@ -140,3 +140,10 @@ def test_matthews_corrcoef_differentiability(self, preds, target, sk_metric, num
"threshold": THRESHOLD,
},
)


def test_zero_case():
"""Cases where the denominator in the matthews corrcoef is 0, the score should return 0."""
# Example where neither 1 or 2 is present in the target tensor
out = matthews_corrcoef(torch.tensor([0, 1, 2]), torch.tensor([0, 0, 0]), 3)
assert out == 0.0
10 changes: 9 additions & 1 deletion torchmetrics/functional/classification/matthews_corrcoef.py
Expand Up @@ -37,7 +37,15 @@ def _matthews_corrcoef_compute(confmat: Tensor) -> Tensor:
pk = confmat.sum(dim=0).float()
c = torch.trace(confmat).float()
s = confmat.sum().float()
return (c * s - sum(tk * pk)) / (torch.sqrt(s ** 2 - sum(pk * pk)) * torch.sqrt(s ** 2 - sum(tk * tk)))

cov_ytyp = c * s - sum(tk * pk)
cov_ypyp = s ** 2 - sum(pk * pk)
cov_ytyt = s ** 2 - sum(tk * tk)

if cov_ypyp * cov_ytyt == 0:
return torch.tensor(0, dtype=confmat.dtype, device=confmat.device)
else:
return cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp)


def matthews_corrcoef(
Expand Down

0 comments on commit 9b7d4c2

Please sign in to comment.