Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs for forecasting task #443

Merged
merged 7 commits into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
100 changes: 98 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ Copyright (C) 2021 [AutoML Groups Freiburg and Hannover](http://www.automl.org/

While early AutoML frameworks focused on optimizing traditional ML pipelines and their hyperparameters, another trend in AutoML is to focus on neural architecture search. To bring the best of these two worlds together, we developed **Auto-PyTorch**, which jointly and robustly optimizes the network architecture and the training hyperparameters to enable fully automated deep learning (AutoDL).

Auto-PyTorch is mainly developed to support tabular data (classification, regression).
Auto-PyTorch is mainly developed to support tabular data (classification, regression) and time series data (forecasting).
The newest features in Auto-PyTorch for tabular data are described in the paper ["Auto-PyTorch Tabular: Multi-Fidelity MetaLearning for Efficient and Robust AutoDL"](https://arxiv.org/abs/2006.13799) (see below for bibtex ref).
Details about Auto-PyTorch for multi-horizontal time series forecasting tasks can be found in the paper ["Efficient Automated Deep Learning for Time Series Forecasting"](https://arxiv.org/abs/2205.05511) (also see below for bibtex ref).

Also, find the documentation [here](https://automl.github.io/Auto-PyTorch/master).

Expand All @@ -27,7 +28,9 @@ In other words, we evaluate the portfolio on a provided data as initial configur
Then API starts the following procedures:
1. **Validate input data**: Process each data type, e.g. encoding categorical data, so that Auto-Pytorch can handled.
2. **Create dataset**: Create a dataset that can be handled in this API with a choice of cross validation or holdout splits.
3. **Evaluate baselines** *1: Train each algorithm in the predefined pool with a fixed hyperparameter configuration and dummy model from `sklearn.dummy` that represents the worst possible performance.
3. **Evaluate baselines**
* ***Tabular dataset*** *1: Train each algorithm in the predefined pool with a fixed hyperparameter configuration and dummy model from `sklearn.dummy` that represents the worst possible performance.
* ***Time Series Forecasting dataset*** : Train a dummy predictor that repeats the last observed value in each series
4. **Search by [SMAC](https://github.com/automl/SMAC3)**:\
a. Determine budget and cut-off rules by [Hyperband](https://jmlr.org/papers/volume18/16-558/16-558.pdf)\
b. Sample a pipeline hyperparameter configuration *2 by SMAC\
Expand All @@ -50,6 +53,14 @@ pip install autoPyTorch

```

Auto-PyTorch for Time Series Forecasting requires additional dependencies

```sh

pip install autoPyTorch[forecasting]

```

### Manual Installation

We recommend using Anaconda for developing as follows:
Expand All @@ -70,6 +81,20 @@ python setup.py install

```

Similarly, to install all the dependencies for Auto-PyTorch-TimeSeriesForecasting:


```sh

git submodule update --init --recursive

conda create -n auto-pytorch python=3.8
conda activate auto-pytorch
conda install swig
pip install -e[forecasting]

```

## Examples

In a nutshell:
Expand Down Expand Up @@ -105,6 +130,66 @@ score = api.score(y_pred, y_test)
print("Accuracy score", score)
```

For Time Series Forecasting Tasks
```py

from autoPyTorch.api.time_series_forecasting import TimeSeriesForecastingTask

# data and metric imports
from sktime.datasets import load_longley
targets, features = load_longley()

# define the forecasting horizon
forecasting_horizon = 3

# Dataset optimized by APT-TS can be a list of np.ndarray/ pd.DataFrame where each series represents an element in the
# list, or a single pd.DataFrame that records the series
# index information: to which series the timestep belongs? This id can be stored as the DataFrame's index or a separate
# column
# Within each series, we take the last forecasting_horizon as test targets. The items before that as training targets
# Normally the value to be forecasted should follow the training sets
y_train = [targets[: -forecasting_horizon]]
y_test = [targets[-forecasting_horizon:]]

# same for features. For uni-variant models, X_train, X_test can be omitted and set as None
X_train = [features[: -forecasting_horizon]]
# Here x_test indicates the 'known future features': they are the features known previously, features that are unknown
# could be replaced with NAN or zeros (which will not be used by our networks). If no feature is known beforehand,
# we could also omit X_test
known_future_features = list(features.columns)
X_test = [features[-forecasting_horizon:]]

start_times = [targets.index.to_timestamp()[0]]
freq = '1Y'

# initialise Auto-PyTorch api
api = TimeSeriesForecastingTask()

# Search for an ensemble of machine learning algorithms
api.search(
X_train=X_train,
y_train=y_train,
X_test=X_test,
optimize_metric='mean_MAPE_forecasting',
n_prediction_steps=forecasting_horizon,
memory_limit=16 * 1024, # Currently, forecasting models use much more memories
freq=freq,
start_times=start_times,
func_eval_time_limit_secs=50,
total_walltime_limit=60,
min_num_test_instances=1000, # proxy validation sets. This only works for the tasks with more than 1000 series
known_future_features=known_future_features,
)

# our dataset could directly generate sequences for new datasets
test_sets = api.dataset.generate_test_seqs()

# Calculate test accuracy
y_pred = api.predict(test_sets)
score = api.score(y_pred, y_test)
print("Forecasting score", score)
```

For more examples including customising the search space, parellising the code, etc, checkout the `examples` folder

```sh
Expand Down Expand Up @@ -163,6 +248,17 @@ Please refer to the branch `TPAMI.2021.3067763` to reproduce the paper *Auto-PyT
}
```

```bibtex
@article{deng-ecml22,
author = {Difan Deng and Florian Karl and Frank Hutter and Bernd Bischl and Marius Lindauer},
title = {Efficient Automated Deep Learning for Time Series Forecasting},
year = {2022},
booktitle = {Machine Learning and Knowledge Discovery in Databases. Research Track
- European Conference, {ECML} {PKDD} 2022},
url = {https://doi.org/10.48550/arXiv.2205.05511},
}
```

## Contact

Auto-PyTorch is developed by the [AutoML Groups of the University of Freiburg and Hannover](http://www.automl.org/).
3 changes: 1 addition & 2 deletions autoPyTorch/api/time_series_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from autoPyTorch.api.base_task import BaseTask
from autoPyTorch.automl_common.common.utils.backend import Backend
from autoPyTorch.constants import MAX_WINDOW_SIZE_BASE, TASK_TYPES_TO_STRING, TIMESERIES_FORECASTING
from autoPyTorch.data.time_series_forecasting_validator import \
TimeSeriesForecastingInputValidator
from autoPyTorch.data.time_series_forecasting_validator import TimeSeriesForecastingInputValidator
from autoPyTorch.data.utils import (
DatasetCompressionSpec,
get_dataset_compression_mapping
Expand Down
5 changes: 4 additions & 1 deletion autoPyTorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
CLASSIFICATION_OUTPUTS = [BINARY, MULTICLASS, MULTICLASSMULTIOUTPUT]
REGRESSION_OUTPUTS = [CONTINUOUS, CONTINUOUSMULTIOUTPUT]

# Constants for Forecasting Tasks
ForecastingDependenciesNotInstalledMSG = "Additional dependencies must be installed to work with time series " \
"forecasting tasks! Please run \n pip install autoPyTorch[forecasting] \n to "\
"install the corresponding dependencies!"


# The constant values for time series forecasting comes from
# https://github.com/rakshitha123/TSForecasting/blob/master/experiments/deep_learning_experiments.py
Expand Down
70 changes: 13 additions & 57 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@
import autoPyTorch.pipeline.image_classification
import autoPyTorch.pipeline.tabular_classification
import autoPyTorch.pipeline.tabular_regression
import autoPyTorch.pipeline.time_series_forecasting
try:
import autoPyTorch.pipeline.time_series_forecasting
forecasting_dependencies_installed = True
except ModuleNotFoundError:
forecasting_dependencies_installed = False
import autoPyTorch.pipeline.traditional_tabular_classification
import autoPyTorch.pipeline.traditional_tabular_regression
from autoPyTorch.automl_common.common.utils.backend import Backend
from autoPyTorch.constants import (
CLASSIFICATION_TASKS,
FORECASTING_BUDGET_TYPE,
FORECASTING_TASKS,
ForecastingDependenciesNotInstalledMSG,
IMAGE_TASKS,
MULTICLASS,
REGRESSION_TASKS,
Expand All @@ -38,12 +43,16 @@
BaseDataset,
BaseDatasetPropertiesType
)
from autoPyTorch.datasets.time_series_dataset import TimeSeriesSequence
from autoPyTorch.evaluation.utils import (
DisableFileOutputParameters,
VotingRegressorWrapper,
convert_multioutput_multiclass_to_multilabel
)
try:
from autoPyTorch.evaluation.utils_extra import DummyTimeSeriesForecastingPipeline
forecasting_dependencies_installed = True
except ModuleNotFoundError:
forecasting_dependencies_installed = False
from autoPyTorch.pipeline.base_pipeline import BasePipeline
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.pipeline.components.training.metrics.utils import (
Expand Down Expand Up @@ -314,61 +323,6 @@ def get_default_pipeline_options() -> Dict[str, Any]:
'runtime': 1}


class DummyTimeSeriesForecastingPipeline(DummyClassificationPipeline):
"""
A wrapper class that holds a pipeline for dummy forecasting. For each series, it simply repeats the last element
in the training series


Attributes:
random_state (Optional[Union[int, np.random.RandomState]]):
Object that contains a seed and allows for reproducible results
init_params (Optional[Dict]):
An optional dictionary that is passed to the pipeline's steps. It complies
a similar function as the kwargs
n_prediction_steps (int):
forecasting horizon
"""
def __init__(self, config: Configuration,
random_state: Optional[Union[int, np.random.RandomState]] = None,
init_params: Optional[Dict] = None,
n_prediction_steps: int = 1,
) -> None:
super(DummyTimeSeriesForecastingPipeline, self).__init__(config, random_state, init_params)
self.n_prediction_steps = n_prediction_steps

def fit(self, X: Dict[str, Any], y: Any,
sample_weight: Optional[np.ndarray] = None) -> object:
self.n_prediction_steps = X['dataset_properties']['n_prediction_steps']
y_train = subsampler(X['y_train'], X['train_indices'])
return DummyClassifier.fit(self, np.ones((y_train.shape[0], 1)), y_train, sample_weight)

def _generate_dummy_forecasting(self, X: List[Union[TimeSeriesSequence, np.ndarray]]) -> List:
if isinstance(X[0], TimeSeriesSequence):
X_tail = [x.get_target_values(-1) for x in X]
else:
X_tail = [x[-1] for x in X]
return X_tail

def predict_proba(self, X: Union[np.ndarray, pd.DataFrame],
batch_size: int = 1000) -> np.ndarray:
X_tail = self._generate_dummy_forecasting(X)
return np.tile(X_tail, (1, self.n_prediction_steps)).astype(np.float32).flatten()

def predict(self, X: Union[np.ndarray, pd.DataFrame],
batch_size: int = 1000) -> np.ndarray:
X_tail = np.asarray(self._generate_dummy_forecasting(X))
if X_tail.ndim == 1:
X_tail = np.expand_dims(X_tail, -1)
return np.tile(X_tail, (1, self.n_prediction_steps)).astype(np.float32).flatten()

@staticmethod
def get_default_pipeline_options() -> Dict[str, Any]:
return {'budget_type': 'epochs',
'epochs': 1,
'runtime': 1}


def fit_and_suppress_warnings(logger: PicklableClientLogger, pipeline: BaseEstimator,
X: Dict[str, Any], y: Any
) -> BaseEstimator:
Expand Down Expand Up @@ -543,6 +497,8 @@ def __init__(self, backend: Backend,
self.predict_function = self._predict_proba
elif self.task_type in FORECASTING_TASKS:
if isinstance(self.configuration, int):
if not forecasting_dependencies_installed:
raise ModuleNotFoundError(ForecastingDependenciesNotInstalledMSG)
self.pipeline_class = DummyTimeSeriesForecastingPipeline
elif isinstance(self.configuration, str):
raise ValueError("Only tabular classifications tasks "
Expand Down
9 changes: 8 additions & 1 deletion autoPyTorch/evaluation/tae.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from autoPyTorch.automl_common.common.utils.backend import Backend
from autoPyTorch.constants import (
FORECASTING_BUDGET_TYPE,
ForecastingDependenciesNotInstalledMSG,
STRING_TO_TASK_TYPES,
TIMESERIES_FORECASTING,
)
Expand All @@ -34,7 +35,11 @@
NoResamplingStrategyTypes
)
from autoPyTorch.evaluation.test_evaluator import eval_test_function
from autoPyTorch.evaluation.time_series_forecasting_train_evaluator import forecasting_eval_train_function
try:
from autoPyTorch.evaluation.time_series_forecasting_train_evaluator import forecasting_eval_train_function
forecasting_dependencies_installed = True
except ModuleNotFoundError:
forecasting_dependencies_installed = False
from autoPyTorch.evaluation.train_evaluator import eval_train_function
from autoPyTorch.evaluation.utils import (
DisableFileOutputParameters,
Expand Down Expand Up @@ -152,6 +157,8 @@ def __init__(
self.resampling_strategy_args = dm.resampling_strategy_args

if STRING_TO_TASK_TYPES.get(dm.task_type, -1) == TIMESERIES_FORECASTING:
if not forecasting_dependencies_installed:
raise ModuleNotFoundError(ForecastingDependenciesNotInstalledMSG)
eval_function: Callable = forecasting_eval_train_function
if isinstance(self.resampling_strategy, (HoldoutValTypes, CrossValTypes)):
self.output_y_hat_optimization = output_y_hat_optimization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from autoPyTorch.automl_common.common.utils.backend import Backend
from autoPyTorch.constants import SEASONALITY_MAP
from autoPyTorch.evaluation.abstract_evaluator import DummyTimeSeriesForecastingPipeline
from autoPyTorch.evaluation.train_evaluator import TrainEvaluator
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
from autoPyTorch.evaluation.utils_extra import DummyTimeSeriesForecastingPipeline
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.pipeline.components.training.metrics.metrics import MASE_LOSSES
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
Expand Down
72 changes: 72 additions & 0 deletions autoPyTorch/evaluation/utils_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# The functions and classes implemented in this module all require extra requirements.
# We put them here to make it easier to be wrapped by try-except process
from typing import Any, Dict, List, Optional, Union

from ConfigSpace import Configuration

import numpy as np

import pandas as pd

from sklearn.dummy import DummyClassifier

from autoPyTorch.datasets.time_series_dataset import TimeSeriesSequence
from autoPyTorch.utils.common import subsampler


class DummyTimeSeriesForecastingPipeline(DummyClassifier):
"""
A wrapper class that holds a pipeline for dummy forecasting. For each series, it simply repeats the last element
in the training series


Attributes:
random_state (Optional[Union[int, np.random.RandomState]]):
Object that contains a seed and allows for reproducible results
init_params (Optional[Dict]):
An optional dictionary that is passed to the pipeline's steps. It complies
a similar function as the kwargs
n_prediction_steps (int):
forecasting horizon
"""
def __init__(self, config: Configuration,
random_state: Optional[Union[int, np.random.RandomState]] = None,
init_params: Optional[Dict] = None,
n_prediction_steps: int = 1,
) -> None:
self.config = config
self.init_params = init_params
self.random_state = random_state
super(DummyTimeSeriesForecastingPipeline, self).__init__(strategy="uniform")
self.n_prediction_steps = n_prediction_steps

def fit(self, X: Dict[str, Any], y: Any,
sample_weight: Optional[np.ndarray] = None) -> object:
self.n_prediction_steps = X['dataset_properties']['n_prediction_steps']
y_train = subsampler(X['y_train'], X['train_indices'])
return DummyClassifier.fit(self, np.ones((y_train.shape[0], 1)), y_train, sample_weight)

def _generate_dummy_forecasting(self, X: List[Union[TimeSeriesSequence, np.ndarray]]) -> List:
if isinstance(X[0], TimeSeriesSequence):
X_tail = [x.get_target_values(-1) for x in X]
else:
X_tail = [x[-1] for x in X]
return X_tail

def predict_proba(self, X: Union[np.ndarray, pd.DataFrame],
batch_size: int = 1000) -> np.ndarray:
X_tail = self._generate_dummy_forecasting(X)
return np.tile(X_tail, (1, self.n_prediction_steps)).astype(np.float32).flatten()

def predict(self, X: Union[np.ndarray, pd.DataFrame],
batch_size: int = 1000) -> np.ndarray:
X_tail = np.asarray(self._generate_dummy_forecasting(X))
if X_tail.ndim == 1:
X_tail = np.expand_dims(X_tail, -1)
return np.tile(X_tail, (1, self.n_prediction_steps)).astype(np.float32).flatten()

@staticmethod
def get_default_pipeline_options() -> Dict[str, Any]:
return {'budget_type': 'epochs',
'epochs': 1,
'runtime': 1}