-
Notifications
You must be signed in to change notification settings - Fork 388
/
r2score.py
131 lines (109 loc) · 4.81 KB
/
r2score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.checks import _check_same_shape
def _r2score_update(
preds: torch.tensor,
target: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
_check_same_shape(preds, target)
if preds.ndim > 2:
raise ValueError(
'Expected both prediction and target to be 1D or 2D tensors,'
f' but recevied tensors with dimension {preds.shape}'
)
if len(preds) < 2:
raise ValueError('Needs atleast two samples to calculate r2 score.')
sum_error = torch.sum(target, dim=0)
sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0)
residual = torch.sum(torch.pow(target - preds, 2.0), dim=0)
total = target.size(0)
return sum_squared_error, sum_error, residual, total
def _r2score_compute(
sum_squared_error: torch.Tensor,
sum_error: torch.Tensor,
residual: torch.Tensor,
total: torch.Tensor,
adjusted: int = 0,
multioutput: str = "uniform_average"
) -> torch.Tensor:
mean_error = sum_error / total
diff = sum_squared_error - sum_error * mean_error
raw_scores = 1 - (residual / diff)
if multioutput == "raw_values":
r2score = raw_scores
elif multioutput == "uniform_average":
r2score = torch.mean(raw_scores)
elif multioutput == "variance_weighted":
diff_sum = torch.sum(diff)
r2score = torch.sum(diff / diff_sum * raw_scores)
else:
raise ValueError(
'Argument `multioutput` must be either `raw_values`,'
f' `uniform_average` or `variance_weighted`. Received {multioutput}.'
)
if adjusted < 0 or not isinstance(adjusted, int):
raise ValueError('`adjusted` parameter should be an integer larger or' ' equal to 0.')
if adjusted != 0:
if adjusted > total - 1:
rank_zero_warn(
"More independent regressions than datapoints in"
" adjusted r2 score. Falls back to standard r2 score.", UserWarning
)
elif adjusted == total - 1:
rank_zero_warn("Division by zero in adjusted r2 score. Falls back to" " standard r2 score.", UserWarning)
else:
r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1)
return r2score
def r2score(
preds: torch.Tensor,
target: torch.Tensor,
adjusted: int = 0,
multioutput: str = "uniform_average",
) -> torch.Tensor:
r"""
Computes r2 score also known as `coefficient of determination
<https://en.wikipedia.org/wiki/Coefficient_of_determination>`_:
.. math:: R^2 = 1 - \frac{SS_res}{SS_tot}
where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
:math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
adjusted r2 score given by
.. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1}
where the parameter :math:`k` (the number of independent regressors) should
be provided as the ``adjusted`` argument.
Args:
preds: estimated labels
target: ground truth labels
adjusted: number of independent regressors for calculating adjusted r2 score.
Default 0 (standard r2 score).
multioutput: Defines aggregation in the case of multiple output scores. Can be one
of the following strings (default is ``'uniform_average'``.):
* ``'raw_values'`` returns full set of scores
* ``'uniform_average'`` scores are uniformly averaged
* ``'variance_weighted'`` scores are weighted by their individual variances
Example:
>>> from torchmetrics.functional import r2score
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> r2score(preds, target)
tensor(0.9486)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score(preds, target, multioutput='raw_values')
tensor([0.9654, 0.9082])
"""
sum_squared_error, sum_error, residual, total = _r2score_update(preds, target)
return _r2score_compute(sum_squared_error, sum_error, residual, total, adjusted, multioutput)