Skip to content

Commit

Permalink
fix: replacing lambda with named function to make model pickable (uni…
Browse files Browse the repository at this point in the history
…t8co#1594)

* fix: replacing lambda with named function to make model pickable

* fix: issue was also occurring with the exponential de-trending function

* fix: adding typing

* fix: linting

---------

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
  • Loading branch information
2 people authored and alexcolpitts96 committed May 31, 2023
1 parent a48286a commit bcdde86
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions darts/models/forecasting/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
----------------------
"""

from typing import Optional
from typing import Callable, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -238,7 +238,7 @@ def __init__(
pd.Timestamp attributes that are relevant for the seasonality automatically.
trend
If set, indicates what kind of detrending will be applied before performing DFT.
Possible values: 'poly' or 'exp', for polynomial trend, or exponential trend, respectively.
Possible values: 'poly', 'exp' or None, for polynomial trend, exponential trend or no trend, respectively.
trend_poly_degree
The degree of the polynomial that will be used for detrending, if `trend='poly'`.
Expand Down Expand Up @@ -269,6 +269,20 @@ def __str__(self):
+ ")"
)

def _exp_trend(self, x) -> Callable:
"""Helper function, used to make FFT model pickable."""
return np.exp(self.trend_coefficients[1]) * np.exp(
self.trend_coefficients[0] * x
)

def _poly_trend(self, trend_coefficients) -> Callable:
"""Helper function, for consistency with the other trends"""
return np.poly1d(trend_coefficients)

def _null_trend(self, x) -> Callable:
"""Helper function, used to make FFT model pickable."""
return 0

def fit(self, series: TimeSeries):
series = fill_missing_values(series)
super().fit(series)
Expand All @@ -277,19 +291,18 @@ def fit(self, series: TimeSeries):

# determine trend
if self.trend == "poly":
trend_coefficients = np.polyfit(
self.trend_coefficients = np.polyfit(
range(len(series)), series.univariate_values(), self.trend_poly_degree
)
self.trend_function = np.poly1d(trend_coefficients)
self.trend_function = self._poly_trend(self.trend_coefficients)
elif self.trend == "exp":
trend_coefficients = np.polyfit(
self.trend_coefficients = np.polyfit(
range(len(series)), np.log(series.univariate_values()), 1
)
self.trend_function = lambda x: np.exp(trend_coefficients[1]) * np.exp(
trend_coefficients[0] * x
)
self.trend_function = self._exp_trend
else:
self.trend_function = lambda x: 0
self.trend_coefficients = None
self.trend_function = self._null_trend

# subtract trend
detrended_values = series.univariate_values() - self.trend_function(
Expand Down

0 comments on commit bcdde86

Please sign in to comment.