-
Notifications
You must be signed in to change notification settings - Fork 408
/
tweedie_deviance.py
117 lines (99 loc) · 4.69 KB
/
tweedie_deviance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
import torch
from torch import Tensor
from torchmetrics.functional.regression.tweedie_deviance import (
_tweedie_deviance_score_compute,
_tweedie_deviance_score_update,
)
from torchmetrics.metric import Metric
class TweedieDevianceScore(Metric):
r"""
Computes the `Tweedie Deviance Score`_ between targets and predictions:
.. math::
deviance\_score(\hat{y},y) =
\begin{cases}
(\hat{y} - y)^2, & \text{for }power=0\\
2 * (y * log(\frac{y}{\hat{y}}) + \hat{y} - y), & \text{for }power=1\\
2 * (log(\frac{\hat{y}}{y}) + \frac{y}{\hat{y}} - 1), & \text{for }power=2\\
2 * (\frac{(max(y,0))^{2}}{(1 - power)(2 - power)} - \frac{y(\hat{y})^{1 - power}}{1 - power} + \frac{(\hat{y})
^{2 - power}}{2 - power}), & \text{otherwise}
\end{cases}
where :math:`y` is a tensor of targets values, and :math:`\hat{y}` is a tensor of predictions.
Forward accepts
- ``preds`` (float tensor): ``(N,...)``
- ``targets`` (float tensor): ``(N,...)``
Args:
power:
- power < 0 : Extreme stable distribution. (Requires: preds > 0.)
- power = 0 : Normal distribution. (Requires: targets and preds can be any real numbers.)
- power = 1 : Poisson distribution. (Requires: targets >= 0 and y_pred > 0.)
- 1 < p < 2 : Compound Poisson distribution. (Requires: targets >= 0 and preds > 0.)
- power = 2 : Gamma distribution. (Requires: targets > 0 and preds > 0.)
- power = 3 : Inverse Gaussian distribution. (Requires: targets > 0 and preds > 0.)
- otherwise : Positive stable distribution. (Requires: targets > 0 and preds > 0.)
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the all gather.
Example:
>>> from torchmetrics import TweedieDevianceScore
>>> targets = torch.tensor([1.0, 2.0, 3.0, 4.0])
>>> preds = torch.tensor([4.0, 3.0, 2.0, 1.0])
>>> deviance_score = TweedieDevianceScore(power=2)
>>> deviance_score(preds, targets)
tensor(1.2083)
"""
is_differentiable = True
higher_is_better = None # TODO: both -1 and 1 are optimal
sum_deviance_score: Tensor
num_observations: Tensor
def __init__(
self,
power: float = 0.0,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
if 0 < power < 1:
raise ValueError(f"Deviance Score is not defined for power={power}.")
self.power: float = power
self.add_state("sum_deviance_score", torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("num_observations", torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: Tensor, targets: Tensor) -> None: # type: ignore
"""Update metric states with predictions and targets.
Args:
preds: Predicted tensor with shape ``(N,d)``
targets: Ground truth tensor with shape ``(N,d)``
"""
sum_deviance_score, num_observations = _tweedie_deviance_score_update(preds, targets, self.power)
self.sum_deviance_score += sum_deviance_score
self.num_observations += num_observations
def compute(self) -> Tensor:
return _tweedie_deviance_score_compute(self.sum_deviance_score, self.num_observations)