Skip to content

Commit

Permalink
Fix corner case for R2Score (#1576)
Browse files Browse the repository at this point in the history
(cherry picked from commit 17e0890)
  • Loading branch information
SkafteNicki authored and Borda committed Mar 10, 2023
1 parent 5d7b1aa commit 0cef3ce
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed the use of `ignore_index` in `MulticlassJaccardIndex` ([#1386](https://github.com/Lightning-AI/metrics/pull/1386))


- Fixed evaluation of `R2Score` with near constant target ([#1576](https://github.com/Lightning-AI/metrics/pull/1576))


## [0.11.2] - 2023-02-21

### Fixed
Expand Down
11 changes: 9 additions & 2 deletions src/torchmetrics/functional/regression/r2.py
Expand Up @@ -42,7 +42,6 @@ def _r2_score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Ten
residual = target - preds
rss = torch.sum(residual * residual, dim=0)
n_obs = target.size(0)

return sum_squared_obs, sum_obs, rss, n_obs


Expand Down Expand Up @@ -80,7 +79,15 @@ def _r2_score_compute(

mean_obs = sum_obs / n_obs
tss = sum_squared_obs - sum_obs * mean_obs
raw_scores = 1 - (rss / tss)

# Account for near constant targets
cond_rss = ~torch.isclose(rss, torch.zeros_like(rss), atol=1e-4)
cond_tss = ~torch.isclose(tss, torch.zeros_like(tss), atol=1e-4)
cond = cond_rss & cond_tss

raw_scores = torch.ones_like(rss)
raw_scores[cond] = 1 - (rss[cond] / tss[cond])
raw_scores[cond_rss & ~cond_tss] = 0.0

if multioutput == "raw_values":
r2 = raw_scores
Expand Down
3 changes: 3 additions & 0 deletions src/torchmetrics/regression/r2.py
Expand Up @@ -32,6 +32,9 @@ class R2Score(Metric):
.. math:: R^2_{adj} = 1 - \frac{(1-R^2)(n-1)}{n-k-1}
where the parameter :math:`k` (the number of independent regressors) should be provided as the `adjusted` argument.
The score is only proper defined when :math:`SS_{tot}\neq 0`, which can happen for near constant targets. In this
case a score of 0 is returned. By definition the score is bounded between 0 and 1, where 1 corresponds to the
predictions exactly matching the targets.
As input to ``forward`` and ``update`` the metric accepts the following input:
Expand Down
8 changes: 8 additions & 0 deletions tests/unittests/regression/test_r2.py
Expand Up @@ -158,3 +158,11 @@ def test_warning_on_too_large_adjusted(metric_class=R2Score):

with pytest.warns(UserWarning, match="Division by zero in adjusted r2 score. Falls back to" " standard r2 score."):
metric(torch.randn(11), torch.randn(11))


def test_constant_target():
"""Check for a near constant target that a value of 0 is returned."""
y_true = torch.tensor([-5.1608, -5.1609, -5.1608, -5.1608, -5.1608, -5.1608])
y_pred = torch.tensor([-3.9865, -5.4648, -5.0238, -4.3899, -5.6672, -4.7336])
score = r2_score(preds=y_pred, target=y_true)
assert score == 0

0 comments on commit 0cef3ce

Please sign in to comment.