Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tweedie Deviance Score Metric. #499

2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431))

- Added Tweedie Deviance Score ([#499](https://github.com/PyTorchLightning/metrics/pull/499))


### Changed

Expand Down
8 changes: 8 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,14 @@ symmetric_mean_absolute_percentage_error [func]
.. autofunction:: torchmetrics.functional.symmetric_mean_absolute_percentage_error
:noindex:


tweedie_deviance_score [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.tweedie_deviance_score
:noindex:


********
Pairwise
********
Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,13 @@ SymmetricMeanAbsolutePercentageError
:noindex:


TweedieDevianceScore
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.TweedieDevianceScore
:noindex:


*********
Retrieval
*********
Expand Down
106 changes: 106 additions & 0 deletions tests/regression/test_tweedie_deviance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch
from sklearn.metrics import mean_tweedie_deviance
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional.regression.tweedie_deviance import tweedie_deviance_score
from torchmetrics.regression.tweedie_deviance import TweedieDevianceScore

seed_all(42)

Input = namedtuple("Input", ["preds", "targets"])

_single_target_inputs1 = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
targets=torch.rand(NUM_BATCHES, BATCH_SIZE),
)

_single_target_inputs2 = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
targets=torch.rand(NUM_BATCHES, BATCH_SIZE),
)
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved


def _sk_deviance(preds: Tensor, targets: Tensor, power: int):
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
sk_preds = preds.view(-1).numpy()
sk_target = targets.view(-1).numpy()
return mean_tweedie_deviance(sk_target, sk_preds, power=power)


@pytest.mark.parametrize("power", [-0.5, 0, 1, 1.5, 2, 3])
@pytest.mark.parametrize(
"preds, targets, sk_metric",
[
(_single_target_inputs1.preds, _single_target_inputs1.targets, _sk_deviance),
(_single_target_inputs2.preds, _single_target_inputs2.targets, _sk_deviance),
],
)
class TestDevianceScore(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_deviance_scores_class(self, ddp, dist_sync_on_step, preds, targets, power, sk_metric):
self.run_class_metric_test(
ddp,
preds,
targets,
TweedieDevianceScore,
partial(sk_metric, power=power),
dist_sync_on_step,
metric_args=dict(power=power),
)

def test_deviance_scores_functional(self, preds, targets, power, sk_metric):
self.run_functional_metric_test(
preds,
targets,
tweedie_deviance_score,
partial(sk_metric, power=power),
metric_args=dict(power=power),
)

karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved

def test_error_on_different_shape(metric_class=TweedieDevianceScore):
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))


def test_error_on_invalid_inputs(metric_class=TweedieDevianceScore):
with pytest.raises(ValueError, match="Deviance Score is not defined for power=0.5."):
metric_class(power=0.5)

metric = metric_class(power=1)
with pytest.raises(
ValueError, match="For power=1, 'preds' has to be strictly positive and 'targets' cannot be negative."
):
metric(torch.tensor([-1.0, 2.0, 3.0]), torch.rand(3))

with pytest.raises(
ValueError, match="For power=1, 'preds' has to be strictly positive and 'targets' cannot be negative."
):
metric(torch.rand(3), torch.tensor([-1.0, 2.0, 3.0]))

metric = metric_class(power=2)
with pytest.raises(ValueError, match="For power=2, both 'preds' and 'targets' have to be strictly positive."):
metric(torch.tensor([-1.0, 2.0, 3.0]), torch.rand(3))

with pytest.raises(ValueError, match="For power=2, both 'preds' and 'targets' have to be strictly positive."):
metric(torch.rand(3), torch.tensor([-1.0, 2.0, 3.0]))
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
R2Score,
SpearmanCorrcoef,
SymmetricMeanAbsolutePercentageError,
TweedieDevianceScore,
)
from torchmetrics.retrieval import ( # noqa: E402
RetrievalFallOut,
Expand Down Expand Up @@ -82,6 +83,7 @@
"CohenKappa",
"ConfusionMatrix",
"CosineSimilarity",
"TweedieDevianceScore",
"ExplainedVariance",
"F1",
"FBeta",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from torchmetrics.functional.regression.symmetric_mean_absolute_percentage_error import (
symmetric_mean_absolute_percentage_error,
)
from torchmetrics.functional.regression.tweedie_deviance import tweedie_deviance_score
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
Expand All @@ -73,6 +74,7 @@
"cohen_kappa",
"confusion_matrix",
"cosine_similarity",
"tweedie_deviance_score",
"dice_score",
"embedding_similarity",
"explained_variance",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from torchmetrics.functional.regression.r2 import r2_score # noqa: F401
from torchmetrics.functional.regression.r2score import r2score # noqa: F401
from torchmetrics.functional.regression.spearman import spearman_corrcoef # noqa: F401
from torchmetrics.functional.regression.tweedie_deviance import tweedie_deviance_score # noqa: F401
147 changes: 147 additions & 0 deletions torchmetrics/functional/regression/tweedie_deviance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape


def _tweedie_deviance_score_update(preds: Tensor, targets: Tensor, power: float = 0.0) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute Deviance Score for the given power. Checks for same shape
of input tensors.

Args:
preds: Predicted tensor
targets: Ground truth tensor
power:
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
- power < 0 : Extreme stable distribution. (Requires: preds > 0.)
- power = 0 : Normal distribution. (Requires: targets and preds can be any real numbers.)
- power = 1 : Poisson distribution. (Requires: targets >= 0 and y_pred > 0.)
- 1 < p < 2 : Compound Poisson distribution. (Requires: targets >= 0 and preds > 0.)
- power = 2 : Gamma distribution. (Requires: targets > 0 and preds > 0.)
- power = 3 : Inverse Gaussian distribution. (Requires: targets > 0 and preds > 0.)
- otherwise : Positive stable distribution. (Requires: targets > 0 and preds > 0.)

Example:
>>> targets = torch.tensor([1.0, 2.0, 3.0, 4.0])
>>> preds = torch.tensor([4.0, 3.0, 2.0, 1.0])
>>> _tweedie_deviance_score_update(preds, targets, power=0)
(tensor(20.), tensor(4))
"""
_check_same_shape(preds, targets)

zero_tensor = torch.zeros(preds.shape, device=preds.device)

if 0 < power < 1:
raise ValueError(f"Deviance Score is not defined for power={power}.")

if power == 0:
deviance_score = torch.pow(targets - preds, exponent=2)
elif power == 1:
# Poisson distribution
if torch.any(preds <= 0) or torch.any(targets < 0):
raise ValueError(
f"For power={power}, 'preds' has to be strictly positive and 'targets' cannot be negative."
)

deviance_score = 2 * (targets * torch.log(targets / preds) + preds - targets)
elif power == 2:
# Gamma distribution
if torch.any(preds <= 0) or torch.any(targets <= 0):
raise ValueError(f"For power={power}, both 'preds' and 'targets' have to be strictly positive.")

deviance_score = 2 * (torch.log(preds / targets) + (targets / preds) - 1)
else:
if power < 0:
if torch.any(preds <= 0):
raise ValueError(f"For power={power}, 'preds' has to be strictly positive.")
elif 1 < power < 2:
if torch.any(preds <= 0) or torch.any(targets < 0):
raise ValueError(
f"For power={power}, 'targets' has to be strictly positive and 'preds' cannot be negative."
)
else:
if torch.any(preds <= 0) or torch.any(targets <= 0):
raise ValueError(f"For power={power}, both 'preds' and 'targets' have to be strictly positive.")

term_1 = torch.pow(torch.max(targets, zero_tensor), 2 - power) / ((1 - power) * (2 - power))
term_2 = targets * torch.pow(preds, 1 - power) / (1 - power)
term_3 = torch.pow(preds, 2 - power) / (2 - power)
deviance_score = 2 * (term_1 - term_2 + term_3)

sum_deviance_score = torch.sum(deviance_score)
num_observations = torch.tensor(torch.numel(deviance_score))
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved

return sum_deviance_score, num_observations


def _tweedie_deviance_score_compute(sum_deviance_score: Tensor, num_observations: Tensor) -> Tensor:
"""Computes Deviance Score.

Args:
sum_deviance_score: Sum of deviance scores accumalated until now.
num_observations: Number of observations encountered until now.

Example:
>>> targets = torch.tensor([1.0, 2.0, 3.0, 4.0])
>>> preds = torch.tensor([4.0, 3.0, 2.0, 1.0])
>>> sum_deviance_score, num_observations = _tweedie_deviance_score_update(preds, targets, power=0)
>>> _tweedie_deviance_score_compute(sum_deviance_score, num_observations)
tensor(5.)
"""

return sum_deviance_score / num_observations


def tweedie_deviance_score(preds: Tensor, targets: Tensor, power: float = 0.0) -> Tensor:
r"""
Computes the `Deviance Score <https://en.wikipedia.org/wiki/Tweedie_distribution#The_Tweedie_deviance>`_ between
targets and predictions:

.. math::
deviance\_score(\hat{y},y) =
\begin{cases}
(\hat{y} - y)^2, & \text{for }power=0\\
2 * (y * log(\frac{y}{\hat{y}}) + \hat{y} - y), & \text{for }power=1\\
2 * (log(\frac{\hat{y}}{y}) + \frac{y}{\hat{y}} - 1), & \text{for }power=2\\
2 * (\frac{(max(y,0))^{2}}{(1 - power)(2 - power)} - \frac{y(\hat{y})^{1 - power}}{1 - power} + \frac{(\hat{y})
^{2 - power}}{2 - power}), & \text{otherwise}
\end{cases}

where :math:`y` is a tensor of targets values, and :math:`\hat{y}` is a tensor of predictions.

Args:
preds: Predicted tensor with shape ``(N,d)``
targets: Ground truth tensor with shape ``(N,d)``
power:
- power < 0 : Extreme stable distribution. (Requires: preds > 0.)
- power = 0 : Normal distribution. (Requires: targets and preds can be any real numbers.)
- power = 1 : Poisson distribution. (Requires: targets >= 0 and y_pred > 0.)
- 1 < p < 2 : Compound Poisson distribution. (Requires: targets >= 0 and preds > 0.)
- power = 2 : Gamma distribution. (Requires: targets > 0 and preds > 0.)
- power = 3 : Inverse Gaussian distribution. (Requires: targets > 0 and preds > 0.)
- otherwise : Positive stable distribution. (Requires: targets > 0 and preds > 0.)

Example:
>>> from torchmetrics.functional.regression import tweedie_deviance_score
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
>>> targets = torch.tensor([1.0, 2.0, 3.0, 4.0])
>>> preds = torch.tensor([4.0, 3.0, 2.0, 1.0])
>>> tweedie_deviance_score(preds, targets, power=0)
tensor(5.)

"""
sum_deviance_score, num_observations = _tweedie_deviance_score_update(preds, targets, power=power)
return _tweedie_deviance_score_compute(sum_deviance_score, num_observations)
1 change: 1 addition & 0 deletions torchmetrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from torchmetrics.regression.symmetric_mean_absolute_percentage_error import ( # noqa: F401
SymmetricMeanAbsolutePercentageError,
)
from torchmetrics.regression.tweedie_deviance import TweedieDevianceScore # noqa: F401