-
Notifications
You must be signed in to change notification settings - Fork 387
/
rse.py
143 lines (113 loc) · 5.31 KB
/
rse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union
import torch
from torch import Tensor, tensor
from torchmetrics.functional.regression.r2 import _r2_score_update
from torchmetrics.functional.regression.rse import _relative_squared_error_compute
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["RelativeSquaredError.plot"]
class RelativeSquaredError(Metric):
r"""Computes the relative squared error (RSE).
.. math:: \text{RSE} = \frac{\sum_i^N(y_i - \hat{y_i})^2}{\sum_i^N(y_i - \overline{y})^2}
Where :math:`y` is a tensor of target values with mean :math:`\overline{y}`, and
:math:`\hat{y}` is a tensor of predictions.
If num_outputs > 1, the returned value is averaged over all the outputs.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): Predictions from model in float tensor with shape ``(N,)``
or ``(N, M)`` (multioutput)
- ``target`` (:class:`~torch.Tensor`): Ground truth values in float tensor with shape ``(N,)``
or ``(N, M)`` (multioutput)
As output of ``forward`` and ``compute`` the metric returns the following output:
- ``rse`` (:class:`~torch.Tensor`): A tensor with the RSE score(s)
Args:
num_outputs: Number of outputs in multioutput setting
squared: If True returns RSE value, if False returns RRSE value.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> from torchmetrics.regression import RelativeSquaredError
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> relative_squared_error = RelativeSquaredError()
>>> relative_squared_error(preds, target)
tensor(0.0514)
"""
is_differentiable = True
higher_is_better = False
full_state_update = False
sum_squared_error: Tensor
sum_error: Tensor
residual: Tensor
total: Tensor
def __init__(
self,
num_outputs: int = 1,
squared: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.num_outputs = num_outputs
self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
self.squared = squared
def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
sum_squared_error, sum_error, residual, total = _r2_score_update(preds, target)
self.sum_squared_error += sum_squared_error
self.sum_error += sum_error
self.residual += residual
self.total += total
def compute(self) -> Tensor:
"""Computes relative squared error over state."""
return _relative_squared_error_compute(
self.sum_squared_error, self.sum_error, self.residual, self.total, squared=self.squared
)
def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import RelativeSquaredError
>>> metric = RelativeSquaredError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import RelativeSquaredError
>>> metric = RelativeSquaredError()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)