Skip to content

Commit

Permalink
Add quantile loss metric (unit8co#1559)
Browse files Browse the repository at this point in the history
* first implementation of quantile loss

* add quantile loss to metrics ___init__ and tests

* refactor

* rename pinball loss to quantile loss

* black

* use reduction to aggregate losses and update docs

* black + isort

* rollback to simple mean instead of reduction param

* change overlooked copy-paste comment

* black enter

* docs changes

* flake8

---------

Co-authored-by: Julien Herzen <julien@unit8.co>
Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com>
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
  • Loading branch information
4 people authored and alexcolpitts96 committed May 31, 2023
1 parent d296554 commit a6fd378
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 0 deletions.
1 change: 1 addition & 0 deletions darts/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
mase,
mse,
ope,
quantile_loss,
r2_score,
rho_risk,
rmse,
Expand Down
77 changes: 77 additions & 0 deletions darts/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,3 +1212,80 @@ def rho_risk(

rho_loss = 2 * (z_true - z_hat_rho) * (rho * pred_below - (1 - rho) * pred_above)
return rho_loss / z_true


# Quantile Loss (Pinball Loss)
@multi_ts_support
@multivariate_support
def quantile_loss(
actual_series: Union[TimeSeries, Sequence[TimeSeries]],
pred_series: Union[TimeSeries, Sequence[TimeSeries]],
tau: float = 0.5,
intersect: bool = True,
*,
reduction: Callable[[np.ndarray], float] = np.mean,
inter_reduction: Callable[[np.ndarray], Union[float, np.ndarray]] = lambda x: x,
n_jobs: int = 1,
verbose: bool = False
) -> float:
"""
Also known as Pinball Loss, given a time series of actual values :math:`y` of length :math:`T`
and a time series of stochastic predictions (containing N samples) :math:`y'` of shape :math:`T x N`
quantile loss is a metric that quantifies the accuracy of a specific quantile :math:`tau`
from the predicted value distribution.
Parameters
----------
actual_series
The (sequence of) actual series.
pred_series
The (sequence of) predicted series.
tau
The quantile (float [0, 1]) of interest for the loss.
intersect
For time series that are overlapping in time without having the same time index, setting `True`
will consider the values only over their common time interval (intersection in time).
reduction
Function taking as input a ``np.ndarray`` and returning a scalar value. This function is used to aggregate
the metrics of different components in case of multivariate ``TimeSeries`` instances.
inter_reduction
Function taking as input a ``np.ndarray`` and returning either a scalar value or a ``np.ndarray``.
This function can be used to aggregate the metrics of different series in case the metric is evaluated on a
``Sequence[TimeSeries]``. Defaults to the identity function, which returns the pairwise metrics for each pair
of ``TimeSeries`` received in input. Example: ``inter_reduction=np.mean``, will return the average of the
pairwise metrics.
n_jobs
The number of jobs to run in parallel. Parallel jobs are created only when a ``Sequence[TimeSeries]`` is
passed as input, parallelising operations regarding different ``TimeSeries``. Defaults to `1`
(sequential). Setting the parameter to `-1` means using all the available processors.
verbose
Optionally, whether to print operations progress
Returns
-------
float
The quantile loss metric
"""

raise_if_not(
pred_series.is_stochastic,
"quantile (pinball) loss should only be computed for stochastic predicted TimeSeries.",
)

y, y_hat = _get_values_or_raise(
actual_series,
pred_series,
intersect,
stochastic_quantile=None,
remove_nan_union=True,
)

ts_length, _, sample_size = y_hat.shape
y = y.reshape(ts_length, -1, 1).repeat(sample_size, axis=2)
y_hat = y_hat.reshape(
ts_length, -1, sample_size
) # make sure y shape == y_hat shape

errors = y - y_hat
losses = np.maximum((tau - 1) * errors, tau * errors)
return losses.mean()
30 changes: 30 additions & 0 deletions darts/tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,36 @@ def test_rho_risk(self):
self.assertAlmostEqual(metrics.rho_risk(s1, s12_stochastic, rho=0.0), 0.0)
self.assertAlmostEqual(metrics.rho_risk(s2, s12_stochastic, rho=1.0), 0.0)

def test_quantile_loss(self):
# deterministic not supported
with self.assertRaises(ValueError):
metrics.quantile_loss(self.series1, self.series1)

# general univariate, multivariate and multi-ts tests
self.helper_test_multivariate_duplication_equality(
metrics.quantile_loss, is_stochastic=True
)
self.helper_test_multiple_ts_duplication_equality(
metrics.quantile_loss, is_stochastic=True
)
self.helper_test_nan(metrics.quantile_loss, is_stochastic=True)

# test perfect predictions -> risk = 0
for tau in [0.25, 0.5]:
self.assertAlmostEqual(
metrics.quantile_loss(self.series1, self.series11_stochastic, tau=tau),
0.0,
)

# test whether stochastic sample from two TimeSeries (ts) represents the individual ts at 0. and 1. quantiles
s1 = self.series1
s2 = self.series1 * 2
s12_stochastic = TimeSeries.from_times_and_values(
s1.time_index, np.stack([s1.values(), s2.values()], axis=2)
)
self.assertAlmostEqual(metrics.quantile_loss(s1, s12_stochastic, tau=1.0), 0.0)
self.assertAlmostEqual(metrics.quantile_loss(s2, s12_stochastic, tau=0.0), 0.0)

def test_metrics_arguments(self):
series00 = self.series0.stack(self.series0)
series11 = self.series1.stack(self.series1)
Expand Down

0 comments on commit a6fd378

Please sign in to comment.