Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prophet regressor #1704

Merged
merged 38 commits into from
Jan 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
85cb500
Add prophet model family
jeremyliweishih Jan 13, 2021
a80347e
Add prophet component
jeremyliweishih Jan 13, 2021
5aa3fd0
Fix prophet component and add unit tests
jeremyliweishih Jan 19, 2021
f549fa3
Merge branch 'main' of github.com:alteryx/evalml into js_1499_prophet
jeremyliweishih Jan 19, 2021
14cae58
RL
jeremyliweishih Jan 19, 2021
709fe4b
fix merge issue
jeremyliweishih Jan 19, 2021
292aa38
fix another merge issue
jeremyliweishih Jan 19, 2021
921f76e
last merge issue
jeremyliweishih Jan 19, 2021
284193c
move Prophet out of automl
jeremyliweishih Jan 19, 2021
5048cc8
fix unit tests
jeremyliweishih Jan 19, 2021
6fc5982
fix latest dependencies test
jeremyliweishih Jan 19, 2021
81770e2
fix latest_dependency_version.txt'
jeremyliweishih Jan 19, 2021
f1a8cc3
add windows installation in CI
jeremyliweishih Jan 20, 2021
2f208ad
Merge branch 'main' of github.com:alteryx/evalml into js_1499_prophet
jeremyliweishih Jan 20, 2021
871de4a
add more coverage
jeremyliweishih Jan 20, 2021
9e0a618
Add X=None case and add date_column parameter
jeremyliweishih Jan 21, 2021
c932cae
Add test case for init other params
jeremyliweishih Jan 21, 2021
16755ad
Add more test cases
jeremyliweishih Jan 21, 2021
dc22b2a
remove test exceptions
jeremyliweishih Jan 21, 2021
d0ca8b1
Merge branch 'main' into js_1499_prophet
jeremyliweishih Jan 21, 2021
1e06c27
add to API reference
jeremyliweishih Jan 21, 2021
bb739e5
Merge branch 'main' of github.com:alteryx/evalml into js_1499_prophet
jeremyliweishih Jan 21, 2021
3d041a2
Merge branch 'js_1499_prophet' of github.com:alteryx/evalml into js_1…
jeremyliweishih Jan 21, 2021
a8a13d6
install fbprophet through conda for windows tests
jeremyliweishih Jan 22, 2021
e02fe13
directly install fbprophet in conda
jeremyliweishih Jan 22, 2021
d3a8716
try other combos
jeremyliweishih Jan 22, 2021
e182bbb
try with compiler
jeremyliweishih Jan 22, 2021
f166e7c
ignore pystan and fbprohpet in pip
jeremyliweishih Jan 25, 2021
c646f26
Merge branch 'main' into js_1499_prophet
jeremyliweishih Jan 25, 2021
2385295
add rest of dependencies
jeremyliweishih Jan 25, 2021
83085c0
Merge branch 'js_1499_prophet' of github.com:alteryx/evalml into js_1…
jeremyliweishih Jan 25, 2021
333bbfa
don't install dependencies
jeremyliweishih Jan 25, 2021
841be47
Merge branch 'main' of github.com:alteryx/evalml into js_1499_prophet
jeremyliweishih Jan 28, 2021
70ce64f
install pystan using conda and prophet using pip
jeremyliweishih Jan 28, 2021
d723488
revert to old installation
jeremyliweishih Jan 29, 2021
c64b500
Merge branch 'main' of github.com:alteryx/evalml into js_1499_prophet
jeremyliweishih Jan 29, 2021
de3629c
Merge branch 'main' into js_1499_prophet
jeremyliweishih Jan 29, 2021
eb01777
Merge branch 'main' into js_1499_prophet
jeremyliweishih Jan 29, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ jobs:
conda config --add channels conda-forge
conda activate curr_py
conda install numba -q -y
- run:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need this for windows installation. In the PR to add prophet to automl we should also include installation instructions in our docs.

name: Install cython and pystan (for prophet)
command: |
C:\Users\circleci\Miniconda3\shell\condabin\conda-hook.ps1
conda config --add channels conda-forge
conda activate curr_py
conda install libpython m2w64-toolchain -c msys2
conda install numpy cython -c conda-forge
python -m pip install pystan
jeremyliweishih marked this conversation as resolved.
Show resolved Hide resolved
- run:
name: Install EvalML
command: |
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ Regressors are components that output a predicted target value.
StackedEnsembleRegressor
DecisionTreeRegressor
LightGBMRegressor
ProphetRegressor
SVMRegressor

.. currentmodule:: evalml.model_understanding
Expand Down
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Release Notes
* Enhanced ``DelayedFeaturesTransformer`` to encode categorical features and targets before delaying them :pr:`1691`
* Added 2-way dependence plots. :pr:`1690`
* Added ability to directly iterate through components within Pipelines :pr:`1583`
* Added Facebook's Prophet as a time series regressor :pr:`1704`
* Fixes
* Fixed inconsistent attributes and added Exceptions to docs :pr:`1673`
* Fixed ``TargetLeakageDataCheck`` to use Woodwork ``mutual_information`` rather than using Pandas' Pearson Correlation :pr:`1616`
Expand Down
4 changes: 4 additions & 0 deletions evalml/model_family/model_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class ModelFamily(Enum):
BASELINE = 'baseline'
"""Baseline model family."""

PROPHET = 'prophet'
"""Prophet model family."""

NONE = 'none'
"""None"""

Expand All @@ -52,6 +55,7 @@ def __str__(self):
ModelFamily.DECISION_TREE.name: "Decision Tree",
ModelFamily.BASELINE.name: "Baseline",
ModelFamily.ENSEMBLE.name: "Ensemble",
ModelFamily.PROPHET.name: "Prophet",
ModelFamily.NONE.name: "None"}
return model_family_dict[self.name]

Expand Down
1 change: 1 addition & 0 deletions evalml/pipelines/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
DecisionTreeRegressor,
TimeSeriesBaselineEstimator,
KNeighborsClassifier,
ProphetRegressor,
SVMClassifier,
SVMRegressor
)
Expand Down
1 change: 1 addition & 0 deletions evalml/pipelines/components/estimators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
BaselineRegressor,
TimeSeriesBaselineEstimator,
DecisionTreeRegressor,
ProphetRegressor,
SVMRegressor)
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from .baseline_regressor import BaselineRegressor
from .decision_tree_regressor import DecisionTreeRegressor
from .time_series_baseline_estimator import TimeSeriesBaselineEstimator
from .prophet_regressor import ProphetRegressor
from .svm_regressor import SVMRegressor
116 changes: 116 additions & 0 deletions evalml/pipelines/components/estimators/regressors/prophet_regressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import numpy as np
import pandas as pd
from skopt.space import Real

from evalml.model_family import ModelFamily
from evalml.pipelines.components.estimators import Estimator
from evalml.problem_types import ProblemTypes
from evalml.utils import SEED_BOUNDS, import_or_raise
from evalml.utils.gen_utils import (
_convert_to_woodwork_structure,
_convert_woodwork_types_wrapper,
suppress_stdout_stderr
)


class ProphetRegressor(Estimator):
"""
Prophet is a procedure for forecasting time series data based on an additive model where non-linear trends are fit with yearly, weekly, and daily seasonality, plus holiday effects.
It works best with time series that have strong seasonal effects and several seasons of historical data. Prophet is robust to missing data and shifts in the trend, and typically handles outliers well.

More information here: https://facebook.github.io/prophet/

"""
name = "Prophet Regressor"
hyperparameter_ranges = {
"changepoint_prior_scale": Real(0.001, 0.5),
"seasonality_prior_scale": Real(0.01, 10),
"holidays_prior_scale": Real(0.01, 10),
"seasonality_mode": ['additive', 'multiplicative'],
}
model_family = ModelFamily.PROPHET
supported_problem_types = [ProblemTypes.TIME_SERIES_REGRESSION]

SEED_MIN = 0
SEED_MAX = SEED_BOUNDS.max_bound

def __init__(self, date_column='ds', changepoint_prior_scale=0.05, seasonality_prior_scale=10, holidays_prior_scale=10, seasonality_mode="additive",
random_state=0, **kwargs):
self.date_column = date_column

parameters = {'changepoint_prior_scale': changepoint_prior_scale,
"seasonality_prior_scale": seasonality_prior_scale,
"holidays_prior_scale": holidays_prior_scale,
"seasonality_mode": seasonality_mode}

parameters.update(kwargs)

p_error_msg = "prophet is not installed. Please install using `pip install pystan` and `pip install fbprophet`."
prophet = import_or_raise("fbprophet", error_msg=p_error_msg)

prophet_regressor = prophet.Prophet(**parameters)
super().__init__(parameters=parameters,
component_obj=prophet_regressor,
random_state=random_state)

@staticmethod
def build_prophet_df(X, y=None, date_column='ds'):
if X is not None:
X = X.copy(deep=True)
if y is not None:
y = y.copy(deep=True)

if date_column in X.columns:
date_col = X[date_column]
elif isinstance(X.index, pd.DatetimeIndex):
date_col = X.reset_index()
date_col = date_col['index']
elif isinstance(y.index, pd.DatetimeIndex):
date_col = y.reset_index()
date_col = date_col['index']
else:
msg = "Prophet estimator requires input data X to have a datetime column specified by the 'date_column' parameter. If not it will look for the datetime column in the index of X or y."
raise ValueError(msg)

date_col = date_col.rename('ds')
prophet_df = date_col.to_frame()
if y is not None:
y.index = prophet_df.index
prophet_df['y'] = y
return prophet_df

def fit(self, X, y=None):
if X is None:
X = pd.DataFrame()

X = _convert_to_woodwork_structure(X)
X = _convert_woodwork_types_wrapper(X.to_dataframe())

y = _convert_to_woodwork_structure(y)
y = _convert_woodwork_types_wrapper(y.to_series())

prophet_df = ProphetRegressor.build_prophet_df(X=X, y=y, date_column=self.date_column)

with suppress_stdout_stderr():
self._component_obj.fit(prophet_df)
return self

def predict(self, X, y=None):
if X is None:
X = pd.DataFrame()

X = _convert_to_woodwork_structure(X)
X = _convert_woodwork_types_wrapper(X.to_dataframe())

prophet_df = ProphetRegressor.build_prophet_df(X=X, y=y, date_column=self.date_column)

with suppress_stdout_stderr():
y_pred = self._component_obj.predict(prophet_df)['yhat']
return y_pred

@property
def feature_importance(self):
"""
Returns array of 0's with len(1) as feature_importance is not defined for Prophet regressor.
"""
return np.zeros(1)
28 changes: 23 additions & 5 deletions evalml/tests/component_tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
LogisticRegressionClassifier,
OneHotEncoder,
PerColumnImputer,
ProphetRegressor,
RandomForestClassifier,
RandomForestRegressor,
RFClassifierSelectFromModel,
Expand Down Expand Up @@ -228,6 +229,11 @@ def test_describe_component():
'min_child_samples': 20, 'n_jobs': -1, 'bagging_fraction': 0.9, 'bagging_freq': 0}}
except ImportError:
pass
try:
prophet_regressor = ProphetRegressor()
assert prophet_regressor.describe(return_dict=True) == {'name': 'Prophet Regressor', 'parameters': {'changepoint_prior_scale': 0.05, 'holidays_prior_scale': 10, 'seasonality_mode': 'additive', 'seasonality_prior_scale': 10}}
except ImportError:
pass


def test_missing_attributes(X_y_binary):
Expand Down Expand Up @@ -706,14 +712,18 @@ def test_all_transformers_check_fit(X_y_binary):
component.transform(X)


def test_all_estimators_check_fit(X_y_binary, test_estimator_needs_fitting_false, helper_functions):
X, y = X_y_binary
def test_all_estimators_check_fit(X_y_binary, ts_data, test_estimator_needs_fitting_false, helper_functions):
estimators_to_check = [estimator for estimator in _all_estimators() if estimator not in [StackedEnsembleClassifier, StackedEnsembleRegressor, TimeSeriesBaselineEstimator]] + [test_estimator_needs_fitting_false]
for component_class in estimators_to_check:
if not component_class.needs_fitting:
continue

component = helper_functions.safe_init_component_with_njobs_1(component_class)
if component.supported_problem_types == [ProblemTypes.TIME_SERIES_REGRESSION]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prophet does not follow sklearn's estimator API (and this is our first estimator that only does time series problem types) and tests in test_components.py break so I've added some work arounds here for now. Happy to discuss any of these changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I think it's fine to have this code outside a util function since it's only used twice and only in our tests.

X, y = ts_data
else:
X, y = X_y_binary

with pytest.raises(ComponentNotYetFittedError, match=f'You must fit {component_class.__name__}'):
component.predict(X)

Expand Down Expand Up @@ -757,8 +767,7 @@ def test_no_fitting_required_components(X_y_binary, test_estimator_needs_fitting
component.transform(X, y)


def test_serialization(X_y_binary, tmpdir, helper_functions):
X, y = X_y_binary
def test_serialization(X_y_binary, ts_data, tmpdir, helper_functions):
jeremyliweishih marked this conversation as resolved.
Show resolved Hide resolved
path = os.path.join(str(tmpdir), 'component.pkl')
for component_class in all_components():
print('Testing serialization of component {}'.format(component_class.name))
Expand All @@ -769,6 +778,12 @@ def test_serialization(X_y_binary, tmpdir, helper_functions):
component = component_class(input_pipelines=[make_pipeline_from_components([RandomForestClassifier()], ProblemTypes.BINARY)], n_jobs=1)
elif (component_class == StackedEnsembleRegressor):
component = component_class(input_pipelines=[make_pipeline_from_components([RandomForestRegressor()], ProblemTypes.REGRESSION)], n_jobs=1)

if isinstance(component, Estimator) and component.supported_problem_types == [ProblemTypes.TIME_SERIES_REGRESSION]:
X, y = ts_data
else:
X, y = X_y_binary

component.fit(X, y)

for pickle_protocol in range(cloudpickle.DEFAULT_PROTOCOL + 1):
Expand Down Expand Up @@ -812,7 +827,10 @@ def test_estimators_accept_all_kwargs(estimator_class,
if estimator_class.model_family == ModelFamily.ENSEMBLE:
params = estimator.parameters
else:
params = estimator._component_obj.get_params()
try:
jeremyliweishih marked this conversation as resolved.
Show resolved Hide resolved
params = estimator._component_obj.get_params()
except AttributeError:
pytest.skip('estimator does not have `get_params()` method.')
if estimator_class.model_family == ModelFamily.CATBOOST:
# Deleting because we call it random_state in our api
del params["random_seed"]
Expand Down
119 changes: 119 additions & 0 deletions evalml/tests/component_tests/test_prophet_regressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import numpy as np
import pandas as pd
import pytest
jeremyliweishih marked this conversation as resolved.
Show resolved Hide resolved
from pytest import importorskip

from evalml.model_family import ModelFamily
from evalml.pipelines.components import ProphetRegressor
from evalml.problem_types import ProblemTypes
from evalml.utils.gen_utils import suppress_stdout_stderr

prophet = importorskip('fbprophet', reason='Skipping test because xgboost not installed')


def test_model_family():
assert ProphetRegressor.model_family == ModelFamily.PROPHET


def test_problem_types():
assert set(ProphetRegressor.supported_problem_types) == {ProblemTypes.TIME_SERIES_REGRESSION}


def test_init_with_other_params():
clf = ProphetRegressor(daily_seasonality=True, mcmc_samples=5, interval_width=0.8, uncertainty_samples=0)
assert clf.parameters == {'changepoint_prior_scale': 0.05,
'daily_seasonality': True,
'holidays_prior_scale': 10,
'interval_width': 0.8,
'mcmc_samples': 5,
'seasonality_mode': 'additive',
'seasonality_prior_scale': 10,
'uncertainty_samples': 0}


def test_feature_importance(ts_data):
X, y = ts_data
clf = ProphetRegressor()
clf.fit(X, y)
clf.feature_importance == np.zeros(1)


def test_fit_predict_ts_with_X_index(ts_data):
X, y = ts_data
assert isinstance(X.index, pd.DatetimeIndex)

p_clf = prophet.Prophet()
prophet_df = ProphetRegressor.build_prophet_df(X=X, y=y, date_column='ds')

with suppress_stdout_stderr():
p_clf.fit(prophet_df)
y_pred_p = p_clf.predict(prophet_df)['yhat']

clf = ProphetRegressor()
clf.fit(X, y)
y_pred = clf.predict(X)

assert (y_pred == y_pred_p).all()


def test_fit_predict_ts_with_y_index(ts_data):
X, y = ts_data
X = X.reset_index(drop=True)
assert isinstance(y.index, pd.DatetimeIndex)

p_clf = prophet.Prophet()
prophet_df = ProphetRegressor.build_prophet_df(X=X, y=y, date_column='ds')

with suppress_stdout_stderr():
p_clf.fit(prophet_df)
y_pred_p = p_clf.predict(prophet_df)['yhat']

clf = ProphetRegressor()
clf.fit(X, y)
y_pred = clf.predict(X, y)

assert (y_pred == y_pred_p).all()


def test_fit_predict_ts_no_X(ts_data):
X, y = ts_data

p_clf = prophet.Prophet()
prophet_df = ProphetRegressor.build_prophet_df(X=pd.DataFrame(), y=y, date_column='ds')

with suppress_stdout_stderr():
p_clf.fit(prophet_df)
y_pred_p = p_clf.predict(prophet_df)['yhat']

clf = ProphetRegressor()
clf.fit(X=None, y=y)
y_pred = clf.predict(X=None, y=y)

assert (y_pred == y_pred_p).all()


def test_fit_predict_date_col(ts_data):
X, y = ts_data

p_clf = prophet.Prophet()
prophet_df = ProphetRegressor.build_prophet_df(X=X, y=y, date_column='ds')

with suppress_stdout_stderr():
p_clf.fit(prophet_df)
y_pred_p = p_clf.predict(prophet_df)['yhat']

X = X.reset_index()
X = X['index'].rename('ds').to_frame()
clf = ProphetRegressor(date_column='ds')
clf.fit(X, y)
y_pred = clf.predict(X)

assert (y_pred == y_pred_p).all()


def test_fit_predict_no_date_col_or_index(X_y_binary):
X, y = X_y_binary

clf = ProphetRegressor()
with pytest.raises(ValueError):
clf.fit(X, y)
jeremyliweishih marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion evalml/tests/component_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_all_components(has_minimal_dependencies):
if has_minimal_dependencies:
assert len(all_components()) == 35
else:
assert len(all_components()) == 42
assert len(all_components()) == 43


def test_handle_component_class_names():
Expand Down
Loading