diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index deb369f196..6eadc98fc9 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -247,7 +247,7 @@ Regressors are components that output a predicted target value. RandomForestRegressor XGBoostRegressor BaselineRegressor - TimeSeriesBaselineRegressor + TimeSeriesBaselineEstimator StackedEnsembleRegressor DecisionTreeRegressor LightGBMRegressor diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index 4160966954..ff23115de4 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -11,6 +11,7 @@ Release Notes * Support graphviz 0.16 :pr:`1657` * Enhanced time series pipelines to accept empty features :pr:`1651` * Added support for list inputs for objectives :pr:`1663` + * Added support for ``AutoMLSearch`` to handle time series classification pipelines :pr:`1666` * Fixes * Fixed thresholding for pipelines in ``AutoMLSearch`` to only threshold binary classification pipelines :pr:`1622` :pr:`1626` * Updated ``load_data`` to return Woodwork structures and update default parameter value for ``index`` to ``None`` :pr:`1610` diff --git a/evalml/automl/automl_search.py b/evalml/automl/automl_search.py index 18d907443b..27d345dd48 100644 --- a/evalml/automl/automl_search.py +++ b/evalml/automl/automl_search.py @@ -41,6 +41,8 @@ ModeBaselineBinaryPipeline, ModeBaselineMulticlassPipeline, PipelineBase, + TimeSeriesBaselineBinaryPipeline, + TimeSeriesBaselineMulticlassPipeline, TimeSeriesBaselineRegressionPipeline ) from evalml.pipelines.components.utils import get_estimators @@ -634,10 +636,14 @@ def _add_baseline_pipelines(self): elif self.problem_type == ProblemTypes.REGRESSION: baseline = MeanBaselineRegressionPipeline(parameters={}) else: + pipeline_class = {ProblemTypes.TIME_SERIES_REGRESSION: TimeSeriesBaselineRegressionPipeline, + ProblemTypes.TIME_SERIES_MULTICLASS: TimeSeriesBaselineMulticlassPipeline, + ProblemTypes.TIME_SERIES_BINARY: TimeSeriesBaselineBinaryPipeline}[self.problem_type] gap = self.problem_configuration['gap'] max_delay = self.problem_configuration['max_delay'] - baseline = TimeSeriesBaselineRegressionPipeline(parameters={"pipeline": {"gap": gap, "max_delay": max_delay}, - "Time Series Baseline Regressor": {"gap": gap, "max_delay": max_delay}}) + baseline = pipeline_class(parameters={"pipeline": {"gap": gap, "max_delay": max_delay}, + "Time Series Baseline Estimator": {"gap": gap, "max_delay": max_delay}}) + pipelines = [baseline] scores = self._evaluate_pipelines(pipelines, baseline=True) if scores == []: diff --git a/evalml/automl/utils.py b/evalml/automl/utils.py index 7dac18b767..10ffb23cb2 100644 --- a/evalml/automl/utils.py +++ b/evalml/automl/utils.py @@ -5,7 +5,11 @@ TimeSeriesSplit, TrainingValidationSplit ) -from evalml.problem_types import ProblemTypes, handle_problem_types +from evalml.problem_types import ( + ProblemTypes, + handle_problem_types, + is_time_series +) _LARGE_DATA_ROW_THRESHOLD = int(1e5) @@ -25,7 +29,9 @@ def get_default_primary_search_objective(problem_type): objective_name = {'binary': 'Log Loss Binary', 'multiclass': 'Log Loss Multiclass', 'regression': 'R2', - 'time series regression': 'R2'}[problem_type.value] + 'time series regression': 'R2', + 'time series binary': 'Log Loss Binary', + 'time series multiclass': 'Log Loss Multiclass'}[problem_type.value] return get_objective(objective_name, return_instance=True) @@ -51,9 +57,7 @@ def make_data_splitter(X, y, problem_type, problem_configuration=None, n_splits= data_splitter = KFold(n_splits=n_splits, random_state=random_state, shuffle=shuffle) elif problem_type in [ProblemTypes.BINARY, ProblemTypes.MULTICLASS]: data_splitter = StratifiedKFold(n_splits=n_splits, random_state=random_state, shuffle=shuffle) - elif problem_type in [ProblemTypes.TIME_SERIES_REGRESSION, - ProblemTypes.TIME_SERIES_BINARY, - ProblemTypes.TIME_SERIES_MULTICLASS]: + elif is_time_series(problem_type): if not problem_configuration: raise ValueError("problem_configuration is required for time series problem types") data_splitter = TimeSeriesSplit(n_splits=n_splits, gap=problem_configuration.get('gap'), diff --git a/evalml/pipelines/__init__.py b/evalml/pipelines/__init__.py index 6592110bac..f5c6077da5 100644 --- a/evalml/pipelines/__init__.py +++ b/evalml/pipelines/__init__.py @@ -52,5 +52,5 @@ from .regression import ( BaselineRegressionPipeline, MeanBaselineRegressionPipeline, - TimeSeriesBaselineRegressionPipeline ) +from .time_series_baselines import TimeSeriesBaselineRegressionPipeline, TimeSeriesBaselineBinaryPipeline, TimeSeriesBaselineMulticlassPipeline diff --git a/evalml/pipelines/components/__init__.py b/evalml/pipelines/components/__init__.py index 2f51299950..22c4c26e27 100644 --- a/evalml/pipelines/components/__init__.py +++ b/evalml/pipelines/components/__init__.py @@ -19,7 +19,7 @@ BaselineRegressor, DecisionTreeClassifier, DecisionTreeRegressor, - TimeSeriesBaselineRegressor + TimeSeriesBaselineEstimator ) from .transformers import ( Transformer, diff --git a/evalml/pipelines/components/estimators/__init__.py b/evalml/pipelines/components/estimators/__init__.py index 64888b9b93..4812238480 100644 --- a/evalml/pipelines/components/estimators/__init__.py +++ b/evalml/pipelines/components/estimators/__init__.py @@ -16,5 +16,5 @@ ElasticNetRegressor, ExtraTreesRegressor, BaselineRegressor, - TimeSeriesBaselineRegressor, + TimeSeriesBaselineEstimator, DecisionTreeRegressor) diff --git a/evalml/pipelines/components/estimators/regressors/__init__.py b/evalml/pipelines/components/estimators/regressors/__init__.py index fb9760834e..c97c291258 100644 --- a/evalml/pipelines/components/estimators/regressors/__init__.py +++ b/evalml/pipelines/components/estimators/regressors/__init__.py @@ -7,4 +7,4 @@ from .et_regressor import ExtraTreesRegressor from .baseline_regressor import BaselineRegressor from .decision_tree_regressor import DecisionTreeRegressor -from .time_series_baseline_regressor import TimeSeriesBaselineRegressor +from .time_series_baseline_estimator import TimeSeriesBaselineEstimator diff --git a/evalml/pipelines/components/estimators/regressors/time_series_baseline_regressor.py b/evalml/pipelines/components/estimators/regressors/time_series_baseline_estimator.py similarity index 63% rename from evalml/pipelines/components/estimators/regressors/time_series_baseline_regressor.py rename to evalml/pipelines/components/estimators/regressors/time_series_baseline_estimator.py index 812560b157..30a8b49c75 100644 --- a/evalml/pipelines/components/estimators/regressors/time_series_baseline_regressor.py +++ b/evalml/pipelines/components/estimators/regressors/time_series_baseline_estimator.py @@ -6,23 +6,25 @@ from evalml.problem_types import ProblemTypes from evalml.utils.gen_utils import ( _convert_to_woodwork_structure, - _convert_woodwork_types_wrapper + _convert_woodwork_types_wrapper, + pad_with_nans ) -class TimeSeriesBaselineRegressor(Estimator): - """Time series regressor that predicts using the naive forecasting approach. +class TimeSeriesBaselineEstimator(Estimator): + """Time series estimator that predicts using the naive forecasting approach. - This is useful as a simple baseline regressor for time series problems + This is useful as a simple baseline estimator for time series problems """ - name = "Time Series Baseline Regressor" + name = "Time Series Baseline Estimator" hyperparameter_ranges = {} model_family = ModelFamily.BASELINE - supported_problem_types = [ProblemTypes.TIME_SERIES_REGRESSION] + supported_problem_types = [ProblemTypes.TIME_SERIES_REGRESSION, ProblemTypes.TIME_SERIES_BINARY, + ProblemTypes.TIME_SERIES_MULTICLASS] predict_uses_y = True def __init__(self, gap=1, random_state=0, **kwargs): - """Baseline time series regressor that predicts using the naive forecasting approach. + """Baseline time series estimator that predicts using the naive forecasting approach. Arguments: gap (int): gap between prediction date and target date and must be a positive integer. If gap is 0, target date will be shifted ahead by 1 time period. @@ -54,7 +56,7 @@ def fit(self, X, y=None): def predict(self, X, y=None): if y is None: - raise ValueError("Cannot predict Time Series Baseline Regressor if y is None") + raise ValueError("Cannot predict Time Series Baseline Estimator if y is None") y = _convert_to_woodwork_structure(y) y = _convert_woodwork_types_wrapper(y.to_series()) @@ -63,9 +65,21 @@ def predict(self, X, y=None): return y + def predict_proba(self, X, y=None): + if y is None: + raise ValueError("Cannot predict Time Series Baseline Estimator if y is None") + y = _convert_to_woodwork_structure(y) + y = _convert_woodwork_types_wrapper(y.to_series()) + preds = self.predict(X, y).dropna(axis=0, how='any').astype('int') + proba_arr = np.zeros((len(preds), y.max() + 1)) + proba_arr[np.arange(len(preds)), preds] = 1 + return pad_with_nans(pd.DataFrame(proba_arr), len(y) - len(preds)) + @property def feature_importance(self): - """Returns importance associated with each feature. Since baseline regressors do not use input features to calculate predictions, returns an array of zeroes. + """Returns importance associated with each feature. + + Since baseline estimators do not use input features to calculate predictions, returns an array of zeroes. Returns: np.ndarray (float): an array of zeroes diff --git a/evalml/pipelines/regression/__init__.py b/evalml/pipelines/regression/__init__.py index 9f985a54ba..d35a31913b 100644 --- a/evalml/pipelines/regression/__init__.py +++ b/evalml/pipelines/regression/__init__.py @@ -1,2 +1 @@ from .baseline_regression import BaselineRegressionPipeline, MeanBaselineRegressionPipeline -from .time_series_baseline_regression import TimeSeriesBaselineRegressionPipeline diff --git a/evalml/pipelines/regression/time_series_baseline_regression.py b/evalml/pipelines/regression/time_series_baseline_regression.py deleted file mode 100644 index e92d4083c7..0000000000 --- a/evalml/pipelines/regression/time_series_baseline_regression.py +++ /dev/null @@ -1,7 +0,0 @@ -from evalml.pipelines import TimeSeriesRegressionPipeline - - -class TimeSeriesBaselineRegressionPipeline(TimeSeriesRegressionPipeline): - """Baseline Pipeline for time series regression problems.""" - _name = "Time Series Baseline Regression Pipeline" - component_graph = ["Time Series Baseline Regressor"] diff --git a/evalml/pipelines/time_series_baselines.py b/evalml/pipelines/time_series_baselines.py new file mode 100644 index 0000000000..1614ce77bd --- /dev/null +++ b/evalml/pipelines/time_series_baselines.py @@ -0,0 +1,23 @@ +from evalml.pipelines import ( + TimeSeriesBinaryClassificationPipeline, + TimeSeriesMulticlassClassificationPipeline, + TimeSeriesRegressionPipeline +) + + +class TimeSeriesBaselineRegressionPipeline(TimeSeriesRegressionPipeline): + """Baseline Pipeline for time series regression problems.""" + _name = "Time Series Baseline Regression Pipeline" + component_graph = ["Time Series Baseline Estimator"] + + +class TimeSeriesBaselineBinaryPipeline(TimeSeriesBinaryClassificationPipeline): + """Baseline Pipeline for time series binary classification problems.""" + _name = "Time Series Baseline Binary Pipeline" + component_graph = ["Time Series Baseline Estimator"] + + +class TimeSeriesBaselineMulticlassPipeline(TimeSeriesMulticlassClassificationPipeline): + """Baseline Pipeline for time series multiclass classification problems.""" + _name = "Time Series Baseline Multiclass Pipeline" + component_graph = ["Time Series Baseline Estimator"] diff --git a/evalml/pipelines/time_series_classification_pipelines.py b/evalml/pipelines/time_series_classification_pipelines.py index 2f035c3db2..ceab7ec699 100644 --- a/evalml/pipelines/time_series_classification_pipelines.py +++ b/evalml/pipelines/time_series_classification_pipelines.py @@ -111,7 +111,11 @@ def predict(self, X, y=None, objective=None): y = _convert_woodwork_types_wrapper(y.to_series()) n_features = max(len(y), X.shape[0]) predictions = self._predict(X, y, objective=objective, pad=False) - predictions = pd.Series(self._decode_targets(predictions), name=self.input_target_name) + + # In case gap is 0 and this is a baseline pipeline, we drop the nans in the + # predictions before decoding them + predictions = pd.Series(self._decode_targets(predictions.dropna()), name=self.input_target_name) + return pad_with_nans(predictions, max(0, n_features - predictions.shape[0])) def predict_proba(self, X, y=None): diff --git a/evalml/pipelines/utils.py b/evalml/pipelines/utils.py index 23b460d734..defaa08aad 100644 --- a/evalml/pipelines/utils.py +++ b/evalml/pipelines/utils.py @@ -30,7 +30,11 @@ TextFeaturizer ) from evalml.pipelines.components.utils import all_components, get_estimators -from evalml.problem_types import ProblemTypes, handle_problem_types +from evalml.problem_types import ( + ProblemTypes, + handle_problem_types, + is_time_series +) from evalml.utils import get_logger from evalml.utils.gen_utils import _convert_to_woodwork_structure @@ -67,7 +71,7 @@ def _get_preprocessing_components(X, y, problem_type, text_columns, estimator_cl if add_datetime_featurizer: pp_components.append(DateTimeFeaturizer) - if problem_type in [ProblemTypes.TIME_SERIES_REGRESSION]: + if is_time_series(problem_type): pp_components.append(DelayedFeatureTransformer) categorical_cols = X.select('category') diff --git a/evalml/tests/automl_tests/test_automl.py b/evalml/tests/automl_tests/test_automl.py index d4ccf5e6c3..cfd5d949da 100644 --- a/evalml/tests/automl_tests/test_automl.py +++ b/evalml/tests/automl_tests/test_automl.py @@ -48,15 +48,11 @@ BinaryClassificationPipeline, Estimator, MulticlassClassificationPipeline, - RegressionPipeline, - TimeSeriesRegressionPipeline + RegressionPipeline ) from evalml.pipelines.components.utils import get_estimators from evalml.pipelines.utils import make_pipeline -from evalml.preprocessing.data_splitters import ( - TimeSeriesSplit, - TrainingValidationSplit -) +from evalml.preprocessing.data_splitters import TrainingValidationSplit from evalml.problem_types import ProblemTypes, handle_problem_types from evalml.tuners import NoParamsException, RandomSearchTuner from evalml.utils.gen_utils import ( @@ -1976,32 +1972,6 @@ def test_automl_validates_problem_configuration(X_y_binary): assert problem_config == {"max_delay": 2, "gap": 3} -@patch('evalml.pipelines.TimeSeriesRegressionPipeline.score', return_value={"R2": 0.3}) -@patch('evalml.pipelines.TimeSeriesRegressionPipeline.fit') -def test_automl_time_series_regression(mock_fit, mock_score, X_y_regression): - X, y = X_y_regression - - configuration = {"gap": 0, "max_delay": 0, 'delay_target': False, 'delay_features': True} - - class Pipeline1(TimeSeriesRegressionPipeline): - name = "Pipeline 1" - component_graph = ["Delayed Feature Transformer", "Random Forest Regressor"] - - class Pipeline2(TimeSeriesRegressionPipeline): - name = "Pipeline 2" - component_graph = ["Delayed Feature Transformer", "Elastic Net Regressor"] - - automl = AutoMLSearch(X_train=X, y_train=y, problem_type="time series regression", problem_configuration=configuration, - allowed_pipelines=[Pipeline1, Pipeline2], max_batches=2) - automl.search() - assert isinstance(automl.data_splitter, TimeSeriesSplit) - for result in automl.results['pipeline_results'].values(): - if result["id"] == 0: - continue - assert result['parameters']['Delayed Feature Transformer'] == configuration - assert result['parameters']['pipeline'] == configuration - - @patch('evalml.objectives.BinaryClassificationObjective.optimize_threshold') def test_automl_best_pipeline(mock_optimize, X_y_binary): X, y = X_y_binary @@ -2085,7 +2055,7 @@ def test_timeseries_baseline_init_with_correct_gap_max_delay(mock_fit, mock_scor # Best pipeline is baseline pipeline because we only run one iteration assert automl.best_pipeline.parameters == {"pipeline": {"gap": 6, "max_delay": 3}, - "Time Series Baseline Regressor": {"gap": 6, "max_delay": 3}} + "Time Series Baseline Estimator": {"gap": 6, "max_delay": 3}} @pytest.mark.parametrize('problem_type', [ProblemTypes.BINARY, ProblemTypes.MULTICLASS, diff --git a/evalml/tests/automl_tests/test_automl_search_classification.py b/evalml/tests/automl_tests/test_automl_search_classification.py index 60c85148b1..9bbf8f918d 100644 --- a/evalml/tests/automl_tests/test_automl_search_classification.py +++ b/evalml/tests/automl_tests/test_automl_search_classification.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd import pytest -from sklearn.model_selection import StratifiedKFold, TimeSeriesSplit +from sklearn.model_selection import StratifiedKFold from skopt.space import Categorical from evalml import AutoMLSearch @@ -22,10 +22,13 @@ ModeBaselineBinaryPipeline, ModeBaselineMulticlassPipeline, MulticlassClassificationPipeline, - PipelineBase + PipelineBase, + TimeSeriesBaselineBinaryPipeline, + TimeSeriesBaselineMulticlassPipeline ) from evalml.pipelines.components.utils import get_estimators from evalml.pipelines.utils import make_pipeline +from evalml.preprocessing import TimeSeriesSplit from evalml.problem_types import ProblemTypes @@ -77,8 +80,8 @@ def test_data_splitter(X_y_binary): assert isinstance(automl.rankings, pd.DataFrame) assert len(automl.results['pipeline_results'][0]["cv_data"]) == cv_folds - automl = AutoMLSearch(X_train=X, y_train=y, problem_type='binary', data_splitter=TimeSeriesSplit(cv_folds), max_iterations=1, - n_jobs=1) + automl = AutoMLSearch(X_train=X, y_train=y, problem_type='binary', data_splitter=TimeSeriesSplit(n_splits=cv_folds), + max_iterations=1, n_jobs=1) automl.search() assert isinstance(automl.rankings, pd.DataFrame) @@ -677,3 +680,37 @@ def test_automl_multiclass_nonlinear_pipeline_search_more_iterations(nonlinear_m assert start_iteration_callback.call_args_list[0][0][0] == ModeBaselineMulticlassPipeline assert start_iteration_callback.call_args_list[1][0][0] == nonlinear_multiclass_pipeline_class assert start_iteration_callback.call_args_list[4][0][0] == nonlinear_multiclass_pipeline_class + + +@pytest.mark.parametrize('problem_type', [ProblemTypes.TIME_SERIES_MULTICLASS, ProblemTypes.TIME_SERIES_BINARY]) +@patch('evalml.pipelines.TimeSeriesMulticlassClassificationPipeline.score') +@patch('evalml.pipelines.TimeSeriesBinaryClassificationPipeline.score') +@patch('evalml.pipelines.TimeSeriesMulticlassClassificationPipeline.fit') +@patch('evalml.pipelines.TimeSeriesBinaryClassificationPipeline.fit') +def test_automl_supports_time_series_classification(mock_binary_fit, mock_multi_fit, mock_binary_score, mock_multiclass_score, + problem_type, X_y_binary, X_y_multi): + if problem_type == ProblemTypes.TIME_SERIES_BINARY: + X, y = X_y_binary + baseline = TimeSeriesBaselineBinaryPipeline + mock_binary_score.return_value = {"Log Loss Binary": 0.2} + problem_type = 'time series binary' + else: + X, y = X_y_multi + baseline = TimeSeriesBaselineMulticlassPipeline + mock_multiclass_score.return_value = {"Log Loss Multiclass": 0.25} + problem_type = 'time series multiclass' + + configuration = {"gap": 0, "max_delay": 0, 'delay_target': False, 'delay_features': True} + + automl = AutoMLSearch(X_train=X, y_train=y, problem_type=problem_type, + problem_configuration=configuration, + max_batches=2) + automl.search() + assert isinstance(automl.data_splitter, TimeSeriesSplit) + for result in automl.results['pipeline_results'].values(): + if result["id"] == 0: + assert result['pipeline_class'] == baseline + continue + + assert result['parameters']['Delayed Feature Transformer'] == configuration + assert result['parameters']['pipeline'] == configuration diff --git a/evalml/tests/automl_tests/test_automl_search_regression.py b/evalml/tests/automl_tests/test_automl_search_regression.py index 56c91ad0ca..6a7e3a5ab0 100644 --- a/evalml/tests/automl_tests/test_automl_search_regression.py +++ b/evalml/tests/automl_tests/test_automl_search_regression.py @@ -8,9 +8,14 @@ from evalml.exceptions import ObjectiveNotFoundError from evalml.model_family import ModelFamily from evalml.objectives import MeanSquaredLogError, RootMeanSquaredLogError -from evalml.pipelines import MeanBaselineRegressionPipeline, PipelineBase +from evalml.pipelines import ( + MeanBaselineRegressionPipeline, + PipelineBase, + TimeSeriesBaselineRegressionPipeline +) from evalml.pipelines.components.utils import get_estimators from evalml.pipelines.utils import make_pipeline +from evalml.preprocessing import TimeSeriesSplit from evalml.problem_types import ProblemTypes @@ -280,3 +285,23 @@ def test_automl_regression_nonlinear_pipeline_search(nonlinear_regression_pipeli assert start_iteration_callback.call_count == 2 assert start_iteration_callback.call_args_list[0][0][0] == MeanBaselineRegressionPipeline assert start_iteration_callback.call_args_list[1][0][0] == nonlinear_regression_pipeline_class + + +@patch('evalml.pipelines.TimeSeriesRegressionPipeline.score', return_value={"R2": 0.3}) +@patch('evalml.pipelines.TimeSeriesRegressionPipeline.fit') +def test_automl_supports_time_series_regression(mock_fit, mock_score, X_y_regression): + X, y = X_y_regression + + configuration = {"gap": 0, "max_delay": 0, 'delay_target': False, 'delay_features': True} + + automl = AutoMLSearch(X_train=X, y_train=y, problem_type="time series regression", problem_configuration=configuration, + max_batches=2) + automl.search() + assert isinstance(automl.data_splitter, TimeSeriesSplit) + for result in automl.results['pipeline_results'].values(): + if result["id"] == 0: + assert result['pipeline_class'] == TimeSeriesBaselineRegressionPipeline + continue + + assert result['parameters']['Delayed Feature Transformer'] == configuration + assert result['parameters']['pipeline'] == configuration diff --git a/evalml/tests/automl_tests/test_automl_utils.py b/evalml/tests/automl_tests/test_automl_utils.py index 9ad3075b39..d4135b3356 100644 --- a/evalml/tests/automl_tests/test_automl_utils.py +++ b/evalml/tests/automl_tests/test_automl_utils.py @@ -23,6 +23,8 @@ def test_get_default_primary_search_objective(): assert isinstance(get_default_primary_search_objective(ProblemTypes.MULTICLASS), LogLossMulticlass) assert isinstance(get_default_primary_search_objective("regression"), R2) assert isinstance(get_default_primary_search_objective(ProblemTypes.REGRESSION), R2) + assert isinstance(get_default_primary_search_objective('time series binary'), LogLossBinary) + assert isinstance(get_default_primary_search_objective('time series multiclass'), LogLossMulticlass) with pytest.raises(KeyError, match="Problem type 'auto' does not exist"): get_default_primary_search_objective("auto") diff --git a/evalml/tests/component_tests/test_components.py b/evalml/tests/component_tests/test_components.py index 10cc96203f..b90d93be18 100644 --- a/evalml/tests/component_tests/test_components.py +++ b/evalml/tests/component_tests/test_components.py @@ -50,7 +50,7 @@ SimpleImputer, StandardScaler, TextFeaturizer, - TimeSeriesBaselineRegressor, + TimeSeriesBaselineEstimator, Transformer, XGBoostClassifier, XGBoostRegressor @@ -800,7 +800,7 @@ def test_all_transformers_check_fit(X_y_binary): def test_all_estimators_check_fit(X_y_binary, test_estimator_needs_fitting_false, helper_functions): X, y = X_y_binary - estimators_to_check = [estimator for estimator in _all_estimators() if estimator not in [StackedEnsembleClassifier, StackedEnsembleRegressor, TimeSeriesBaselineRegressor]] + [test_estimator_needs_fitting_false] + estimators_to_check = [estimator for estimator in _all_estimators() if estimator not in [StackedEnsembleClassifier, StackedEnsembleRegressor, TimeSeriesBaselineEstimator]] + [test_estimator_needs_fitting_false] for component_class in estimators_to_check: if not component_class.needs_fitting: continue diff --git a/evalml/tests/component_tests/test_time_series_baseline_regressor.py b/evalml/tests/component_tests/test_time_series_baseline_estimators.py similarity index 74% rename from evalml/tests/component_tests/test_time_series_baseline_regressor.py rename to evalml/tests/component_tests/test_time_series_baseline_estimators.py index 7b059a2349..acea718da7 100644 --- a/evalml/tests/component_tests/test_time_series_baseline_regressor.py +++ b/evalml/tests/component_tests/test_time_series_baseline_estimators.py @@ -2,31 +2,33 @@ import pytest from evalml.model_family import ModelFamily -from evalml.pipelines.components import TimeSeriesBaselineRegressor +from evalml.pipelines.components import TimeSeriesBaselineEstimator def test_time_series_baseline_regressor_init(): - baseline = TimeSeriesBaselineRegressor() + baseline = TimeSeriesBaselineEstimator() assert baseline.model_family == ModelFamily.BASELINE def test_time_series_baseline_gap_negative(): with pytest.raises(ValueError, match='gap value must be a positive integer.'): - TimeSeriesBaselineRegressor(gap=-1) + TimeSeriesBaselineEstimator(gap=-1) def test_time_series_baseline_y_is_None(X_y_regression): X, _ = X_y_regression - clf = TimeSeriesBaselineRegressor() + clf = TimeSeriesBaselineEstimator() clf.fit(X, y=None) with pytest.raises(ValueError): clf.predict(X, y=None) + with pytest.raises(ValueError): + clf.predict_proba(X, y=None) def test_time_series_baseline(ts_data): X, y = ts_data - clf = TimeSeriesBaselineRegressor(gap=1) + clf = TimeSeriesBaselineEstimator(gap=1) clf.fit(X, y) np.testing.assert_allclose(clf.predict(X, y), y) @@ -38,7 +40,7 @@ def test_time_series_baseline_gap_0(ts_data): y_true = y.shift(periods=1) - clf = TimeSeriesBaselineRegressor(gap=0) + clf = TimeSeriesBaselineEstimator(gap=0) clf.fit(X, y) np.testing.assert_allclose(clf.predict(X, y), y_true) @@ -48,7 +50,7 @@ def test_time_series_baseline_gap_0(ts_data): def test_time_series_baseline_no_X(ts_data): _, y = ts_data - clf = TimeSeriesBaselineRegressor() + clf = TimeSeriesBaselineEstimator() clf.fit(X=None, y=y) np.testing.assert_allclose(clf.predict(X=None, y=y), y) diff --git a/evalml/tests/pipeline_tests/regression_pipeline_tests/test_time_series_baseline_regression.py b/evalml/tests/pipeline_tests/regression_pipeline_tests/test_time_series_baseline_regression.py deleted file mode 100644 index cefa1e85dd..0000000000 --- a/evalml/tests/pipeline_tests/regression_pipeline_tests/test_time_series_baseline_regression.py +++ /dev/null @@ -1,53 +0,0 @@ -from unittest.mock import patch - -import numpy as np -import pandas as pd -import pytest - -from evalml.pipelines import TimeSeriesBaselineRegressionPipeline - - -def test_time_series_baseline(ts_data): - X, y = ts_data - - clf = TimeSeriesBaselineRegressionPipeline(parameters={"pipeline": {"gap": 1, "max_delay": 1}}) - clf.fit(X, y) - - np.testing.assert_allclose(clf.predict(X, y), y) - - -def test_time_series_baseline_no_X(ts_data): - X, y = ts_data - - clf = TimeSeriesBaselineRegressionPipeline(parameters={"pipeline": {"gap": 1, "max_delay": 1}}) - clf.fit(X=None, y=y) - - np.testing.assert_allclose(clf.predict(X=None, y=y), y) - - -@pytest.mark.parametrize("only_use_y", [True, False]) -@pytest.mark.parametrize("gap,max_delay", [(0, 0), (1, 0), (0, 2), (1, 1), (1, 2), (2, 2), (7, 3), (2, 4)]) -@patch("evalml.pipelines.RegressionPipeline._score_all_objectives") -def test_time_series_baseline_score_offset(mock_score, gap, max_delay, only_use_y, ts_data): - X, y = ts_data - - expected_target = np.arange(1 + gap, 32) - target_index = pd.date_range(f"2020-10-01", f"2020-10-{31-gap}") - - clf = TimeSeriesBaselineRegressionPipeline(parameters={"pipeline": {"gap": gap, "max_delay": max_delay}}) - - if only_use_y: - clf.fit(None, y) - clf.score(X=None, y=y, objectives=[]) - else: - clf.fit(X, y) - clf.score(X, y, objectives=[]) - - # Verify that NaNs are dropped before passed to objectives - _, target, preds = mock_score.call_args[0] - assert not target.isna().any() - assert not preds.isna().any() - - # Target used for scoring matches expected dates - pd.testing.assert_index_equal(target.index, target_index) - np.testing.assert_equal(target.values, expected_target) diff --git a/evalml/tests/pipeline_tests/test_pipelines.py b/evalml/tests/pipeline_tests/test_pipelines.py index 35bfa71658..81143f9334 100644 --- a/evalml/tests/pipeline_tests/test_pipelines.py +++ b/evalml/tests/pipeline_tests/test_pipelines.py @@ -51,8 +51,7 @@ make_pipeline, make_pipeline_from_components ) -from evalml.preprocessing.utils import is_time_series -from evalml.problem_types import ProblemTypes +from evalml.problem_types import ProblemTypes, is_time_series from evalml.utils.gen_utils import check_random_state_equality @@ -119,7 +118,7 @@ def test_make_pipeline_all_nan_no_categoricals(input_type, problem_type): pipeline = make_pipeline(X, y, estimator_class, problem_type) assert isinstance(pipeline, type(pipeline_class)) assert pipeline.custom_hyperparameters is None - if problem_type in [ProblemTypes.TIME_SERIES_REGRESSION]: + if is_time_series(problem_type): delayed_features = [DelayedFeatureTransformer] else: delayed_features = [] @@ -153,7 +152,7 @@ def test_make_pipeline(input_type, problem_type): pipeline = make_pipeline(X, y, estimator_class, problem_type) assert isinstance(pipeline, type(pipeline_class)) assert pipeline.custom_hyperparameters is None - if problem_type in [ProblemTypes.TIME_SERIES_REGRESSION]: + if is_time_series(problem_type): delayed_features = [DelayedFeatureTransformer] else: delayed_features = [] @@ -187,7 +186,7 @@ def test_make_pipeline_no_nulls(input_type, problem_type): pipeline = make_pipeline(X, y, estimator_class, problem_type) assert isinstance(pipeline, type(pipeline_class)) assert pipeline.custom_hyperparameters is None - if problem_type in [ProblemTypes.TIME_SERIES_REGRESSION]: + if is_time_series(problem_type): delayed_features = [DelayedFeatureTransformer] else: delayed_features = [] @@ -221,7 +220,7 @@ def test_make_pipeline_no_datetimes(input_type, problem_type): pipeline = make_pipeline(X, y, estimator_class, problem_type) assert isinstance(pipeline, type(pipeline_class)) assert pipeline.custom_hyperparameters is None - if problem_type in [ProblemTypes.TIME_SERIES_REGRESSION]: + if is_time_series(problem_type): delayed_features = [DelayedFeatureTransformer] else: delayed_features = [] @@ -252,7 +251,7 @@ def test_make_pipeline_no_column_names(input_type, problem_type): pipeline = make_pipeline(X, y, estimator_class, problem_type) assert isinstance(pipeline, type(pipeline_class)) assert pipeline.custom_hyperparameters is None - if problem_type in [ProblemTypes.TIME_SERIES_REGRESSION]: + if is_time_series(problem_type): delayed_features = [DelayedFeatureTransformer] else: delayed_features = [] @@ -286,7 +285,7 @@ def test_make_pipeline_text_columns(input_type, problem_type): pipeline = make_pipeline(X, y, estimator_class, problem_type, text_columns=['text']) assert isinstance(pipeline, type(pipeline_class)) assert pipeline.custom_hyperparameters is None - if problem_type in [ProblemTypes.TIME_SERIES_REGRESSION]: + if is_time_series(problem_type): delayed_features = [DelayedFeatureTransformer] else: delayed_features = [] @@ -313,7 +312,7 @@ def test_make_pipeline_numpy_input(problem_type): for problem_type in estimator_class.supported_problem_types: pipeline = make_pipeline(X, y, estimator_class, problem_type) assert isinstance(pipeline, type(pipeline_class)) - if problem_type in [ProblemTypes.TIME_SERIES_REGRESSION]: + if is_time_series(problem_type): delayed_features = [DelayedFeatureTransformer] else: delayed_features = [] @@ -344,7 +343,7 @@ def test_make_pipeline_datetime_no_categorical(input_type, problem_type): pipeline = make_pipeline(X, y, estimator_class, problem_type) assert isinstance(pipeline, type(pipeline_class)) assert pipeline.custom_hyperparameters is None - if problem_type in [ProblemTypes.TIME_SERIES_REGRESSION]: + if is_time_series(problem_type): delayed_features = [DelayedFeatureTransformer] else: delayed_features = [] diff --git a/evalml/tests/pipeline_tests/test_time_series_baseline_pipeline.py b/evalml/tests/pipeline_tests/test_time_series_baseline_pipeline.py new file mode 100644 index 0000000000..41f1f5f3c2 --- /dev/null +++ b/evalml/tests/pipeline_tests/test_time_series_baseline_pipeline.py @@ -0,0 +1,86 @@ +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pytest + +from evalml.pipelines import TimeSeriesBaselineRegressionPipeline +from evalml.pipelines.time_series_baselines import ( + TimeSeriesBaselineBinaryPipeline, + TimeSeriesBaselineMulticlassPipeline +) + + +@pytest.mark.parametrize('X_none', [True, False]) +@pytest.mark.parametrize('gap', [0, 1]) +@pytest.mark.parametrize('pipeline_class', [TimeSeriesBaselineRegressionPipeline, + TimeSeriesBaselineBinaryPipeline, TimeSeriesBaselineMulticlassPipeline]) +def test_time_series_baseline(pipeline_class, gap, X_none, ts_data): + X, y = ts_data + + clf = pipeline_class(parameters={"pipeline": {"gap": gap, "max_delay": 1}, + "Time Series Baseline Estimator": {'gap': gap, 'max_delay': 1}}) + expected_y = y.shift(1) if gap == 0 else y + if X_none: + X = None + clf.fit(X, y) + np.testing.assert_allclose(clf.predict(X, y), expected_y) + + +@pytest.mark.parametrize('X_none', [True, False]) +@pytest.mark.parametrize('gap', [0, 1]) +@pytest.mark.parametrize('pipeline_class', [TimeSeriesBaselineBinaryPipeline, TimeSeriesBaselineMulticlassPipeline]) +def test_time_series_baseline_predict_proba(pipeline_class, gap, X_none): + X = pd.DataFrame({"a": [4, 5, 6, 7, 8]}) + y = pd.Series([0, 1, 1, 0, 1]) + expected_proba = pd.DataFrame({0: [1, 0, 0, 1, 0], + 1: [0, 1, 1, 0, 1]}) + if pipeline_class == TimeSeriesBaselineMulticlassPipeline: + y = pd.Series([0, 1, 2, 2, 1]) + expected_proba = pd.DataFrame({0: [1, 0, 0, 0, 0], + 1: [0, 1, 0, 0, 1], + 2: [0, 0, 1, 1, 0]}) + if gap == 0: + # Shift to pad the first row with Nans + expected_proba = expected_proba.shift(1) + + clf = pipeline_class(parameters={"pipeline": {"gap": gap, "max_delay": 1}, + "Time Series Baseline Estimator": {'gap': gap, 'max_delay': 1}}) + if X_none: + X = None + clf.fit(X, y) + np.testing.assert_allclose(clf.predict_proba(X, y), expected_proba) + + +@pytest.mark.parametrize('pipeline_class', [TimeSeriesBaselineRegressionPipeline, + TimeSeriesBaselineBinaryPipeline, TimeSeriesBaselineMulticlassPipeline]) +@pytest.mark.parametrize("only_use_y", [True, False]) +@pytest.mark.parametrize("gap,max_delay", [(0, 0), (1, 0), (0, 2), (1, 1), (1, 2), (2, 2), (7, 3), (2, 4)]) +@patch("evalml.pipelines.RegressionPipeline._score_all_objectives") +@patch("evalml.pipelines.ClassificationPipeline._score_all_objectives") +@patch("evalml.pipelines.ClassificationPipeline._encode_targets", side_effect=lambda y: y) +def test_time_series_baseline_score_offset(mock_encode, mock_classification_score, mock_regression_score, gap, max_delay, + only_use_y, pipeline_class, ts_data): + X, y = ts_data + + expected_target = pd.Series(np.arange(1 + gap, 32), index=pd.date_range(f"2020-10-01", f"2020-10-{31-gap}")) + if gap == 0: + expected_target = expected_target[1:] + clf = pipeline_class(parameters={"pipeline": {"gap": gap, "max_delay": max_delay}, + "Time Series Baseline Estimator": {"gap": gap, "max_delay": max_delay}}) + mock_score = mock_regression_score if pipeline_class == TimeSeriesBaselineRegressionPipeline else mock_classification_score + if only_use_y: + clf.fit(None, y) + clf.score(X=None, y=y, objectives=['MCC Binary']) + else: + clf.fit(X, y) + clf.score(X, y, objectives=['MCC Binary']) + + # Verify that NaNs are dropped before passed to objectives + _, target, preds = mock_score.call_args[0] + assert not target.isna().any() + assert not preds.isna().any() + + # Target used for scoring matches expected dates + pd.testing.assert_index_equal(target.index, expected_target.index) + np.testing.assert_equal(target.values, expected_target.values) diff --git a/evalml/tests/pipeline_tests/test_time_series_pipeline.py b/evalml/tests/pipeline_tests/test_time_series_pipeline.py index 8b4d2f675d..ba56348d12 100644 --- a/evalml/tests/pipeline_tests/test_time_series_pipeline.py +++ b/evalml/tests/pipeline_tests/test_time_series_pipeline.py @@ -5,9 +5,7 @@ import pytest import woodwork as ww -from evalml.model_family import ModelFamily from evalml.pipelines import ( - Estimator, TimeSeriesBinaryClassificationPipeline, TimeSeriesMulticlassClassificationPipeline, TimeSeriesRegressionPipeline @@ -252,30 +250,6 @@ class MyTsPipeline(pipeline_class): pd.testing.assert_frame_equal(df_passed_to_predict, answer) -class ComponentUsesYInPredict(Estimator): - name = "Custom Component" - supported_problem_types = [ProblemTypes.TIME_SERIES_BINARY, ProblemTypes.TIME_SERIES_MULTICLASS] - model_family = ModelFamily.NONE - predict_uses_y = True - - def __init__(self, *args, **kwargs): - super().__init__(parameters={}, - component_obj=None, - random_state=0) - - def fit(self, X, y): - """No op.""" - - def predict(self, X, y): - return y - - def predict_proba(self, X, y): - n_classes = len(y.value_counts()) - mode_index = 0 - proba_arr = np.array([[1.0 if i == mode_index else 0.0 for i in range(n_classes)]] * len(y)) - return pd.DataFrame(proba_arr) - - @pytest.mark.parametrize("pipeline_class,objectives", [(TimeSeriesBinaryClassificationPipeline, ["MCC Binary"]), (TimeSeriesBinaryClassificationPipeline, ["Log Loss Binary"]), (TimeSeriesBinaryClassificationPipeline, ["MCC Binary", "Log Loss Binary"]), @@ -322,36 +296,6 @@ class Pipeline(pipeline_class): pl.score(X, y, objectives) -@pytest.mark.parametrize("pipeline_class", [TimeSeriesBinaryClassificationPipeline, - TimeSeriesMulticlassClassificationPipeline]) -@pytest.mark.parametrize("use_none_X", [True, False]) -def test_score_works_with_estimator_uses_y(use_none_X, pipeline_class, X_y_binary, X_y_multi): - - class Pipeline(pipeline_class): - component_graph = [ComponentUsesYInPredict] - - pl = Pipeline({"pipeline": {"gap": 1, "max_delay": 2, "delay_features": False}}) - if pl.problem_type == ProblemTypes.TIME_SERIES_BINARY: - X, y = X_y_binary - y = pd.Series(y).map(lambda label: "good" if label == 1 else "bad") - expected_unique_values = {"good", "bad"} - objectives = ['MCC Binary', "Log Loss Binary"] - elif pl.problem_type == ProblemTypes.TIME_SERIES_MULTICLASS: - X, y = X_y_multi - label_map = {0: "good", 1: "bad", 2: "best"} - y = pd.Series(y).map(lambda label: label_map[label]) - expected_unique_values = {"good", "bad", "best"} - objectives = ["MCC Multiclass", "Log Loss Multiclass"] - - if use_none_X: - X = None - - pl.fit(X, y) - # NaNs are expected because of padding due to max_delay - assert set(pl.predict(X, y).dropna().unique()) == expected_unique_values - pl.score(X, y, objectives) - - @patch('evalml.pipelines.TimeSeriesClassificationPipeline._decode_targets') @patch('evalml.objectives.BinaryClassificationObjective.decision_function') @patch('evalml.pipelines.components.Estimator.predict_proba', return_value=pd.DataFrame({0: [1.]})) diff --git a/evalml/utils/gen_utils.py b/evalml/utils/gen_utils.py index e5e6de9a0a..07d77a3865 100644 --- a/evalml/utils/gen_utils.py +++ b/evalml/utils/gen_utils.py @@ -167,11 +167,12 @@ def _get_subclasses(base_class): return subclasses -_not_used_in_automl = {'BaselineClassifier', 'BaselineRegressor', 'TimeSeriesBaselineRegressor', +_not_used_in_automl = {'BaselineClassifier', 'BaselineRegressor', 'TimeSeriesBaselineEstimator', 'StackedEnsembleClassifier', 'StackedEnsembleRegressor', 'ModeBaselineBinaryPipeline', 'BaselineBinaryPipeline', 'MeanBaselineRegressionPipeline', 'BaselineRegressionPipeline', 'ModeBaselineMulticlassPipeline', 'BaselineMulticlassPipeline', - 'TimeSeriesBaselineRegressionPipeline'} + 'TimeSeriesBaselineRegressionPipeline', 'TimeSeriesBaselineBinaryPipeline', + 'TimeSeriesBaselineMulticlassPipeline'} def get_importable_subclasses(base_class, used_in_automl=True):