Skip to content

Commit

Permalink
Feature: Add support for logistic growth to Prophet (unit8co#1419)
Browse files Browse the repository at this point in the history
* Add logistic growth to prophet

* Add RangeIndex to typehint and docstring

* Fix import order

* Move cap & floor to __init__

* Update darts/models/forecasting/prophet_model.py

* Fix linting & remove duplicat docstring

---------

Co-authored-by: David Kleindienst <kleindienst@ximes.com>
Co-authored-by: Julien Herzen <julien@unit8.co>
Co-authored-by: Julien Herzen <j.herzen@gmail.com>
  • Loading branch information
4 people authored and alexcolpitts96 committed May 31, 2023
1 parent d3c8e71 commit 4927afc
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 1 deletion.
68 changes: 67 additions & 1 deletion darts/models/forecasting/prophet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import logging
import re
from typing import List, Optional, Union
from typing import Callable, List, Optional, Sequence, Union

import numpy as np
import pandas as pd
Expand All @@ -28,6 +28,12 @@ def __init__(
country_holidays: Optional[str] = None,
suppress_stdout_stderror: bool = True,
add_encoders: Optional[dict] = None,
cap: Union[
float, Callable[[Union[pd.DatetimeIndex, pd.RangeIndex]], Sequence[float]]
] = None,
floor: Union[
float, Callable[[Union[pd.DatetimeIndex, pd.RangeIndex]], Sequence[float]]
] = None,
**prophet_kwargs,
):
"""Facebook Prophet
Expand Down Expand Up @@ -92,6 +98,26 @@ def __init__(
'transformer': Scaler()
}
..
cap
Parameter specifiying the maximum carrying capacity when predicting with logistic growth.
Mandatory when `growth = 'logistic'`, otherwise ignored.
See <https://facebook.github.io/prophet/docs/saturating_forecasts.html> for more information
on logistic forecasts.
Can be either
- a number, for constant carrying capacities
- a function taking a DatetimeIndex or RangeIndex and returning a corresponding a Sequence of numbers,
where each number indicates the carrying capacity at this index.
floor
Parameter specifiying the minimum carrying capacity when predicting logistic growth.
Optional when `growth = 'logistic'` (defaults to 0), otherwise ignored.
See <https://facebook.github.io/prophet/docs/saturating_forecasts.html> for more information
on logistic forecasts.
Can be either
- a number, for constant carrying capacities
- a function taking a DatetimeIndex or RangeIndex and returning a corresponding a Sequence of numbers,
where each number indicates the carrying capacity at this index.
prophet_kwargs
Some optional keyword arguments for Prophet.
For information about the parameters see:
Expand Down Expand Up @@ -119,6 +145,26 @@ def __init__(
self._execute_and_suppress_output = execute_and_suppress_output
self._model_builder = prophet.Prophet

self._cap = cap
self._floor = floor
self.is_logistic = (
"growth" in prophet_kwargs and prophet_kwargs["growth"] == "logistic"
)
if not self.is_logistic and (cap is not None or floor is not None):
logger.warning(
"Parameters `cap` and/or `floor` were set although `growth` is not "
"logistic. The set capacities will be ignored."
)
if self.is_logistic:
raise_if(
cap is None,
"Parameter `cap` has to be set when `growth` is logistic",
logger,
)
if floor is None:
# Use 0 as default value
self._floor = 0

def __str__(self):
return "Prophet"

Expand All @@ -131,6 +177,8 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non
fit_df = pd.DataFrame(
data={"ds": series.time_index, "y": series.univariate_values()}
)
if self.is_logistic:
fit_df = self._add_capacities_to_df(fit_df)

self.model = self._model_builder(**self.prophet_kwargs)

Expand Down Expand Up @@ -188,13 +236,31 @@ def _predict(

return self._build_forecast_series(forecast)

def _add_capacities_to_df(self, df: pd.DataFrame) -> pd.DataFrame:
dates = df["ds"]
try:
df["cap"] = self._cap(dates) if callable(self._cap) else self._cap
df["floor"] = self._floor(dates) if callable(self._floor) else self._floor
except ValueError as e:
raise_if(
"does not match length of index" in str(e),
"Callables supplied to `Prophet.set_capacity` as `cap` or `floor` "
"arguments have to return Sequences of identical length as their "
" input argument Sequence!",
logger,
)
raise
return df

def _generate_predict_df(
self, n: int, future_covariates: Optional[TimeSeries] = None
) -> pd.DataFrame:
"""Returns a pandas DataFrame in the format required for Prophet.predict() with `n` dates after the end of
the fitted TimeSeries"""

predict_df = pd.DataFrame(data={"ds": self._generate_new_dates(n)})
if self.is_logistic:
predict_df = self._add_capacities_to_df(predict_df)
if future_covariates is not None:
predict_df = predict_df.merge(
future_covariates.pd_dataframe(),
Expand Down
21 changes: 21 additions & 0 deletions darts/tests/models/forecasting/test_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,27 @@ def test_prophet_model_default_with_prophet_constructor(self):
model = Prophet()
assert model._model_builder == FBProphet, "model should use Facebook Prophet"

def test_prophet_model_with_logistic_growth(self):
model = Prophet(growth="logistic", cap=1)

# Create timeseries with logistic function
times = tg.generate_index(
pd.Timestamp("20200101"), pd.Timestamp("20210101"), freq="D"
)
values = np.linspace(-10, 10, len(times))
f = np.vectorize(lambda x: 1 / (1 + np.exp(-x)))
values = f(values)
ts = TimeSeries.from_times_and_values(times, values, freq="D")
# split in the middle, so the only way of predicting the plateau correctly
# is using the capacity
train, val = ts.split_after(0.5)

model.fit(train)
pred = model.predict(len(val))

for val_i, pred_i in zip(val.univariate_values(), pred.univariate_values()):
self.assertAlmostEqual(val_i, pred_i, delta=0.1)

def helper_test_freq_coversion(self, test_cases):
for freq, period in test_cases.items():
ts_sine = tg.sine_timeseries(
Expand Down

0 comments on commit 4927afc

Please sign in to comment.